Initial commit: Zen-Marketing MCP Server v0.1.0

- Core architecture from zen-mcp-server
- OpenRouter and Gemini provider configuration
- Content variant generator tool (first marketing tool)
- Chat tool for marketing strategy
- Version and model listing tools
- Configuration system with .env support
- Logging infrastructure
- Ready for Claude Desktop integration
This commit is contained in:
Ben 2025-11-07 11:35:17 -04:00
commit 371806488d
56 changed files with 16196 additions and 0 deletions

40
.env.example Normal file
View file

@ -0,0 +1,40 @@
# Zen-Marketing MCP Server Configuration
# API Keys
# Required: At least one API key must be configured
OPENROUTER_API_KEY=your-openrouter-api-key-here
GEMINI_API_KEY=your-gemini-api-key-here
# Model Configuration
# DEFAULT_MODEL: Primary model for analytical and strategic work
# Options: google/gemini-2.5-pro-latest, minimax/minimax-m2, or any OpenRouter model
DEFAULT_MODEL=google/gemini-2.5-pro-latest
# FAST_MODEL: Model for quick generation (subject lines, variations)
FAST_MODEL=google/gemini-2.5-flash-preview-09-2025
# CREATIVE_MODEL: Model for creative content generation
CREATIVE_MODEL=minimax/minimax-m2
# Web Search
# Enable web search for fact-checking and current information
ENABLE_WEB_SEARCH=true
# Tool Configuration
# Comma-separated list of tools to disable
# Available tools: contentvariant, platformadapt, subjectlines, styleguide,
# seooptimize, guestedit, linkstrategy, factcheck,
# voiceanalysis, campaignmap, chat, thinkdeep, planner
DISABLED_TOOLS=
# Logging
# Options: DEBUG, INFO, WARNING, ERROR
LOG_LEVEL=INFO
# Language/Locale
# Leave empty for English, or specify locale (e.g., fr-FR, es-ES, de-DE)
LOCALE=
# Thinking Mode for Deep Analysis
# Options: minimal, low, medium, high, max
DEFAULT_THINKING_MODE_THINKDEEP=high

190
.gitignore vendored Normal file
View file

@ -0,0 +1,190 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
Pipfile.lock
# poetry
poetry.lock
# pdm
.pdm.toml
.pdm-python
pdm.lock
# PEP 582
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.env~
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
.idea/
# VS Code
.vscode/
# macOS
.DS_Store
# API Keys and secrets
*.key
*.pem
.env.local
.env.*.local
# Test outputs
test_output/
*.test.log
.coverage
htmlcov/
coverage.xml
.pytest_cache/
# Test simulation artifacts (dynamically created during testing)
test_simulation_files/.claude/
# Temporary test directories
test-setup/
# Scratch feature documentation files
FEATURE_*.md
# Temporary files
/tmp/
# Local user instructions
CLAUDE.local.md
# Claude Code personal settings
.claude/settings.local.json
# Standalone mode files
.zen_venv/
.docker_cleaned
logs/
*.backup
/.desktop_configured
/worktrees/
test_simulation_files/
.mcp.json

566
CLAUDE.md Normal file
View file

@ -0,0 +1,566 @@
# Claude Development Guide for Zen-Marketing MCP Server
This file contains essential commands and workflows for developing the Zen-Marketing MCP Server - a specialized marketing-focused fork of Zen MCP Server.
## Project Context
**What is Zen-Marketing?**
A Claude Desktop MCP server providing AI-powered marketing tools focused on:
- Content variation generation for A/B testing
- Cross-platform content adaptation
- Writing style enforcement
- SEO optimization for WordPress
- Guest content editing with voice preservation
- Technical fact verification
- Internal linking strategy
- Multi-channel campaign planning
**Target User:** Solo marketing professionals managing technical B2B content, particularly in industries like HVAC, SaaS, and technical education.
**Key Difference from Zen Code:** This is for marketing/content work, not software development. Tools generate content variations, enforce writing styles, and optimize for platforms like LinkedIn, newsletters, and WordPress - not code review or debugging.
## Quick Reference Commands
### Initial Setup
```bash
# Navigate to project directory
cd ~/mcp/zen-marketing
# Copy core files from zen-mcp-server (if starting fresh)
# We'll do this in the new session
# Create virtual environment
python3 -m venv .venv
source .venv/bin/activate
# Install dependencies (once requirements.txt is created)
pip install -r requirements.txt
# Create .env file
cp .env.example .env
# Edit .env with your API keys
```
### Development Workflow
```bash
# Activate environment
source .venv/bin/activate
# Run code quality checks (once implemented)
./code_quality_checks.sh
# Run server locally for testing
python server.py
# View logs
tail -f logs/mcp_server.log
# Run tests
python -m pytest tests/ -v
```
### Claude Desktop Configuration
Add to `~/.claude.json`:
```json
{
"mcpServers": {
"zen-marketing": {
"command": "/home/ben/mcp/zen-marketing/.venv/bin/python",
"args": ["/home/ben/mcp/zen-marketing/server.py"],
"env": {
"OPENROUTER_API_KEY": "your-openrouter-key",
"GEMINI_API_KEY": "your-gemini-key",
"DEFAULT_MODEL": "gemini-2.5-pro",
"FAST_MODEL": "gemini-flash",
"CREATIVE_MODEL": "minimax-m2",
"ENABLE_WEB_SEARCH": "true",
"DISABLED_TOOLS": "",
"LOG_LEVEL": "INFO"
}
}
}
}
```
**After modifying config:** Restart Claude Desktop for changes to take effect.
## Tool Development Guidelines
### Tool Categories
**Simple Tools** (single-shot, fast response):
- Inherit from `SimpleTool` base class
- Focus on speed and iteration
- Examples: `contentvariant`, `platformadapt`, `subjectlines`, `factcheck`
- Use fast models (gemini-flash) when possible
**Workflow Tools** (multi-step processes):
- Inherit from `WorkflowTool` base class
- Systematic step-by-step workflows
- Track progress, confidence, findings
- Examples: `styleguide`, `seooptimize`, `guestedit`, `linkstrategy`
### Temperature Guidelines for Marketing Tools
- **High (0.7-0.8)**: Content variation, creative adaptation
- **Medium (0.5-0.6)**: Balanced tasks, campaign planning
- **Low (0.3-0.4)**: Analytical work, SEO optimization
- **Very Low (0.2)**: Fact-checking, technical verification
### Model Selection Strategy
**Gemini 2.5 Pro** (`gemini-2.5-pro`):
- Analytical and strategic work
- SEO optimization
- Guest editing
- Internal linking analysis
- Voice analysis
- Campaign planning
- Fact-checking
**Gemini Flash** (`gemini-flash`):
- Fast bulk generation
- Subject line creation
- Quick variations
- Cost-effective iterations
**Minimax M2** (`minimax-m2`):
- Creative content generation
- Platform adaptation
- Content repurposing
- Marketing copy variations
### System Prompt Best Practices
Marketing tool prompts should:
1. **Specify output format clearly** (JSON, markdown, numbered list)
2. **Include platform constraints** (character limits, formatting rules)
3. **Emphasize preservation** (voice, expertise, technical accuracy)
4. **Request rationale** (why certain variations work, what to test)
5. **Avoid code terminology** (use "content" not "implementation")
Example prompt structure:
```python
CONTENTVARIANT_PROMPT = """
You are a marketing content strategist specializing in A/B testing and variation generation.
TASK: Generate multiple variations of marketing content for testing different approaches.
OUTPUT FORMAT:
Return variations as numbered list, each with:
1. The variation text
2. The testing angle (what makes it different)
3. Predicted audience response
CONSTRAINTS:
- Maintain core message across variations
- Respect platform character limits if specified
- Preserve brand voice characteristics
- Generate genuinely different approaches, not just word swaps
VARIATION TYPES:
- Hook variations: Different opening angles
- Length variations: Short, medium, long
- Tone variations: Professional, conversational, urgent
- Structure variations: Question, statement, story
- CTA variations: Different calls-to-action
"""
```
## Implementation Phases
### Phase 1: Foundation ✓ (You Are Here)
- [x] Create project directory
- [x] Write implementation plan (PLAN.md)
- [x] Create development guide (CLAUDE.md)
- [ ] Copy core architecture from zen-mcp-server
- [ ] Configure minimax provider
- [ ] Remove code-specific tools
- [ ] Test basic chat functionality
### Phase 2: Simple Tools (Priority: High)
Implementation order based on real-world usage:
1. **`contentvariant`** - Most frequently used (subject lines, social posts)
2. **`subjectlines`** - Specific workflow mentioned in project memories
3. **`platformadapt`** - Multi-channel content distribution
4. **`factcheck`** - Technical accuracy verification
### Phase 3: Workflow Tools (Priority: Medium)
5. **`styleguide`** - Writing rule enforcement (no em-dashes, etc.)
6. **`seooptimize`** - WordPress SEO optimization
7. **`guestedit`** - Guest content editing workflow
8. **`linkstrategy`** - Internal linking and cross-platform integration
### Phase 4: Advanced Features (Priority: Lower)
9. **`voiceanalysis`** - Voice extraction and consistency checking
10. **`campaignmap`** - Multi-touch campaign planning
## Tool Implementation Checklist
For each new tool:
**Code Files:**
- [ ] Create tool file in `tools/` (e.g., `tools/contentvariant.py`)
- [ ] Create system prompt in `systemprompts/` (e.g., `systemprompts/contentvariant_prompt.py`)
- [ ] Create test file in `tests/` (e.g., `tests/test_contentvariant.py`)
- [ ] Register tool in `server.py`
**Tool Class Requirements:**
- [ ] Inherit from `SimpleTool` or `WorkflowTool`
- [ ] Implement `get_name()` - tool name
- [ ] Implement `get_description()` - what it does
- [ ] Implement `get_system_prompt()` - behavior instructions
- [ ] Implement `get_default_temperature()` - creativity level
- [ ] Implement `get_model_category()` - FAST_RESPONSE or DEEP_THINKING
- [ ] Implement `get_request_model()` - Pydantic request schema
- [ ] Implement `get_input_schema()` - MCP tool schema
- [ ] Implement request/response formatting hooks
**Testing:**
- [ ] Unit tests for request validation
- [ ] Unit tests for response formatting
- [ ] Integration test with real model (optional)
- [ ] Add to quality checks script
**Documentation:**
- [ ] Add tool to README.md
- [ ] Create examples in docs/tools/
- [ ] Update PLAN.md progress
## Common Development Tasks
### Adding a New Simple Tool
```python
# tools/mynewtool.py
from typing import Optional
from pydantic import Field
from tools.shared.base_models import ToolRequest
from .simple.base import SimpleTool
from systemprompts import MYNEWTOOL_PROMPT
from config import TEMPERATURE_BALANCED
class MyNewToolRequest(ToolRequest):
"""Request model for MyNewTool"""
prompt: str = Field(..., description="What you want to accomplish")
files: Optional[list[str]] = Field(default_factory=list)
class MyNewTool(SimpleTool):
def get_name(self) -> str:
return "mynewtool"
def get_description(self) -> str:
return "Brief description of what this tool does"
def get_system_prompt(self) -> str:
return MYNEWTOOL_PROMPT
def get_default_temperature(self) -> float:
return TEMPERATURE_BALANCED
def get_model_category(self) -> "ToolModelCategory":
from tools.models import ToolModelCategory
return ToolModelCategory.FAST_RESPONSE
def get_request_model(self):
return MyNewToolRequest
```
### Adding a New Workflow Tool
```python
# tools/mynewworkflow.py
from typing import Optional
from pydantic import Field
from tools.shared.base_models import WorkflowRequest
from .workflow.base import WorkflowTool
from systemprompts import MYNEWWORKFLOW_PROMPT
class MyNewWorkflowRequest(WorkflowRequest):
"""Request model for workflow tool"""
step: str = Field(description="Current step content")
step_number: int = Field(ge=1)
total_steps: int = Field(ge=1)
next_step_required: bool
findings: str = Field(description="What was discovered")
# Add workflow-specific fields
class MyNewWorkflow(WorkflowTool):
# Implementation similar to Simple Tool
# but with workflow-specific logic
```
### Testing a Tool Manually
```bash
# Start server with debug logging
LOG_LEVEL=DEBUG python server.py
# In another terminal, watch logs
tail -f logs/mcp_server.log | grep -E "(TOOL_CALL|ERROR|MyNewTool)"
# In Claude Desktop, test:
# "Use zen-marketing to generate 10 subject lines about HVAC maintenance"
```
## Project Structure
```
zen-marketing/
├── server.py # Main MCP server entry point
├── config.py # Configuration constants
├── PLAN.md # Implementation plan (this doc)
├── CLAUDE.md # Development guide
├── README.md # User-facing documentation
├── requirements.txt # Python dependencies
├── .env.example # Environment variable template
├── .env # Local config (gitignored)
├── run-server.sh # Setup and run script
├── code_quality_checks.sh # Linting and testing
├── tools/ # Tool implementations
│ ├── __init__.py
│ ├── contentvariant.py # Bulk variation generator
│ ├── platformadapt.py # Cross-platform adapter
│ ├── subjectlines.py # Email subject line generator
│ ├── styleguide.py # Writing style enforcer
│ ├── seooptimize.py # SEO optimizer
│ ├── guestedit.py # Guest content editor
│ ├── linkstrategy.py # Internal linking strategist
│ ├── factcheck.py # Technical fact checker
│ ├── voiceanalysis.py # Voice extractor/validator
│ ├── campaignmap.py # Campaign planner
│ ├── chat.py # General chat (from zen)
│ ├── thinkdeep.py # Deep thinking (from zen)
│ ├── planner.py # Planning (from zen)
│ ├── models.py # Shared models
│ ├── simple/ # Simple tool base classes
│ │ └── base.py
│ ├── workflow/ # Workflow tool base classes
│ │ └── base.py
│ └── shared/ # Shared utilities
│ └── base_models.py
├── providers/ # AI provider implementations
│ ├── __init__.py
│ ├── base.py # Base provider interface
│ ├── gemini.py # Google Gemini
│ ├── minimax.py # Minimax (NEW)
│ ├── openrouter.py # OpenRouter fallback
│ ├── registry.py # Provider registry
│ └── shared/
├── systemprompts/ # System prompts for tools
│ ├── __init__.py
│ ├── contentvariant_prompt.py
│ ├── platformadapt_prompt.py
│ ├── subjectlines_prompt.py
│ ├── styleguide_prompt.py
│ ├── seooptimize_prompt.py
│ ├── guestedit_prompt.py
│ ├── linkstrategy_prompt.py
│ ├── factcheck_prompt.py
│ ├── voiceanalysis_prompt.py
│ ├── campaignmap_prompt.py
│ ├── chat_prompt.py # From zen
│ ├── thinkdeep_prompt.py # From zen
│ └── planner_prompt.py # From zen
├── utils/ # Utility functions
│ ├── conversation_memory.py # Conversation continuity
│ ├── file_utils.py # File handling
│ └── web_search.py # Web search integration
├── tests/ # Test suite
│ ├── __init__.py
│ ├── test_contentvariant.py
│ ├── test_platformadapt.py
│ ├── test_subjectlines.py
│ └── ...
├── logs/ # Log files (gitignored)
│ ├── mcp_server.log
│ └── mcp_activity.log
└── docs/ # Documentation
├── getting-started.md
├── tools/
│ ├── contentvariant.md
│ ├── platformadapt.md
│ └── ...
└── examples/
└── marketing-workflows.md
```
## Key Concepts from Zen Architecture
### Conversation Continuity
Every tool supports `continuation_id` to maintain context across interactions:
```python
# First call
result1 = await tool.execute({
"prompt": "Analyze this brand voice",
"files": ["brand_samples/post1.txt", "brand_samples/post2.txt"]
})
# Returns: continuation_id: "abc123"
# Follow-up call (remembers previous context)
result2 = await tool.execute({
"prompt": "Now check if this new draft matches the voice",
"files": ["new_draft.txt"],
"continuation_id": "abc123" # Preserves context
})
```
### File Handling
Tools automatically:
- Expand directories to individual files
- Deduplicate file lists
- Handle absolute paths
- Process images (screenshots, brand assets)
### Web Search Integration
Tools can request Claude to perform web searches:
```python
# In system prompt:
"If you need current information about [topic], request a web search from Claude."
# Claude will then use WebSearch tool and provide results
```
### Multi-Model Orchestration
Tools specify model category, server selects best available:
- `FAST_RESPONSE` → gemini-flash or equivalent
- `DEEP_THINKING` → gemini-2.5-pro or equivalent
- User can override with `model` parameter
## Debugging Common Issues
### Tool Not Appearing in Claude Desktop
1. Check `server.py` registers the tool
2. Verify tool is not in `DISABLED_TOOLS` env var
3. Restart Claude Desktop after config changes
4. Check logs: `tail -f logs/mcp_server.log`
### Model Selection Issues
1. Verify API keys in `.env`
2. Check provider registration in `providers/registry.py`
3. Test with explicit model name: `"model": "gemini-2.5-pro"`
4. Check logs for provider errors
### Response Formatting Issues
1. Validate system prompt specifies output format
2. Check response doesn't exceed token limits
3. Test with simpler input first
4. Review logs for truncation warnings
### Conversation Continuity Not Working
1. Verify `continuation_id` is being passed correctly
2. Check conversation hasn't expired (default 6 hours)
3. Validate conversation memory storage
4. Review logs: `grep "continuation_id" logs/mcp_server.log`
## Code Quality Standards
Before committing:
```bash
# Run all quality checks
./code_quality_checks.sh
# Manual checks:
ruff check . --fix # Linting
black . # Formatting
isort . # Import sorting
pytest tests/ -v # Run tests
```
## Marketing-Specific Considerations
### Character Limits by Platform
Tools should be aware of:
- **Twitter/Bluesky**: 280 characters
- **LinkedIn**: 3000 chars (1300 optimal)
- **Instagram**: 2200 characters
- **Facebook**: No hard limit (500 chars optimal)
- **Email subject**: 60 characters optimal
- **Email preview**: 90-100 characters
- **Meta description**: 156 characters
- **Page title**: 60 characters
### Writing Style Rules from Project Memories
- No em-dashes (use periods or semicolons)
- No "This isn't X, it's Y" constructions
- Direct affirmative statements over negations
- Semantic variety in paragraph openings
- Concrete metrics over abstract claims
- Technical accuracy preserved
- Author voice maintained
### Testing Angles for Variations
- Technical curiosity
- Contrarian/provocative
- Knowledge gap emphasis
- Urgency/timeliness
- Insider knowledge positioning
- Problem-solution framing
- Before-after transformation
- Social proof/credibility
- FOMO (fear of missing out)
- Educational value
## Next Session Goals
When you start the new session in `~/mcp/zen-marketing/`:
1. **Copy Core Files from Zen**
- Copy base architecture preserving git history
- Remove code-specific tools
- Update imports and references
2. **Configure Minimax Provider**
- Add minimax support to providers/
- Register in provider registry
- Test basic model calls
3. **Implement First Simple Tool**
- Start with `contentvariant` (highest priority)
- Create tool, system prompt, and tests
- Test end-to-end with Claude Desktop
4. **Validate Architecture**
- Ensure conversation continuity works
- Verify file handling
- Test web search integration
## Questions to Consider
Before implementing each tool:
1. What real-world workflow does this solve? (Reference project memories)
2. What's the minimum viable version?
3. What can go wrong? (Character limits, API errors, invalid input)
4. How will users test variations? (Output format)
5. Does it need web search? (Current info, fact-checking)
6. What's the right temperature? (Creative vs analytical)
7. Simple or workflow tool? (Single-shot vs multi-step)
## Resources
- **Zen MCP Server Repo**: Source for architecture and patterns
- **MCP Protocol Docs**: https://modelcontextprotocol.com
- **Claude Desktop Config**: `~/.claude.json`
- **Project Memories**: See PLAN.md for user workflow examples
- **Platform Best Practices**: Research current 2025 guidelines
---
**Ready to build?** Start the new session with:
```bash
cd ~/mcp/zen-marketing
# Then ask Claude to begin Phase 1 implementation
```

197
LICENSE Normal file
View file

@ -0,0 +1,197 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship covered by this License,
whether in source or binary form, which is made available under the
License, as indicated by a copyright notice that is included in or
attached to the work. (The copyright notice requirement does not
apply to derivative works of the License holder.)
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based upon (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and derivative works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control
systems, and issue tracking systems that are managed by, or on behalf
of, the Licensor for the purpose of discussing and improving the Work,
but excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to use, reproduce, modify, distribute, and otherwise
transfer the Work as part of a Derivative Work.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright notice to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Support. You may choose to offer, and to
charge a fee for, warranty, support, indemnity or other liability
obligations and/or rights consistent with this License. However,
in accepting such obligations, You may act only on Your own behalf
and on Your sole responsibility, not on behalf of any other
Contributor, and only if You agree to indemnify, defend, and hold
each Contributor harmless for any liability incurred by, or claims
asserted against, such Contributor by reason of your accepting any
such warranty or support.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in comments for the
particular file format. An identification line is also useful.
Copyright 2025 Beehive Innovations
https://github.com/BeehiveInnovations
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

513
PLAN.md Normal file
View file

@ -0,0 +1,513 @@
# Zen-Marketing MCP Server - Implementation Plan
## Project Overview
A specialized MCP server for Claude Desktop focused on marketing workflows, derived from the Zen MCP Server codebase. Designed for solo marketing professionals working in technical B2B content, particularly HVAC/technical education and software marketing.
**Target User Profile:** Marketing professional managing:
- Technical newsletters with guest contributors
- Multi-platform social media campaigns
- Educational content with SEO optimization
- A/B testing and content variation generation
- WordPress content management with internal linking
- Brand voice consistency across channels
## Core Principles
1. **Variation at Scale**: Generate 5-20 variations of content for A/B testing
2. **Platform Intelligence**: Understand character limits, tone, and best practices per platform
3. **Technical Accuracy**: Verify facts via web search before publishing
4. **Voice Preservation**: Maintain authentic voices (guest authors, brand persona)
5. **Cross-Platform Integration**: Connect content across blog, social, email, video
6. **Style Enforcement**: Apply specific writing guidelines automatically
## Planned Tools
### 1. `contentvariant` (Simple Tool)
**Purpose:** Rapid generation of multiple content variations for testing
**Real-world use case from project memory:**
- "Generate 15 email subject line variations testing different angles: technical curiosity, contrarian statements, FOMO, educational value"
- "Create 10 LinkedIn post variations of this announcement, mixing lengths and hooks"
**Key features:**
- Generate 5-20 variations in one call
- Specify variation dimensions (tone, length, hook, CTA placement)
- Fast model (gemini-flash) for speed
- Output formatted for easy copy-paste testing
**Parameters:**
- `content` (str): Base content to vary
- `variation_count` (int): Number of variations (default 10)
- `variation_types` (list): Types to vary - "tone", "length", "hook", "structure", "cta"
- `platform` (str): Target platform for context
- `constraints` (str): Character limits, style requirements
**Model:** gemini-flash (fast, cost-effective for bulk generation)
**Temperature:** 0.8 (creative variation while maintaining message)
---
### 2. `platformadapt` (Simple Tool)
**Purpose:** Adapt single content piece across multiple social platforms
**Real-world use case from project memory:**
- "Take this blog post intro and create versions for LinkedIn (1300 chars), Twitter (280 chars), Instagram caption (2200 chars), Facebook (500 chars), and Bluesky (300 chars)"
**Key features:**
- Single source → multiple platform versions
- Respects character limits automatically
- Platform-specific best practices (hashtags, tone, formatting)
- Preserves core message across adaptations
**Parameters:**
- `source_content` (str): Original content
- `source_platform` (str): Where content originated
- `target_platforms` (list): ["linkedin", "twitter", "instagram", "facebook", "bluesky"]
- `preserve_urls` (bool): Keep links intact vs. adapt
- `brand_voice` (str): Voice guidelines to maintain
**Model:** minimax/minimax-m2 (creative adaptations)
**Temperature:** 0.7
---
### 3. `subjectlines` (Simple Tool)
**Purpose:** Specialized tool for email subject line generation with angle testing
**Real-world use case from project memory:**
- "Generate subject lines testing: technical curiosity hook, contrarian/provocative angle, knowledge gap emphasis, urgency/timeliness, insider knowledge positioning"
**Key features:**
- Generates 15-25 subject lines by default
- Groups by psychological angle/hook
- Includes A/B testing rationale for each
- Character count validation (under 60 chars optimal)
- Emoji suggestions (optional)
**Parameters:**
- `email_topic` (str): Main topic/content
- `target_audience` (str): Who receives this
- `angles_to_test` (list): Psychological hooks to explore
- `include_emoji` (bool): Add emoji options
- `preview_text` (str): Optional - generate matching preview text
**Model:** gemini-flash (fast generation, creative but focused)
**Temperature:** 0.8
---
### 4. `styleguide` (Workflow Tool)
**Purpose:** Enforce specific writing guidelines and polish content
**Real-world use case from project memory:**
- "Check this content for: em-dashes (remove), 'This isn't X, it's Y' constructions (rewrite), semantic variety in paragraph structure, direct affirmative statements instead of negations"
**Key features:**
- Multi-step workflow: detection → flagging → rewriting → validation
- Tracks style violations with severity
- Provides before/after comparisons
- Custom rule definitions
**Workflow steps:**
1. Analyze content for style violations
2. Flag issues with line numbers and severity
3. Generate corrected version
4. Validate improvements
**Parameters:**
- `content` (str): Content to check
- `rules` (list): Style rules to enforce
- `rewrite_violations` (bool): Auto-fix vs. flag only
- `preserve_technical_terms` (bool): Don't change HVAC/technical terminology
**Common rules (from project memory):**
- No em-dashes (use periods or semicolons)
- No "This isn't X, it's Y" constructions
- Direct affirmative statements over negations
- Semantic variety in paragraph openings
- No clichéd structures
- Concrete metrics over abstract claims
**Model:** gemini-2.5-pro (analytical precision)
**Temperature:** 0.3 (consistency enforcement)
---
### 5. `seooptimize` (Workflow Tool)
**Purpose:** Comprehensive SEO optimization for WordPress content
**Real-world use case from project memory:**
- "Create SEO title under 60 chars, excerpt under 156 chars, suggest 5-7 WordPress tags, internal linking opportunities to foundational content"
**Workflow steps:**
1. Analyze content for primary keywords and topics
2. Generate SEO title (under 60 chars)
3. Create meta description/excerpt (under 156 chars)
4. Suggest WordPress tags (5-10)
5. Identify internal linking opportunities
6. Validate character limits and keyword density
**Parameters:**
- `content` (str): Full article content
- `existing_tags` (list): Current WordPress tag library
- `target_keywords` (list): Keywords to optimize for
- `internal_links_context` (str): Description of site structure for linking
- `platform` (str): "wordpress" (default), "ghost", "medium"
**Model:** gemini-2.5-pro (analytical, search-aware)
**Temperature:** 0.4
**Web search:** Enabled for keyword research
---
### 6. `guestedit` (Workflow Tool)
**Purpose:** Edit guest content while preserving author expertise and voice
**Real-world use case from project memory:**
- "Edit this guest article on PCB components: preserve technical authority, add key takeaways section, suggest internal links to foundational content, enhance educational flow, maintain expert's voice"
**Workflow steps:**
1. Analyze guest author's voice and technical expertise level
2. Identify areas needing clarification vs. areas to preserve
3. Generate educational enhancements (key takeaways, callouts)
4. Suggest internal linking without fabricating URLs
5. Validate technical accuracy via web search
6. Present edits with rationale
**Parameters:**
- `guest_content` (str): Original article
- `author_name` (str): Guest author for voice reference
- `expertise_level` (str): "expert", "practitioner", "educator"
- `enhancements_needed` (list): ["key_takeaways", "internal_links", "clarity", "seo"]
- `site_context` (str): Description of broader content ecosystem
- `verify_technical` (bool): Use web search for fact-checking
**Model:** gemini-2.5-pro (analytical + creative balance)
**Temperature:** 0.5
**Web search:** Enabled for technical verification
---
### 7. `linkstrategy` (Workflow Tool)
**Purpose:** Internal linking and cross-platform content integration strategy
**Real-world use case from project memory:**
- "Find opportunities to link this advanced PCB article to foundational HVAC knowledge. Connect blog with related podcast episode, Instagram post, and YouTube video from our content database."
**Workflow steps:**
1. Analyze content topics and technical depth
2. Identify prerequisite concepts needing internal links
3. Search for related content across platforms (blog, podcast, video, social)
4. Generate linking strategy with anchor text suggestions
5. Validate URLs against actual content (no fabrication)
6. Create cross-platform promotion plan
**Parameters:**
- `primary_content` (str): Main piece to link from/to
- `content_type` (str): "blog", "newsletter", "social_post"
- `available_platforms` (list): ["blog", "podcast", "youtube", "instagram"]
- `technical_depth` (str): "foundational", "intermediate", "advanced"
- `verify_links` (bool): Validate URLs exist before suggesting
**Model:** gemini-2.5-pro (analytical, relationship mapping)
**Temperature:** 0.4
**Web search:** Enabled for content discovery
---
### 8. `factcheck` (Simple Tool)
**Purpose:** Quick technical fact verification via web search
**Real-world use case from project memory:**
- "Verify these technical claims about voltage regulation chains in HVAC PCBs"
- "Check if this White-Rodgers universal control model number and compatibility claims are accurate"
**Key features:**
- Web search-powered verification
- Cites sources for each claim
- Flags unsupported or questionable statements
- Quick turnaround for pre-publish checks
**Parameters:**
- `content` (str): Content with claims to verify
- `technical_domain` (str): "hvac", "software", "general"
- `claim_type` (str): "product_specs", "technical_process", "statistics", "general"
- `confidence_threshold` (str): "high_confidence_only", "balanced", "comprehensive"
**Model:** gemini-2.5-pro (search-augmented, analytical)
**Temperature:** 0.2 (precision for facts)
**Web search:** Required (core functionality)
---
### 9. `voiceanalysis` (Workflow Tool)
**Purpose:** Extract and codify brand/author voice for consistency
**Real-world use case from project memory:**
- "Analyze these 10 pieces of Gary McCreadie's content to extract voice patterns, then check if this new draft matches his authentic teaching voice"
**Workflow steps:**
1. Analyze sample content for voice characteristics
2. Extract patterns: sentence structure, vocabulary, tone markers
3. Create voice profile with examples
4. Compare new content against profile
5. Flag inconsistencies with suggested rewrites
6. Track confidence in voice matching
**Parameters:**
- `sample_content` (list): 5-15 pieces showing authentic voice
- `author_name` (str): Voice owner for reference
- `new_content` (str): Content to validate against voice
- `voice_aspects` (list): ["tone", "vocabulary", "structure", "educational_style"]
- `generate_profile` (bool): Create reusable voice profile
**Model:** gemini-2.5-pro (pattern analysis)
**Temperature:** 0.3
---
### 10. `campaignmap` (Workflow Tool)
**Purpose:** Map content campaign across multiple touchpoints
**Real-world use case from project memory:**
- "Plan content campaign for measureQuick National Championship: social teaser posts, email sequence, on-site promotion, post-event follow-up, blog recap with internal links"
**Workflow steps:**
1. Define campaign goals and timeline
2. Map content across platforms and stages (awareness → consideration → action → retention)
3. Create content calendar with dependencies
4. Generate messaging for each touchpoint
5. Plan cross-promotion and internal linking
6. Set success metrics per channel
**Parameters:**
- `campaign_goal` (str): Primary objective
- `timeline_days` (int): Campaign duration
- `platforms` (list): Channels to use
- `content_pillars` (list): Key messages to reinforce
- `target_audience` (str): Audience description
- `existing_assets` (list): Content to repurpose
**Model:** gemini-2.5-pro (strategic planning)
**Temperature:** 0.6
---
## Tool Architecture
### Simple Tools (3)
Fast, single-shot interactions for rapid iteration:
- `contentvariant` - Bulk variation generation
- `platformadapt` - Cross-platform adaptation
- `subjectlines` - Email subject line testing
- `factcheck` - Quick verification
### Workflow Tools (6)
Multi-step systematic processes:
- `styleguide` - Style enforcement workflow
- `seooptimize` - Comprehensive SEO process
- `guestedit` - Guest content editing workflow
- `linkstrategy` - Internal linking analysis
- `voiceanalysis` - Voice extraction and validation
- `campaignmap` - Multi-touch campaign planning
### Keep from Original Zen
- `chat` - General brainstorming
- `thinkdeep` - Deep strategic thinking
- `planner` - Project planning
## Model Assignment Strategy
**Primary Models:**
- **minimax/minimax-m2**: Creative content generation (contentvariant, platformadapt)
- **google/gemini-2.5-pro**: Analytical and strategic work (seooptimize, guestedit, linkstrategy, voiceanalysis, campaignmap, styleguide, factcheck)
- **google/gemini-flash**: Fast bulk generation (subjectlines)
**Fallback Chain:**
1. Configured provider (minimax or gemini)
2. OpenRouter (if API key present)
3. Error (no fallback to Anthropic to avoid confusion)
## Configuration Defaults
```json
{
"mcpServers": {
"zen-marketing": {
"command": "/home/ben/mcp/zen-marketing/.venv/bin/python",
"args": ["/home/ben/mcp/zen-marketing/server.py"],
"env": {
"OPENROUTER_API_KEY": "your-key",
"DEFAULT_MODEL": "gemini-2.5-pro",
"FAST_MODEL": "gemini-flash",
"CREATIVE_MODEL": "minimax-m2",
"DISABLED_TOOLS": "analyze,refactor,testgen,secaudit,docgen,tracer,precommit,codereview,debug,consensus,challenge",
"ENABLE_WEB_SEARCH": "true",
"LOG_LEVEL": "INFO"
}
}
}
}
```
## Tool Temperature Defaults
| Tool | Temperature | Rationale |
|------|-------------|-----------|
| contentvariant | 0.8 | High creativity for diverse variations |
| platformadapt | 0.7 | Creative while maintaining message |
| subjectlines | 0.8 | Explore diverse psychological angles |
| styleguide | 0.3 | Precision for consistency |
| seooptimize | 0.4 | Balanced analytical + creative |
| guestedit | 0.5 | Balance preservation + enhancement |
| linkstrategy | 0.4 | Analytical relationship mapping |
| factcheck | 0.2 | Precision for accuracy |
| voiceanalysis | 0.3 | Pattern recognition precision |
| campaignmap | 0.6 | Strategic creativity |
## Differences from Zen Code
### What's Different
1. **Content-first**: Tools designed for writing, not coding
2. **Variation generation**: Built-in support for A/B testing at scale
3. **Platform awareness**: Character limits, formatting, best practices per channel
4. **Voice preservation**: Tools explicitly maintain author authenticity
5. **SEO native**: WordPress-specific optimization built in
6. **Verification emphasis**: Web search for fact-checking integrated
7. **No code tools**: Remove debug, codereview, precommit, refactor, testgen, secaudit, tracer
### What's Kept
1. **Conversation continuity**: `continuation_id` across tools
2. **Multi-model orchestration**: Different models for different tasks
3. **Workflow architecture**: Multi-step systematic processes
4. **File handling**: Attach content files, brand guidelines
5. **Image support**: Analyze screenshots, brand assets
6. **Thinking modes**: Control depth vs. cost
7. **Web search**: Current information access
### What's Enhanced
1. **Bulk operations**: Generate 5-20 variations in one call
2. **Style enforcement**: Custom writing rules
3. **Cross-platform thinking**: Native multi-channel workflows
4. **Technical verification**: Domain-specific fact-checking
## Implementation Phases
### Phase 1: Foundation (Week 1)
- Fork Zen MCP Server codebase
- Remove code-specific tools
- Configure minimax and gemini providers
- Update system prompts for marketing context
- Basic testing with chat tool
### Phase 2: Simple Tools (Week 2)
- Implement `contentvariant`
- Implement `platformadapt`
- Implement `subjectlines`
- Implement `factcheck`
- Test variation generation workflows
### Phase 3: Workflow Tools (Week 3-4)
- Implement `styleguide`
- Implement `seooptimize`
- Implement `guestedit`
- Implement `linkstrategy`
- Test multi-step workflows
### Phase 4: Advanced Features (Week 5)
- Implement `voiceanalysis`
- Implement `campaignmap`
- Add voice profile storage
- Test complex campaign workflows
### Phase 5: Polish & Documentation (Week 6)
- User documentation
- Example workflows
- Configuration guide
- Testing with real project memories
## Success Metrics
1. **Variation Quality**: Can generate 10+ usable subject lines in one call
2. **Platform Accuracy**: Respects character limits and formatting rules
3. **Voice Consistency**: Successfully preserves guest author expertise
4. **SEO Effectiveness**: Generated titles/descriptions pass WordPress SEO checks
5. **Cross-platform Integration**: Correctly maps content across blog/social/email
6. **Fact Accuracy**: Catches technical errors before publication
7. **Speed**: Simple tools respond in <10 seconds, workflow tools in <30 seconds
## File Structure
```
zen-marketing/
├── server.py # Main MCP server (adapted from zen)
├── config.py # Marketing-specific configuration
├── PLAN.md # This file
├── CLAUDE.md # Development guide
├── README.md # User documentation
├── requirements.txt # Python dependencies
├── .env.example # Environment template
├── run-server.sh # Setup and run script
├── tools/
│ ├── contentvariant.py # Bulk variation generation
│ ├── platformadapt.py # Cross-platform adaptation
│ ├── subjectlines.py # Email subject lines
│ ├── styleguide.py # Writing style enforcement
│ ├── seooptimize.py # WordPress SEO workflow
│ ├── guestedit.py # Guest content editing
│ ├── linkstrategy.py # Internal linking strategy
│ ├── factcheck.py # Technical verification
│ ├── voiceanalysis.py # Voice extraction/validation
│ ├── campaignmap.py # Campaign planning
│ └── shared/ # Shared utilities from zen
├── providers/
│ ├── gemini.py # Google Gemini provider
│ ├── minimax.py # Minimax provider (NEW)
│ ├── openrouter.py # OpenRouter fallback
│ └── registry.py # Provider registration
├── systemprompts/
│ ├── contentvariant_prompt.py
│ ├── platformadapt_prompt.py
│ ├── subjectlines_prompt.py
│ ├── styleguide_prompt.py
│ ├── seooptimize_prompt.py
│ ├── guestedit_prompt.py
│ ├── linkstrategy_prompt.py
│ ├── factcheck_prompt.py
│ ├── voiceanalysis_prompt.py
│ └── campaignmap_prompt.py
└── tests/
├── test_contentvariant.py
├── test_platformadapt.py
└── ...
```
## Development Priorities
### High Priority (Essential)
1. `contentvariant` - Most used feature from project memories
2. `subjectlines` - Specific workflow mentioned multiple times
3. `platformadapt` - Core multi-channel need
4. `styleguide` - Explicit writing rules enforcement
5. `factcheck` - Technical accuracy verification
### Medium Priority (Important)
6. `seooptimize` - WordPress-specific SEO needs
7. `guestedit` - Guest content workflow
8. `linkstrategy` - Internal linking mentioned frequently
### Lower Priority (Nice to Have)
9. `voiceanalysis` - Advanced voice preservation
10. `campaignmap` - Strategic planning (can use planner tool initially)
## Next Steps
1. Create CLAUDE.md with development guide for zen-marketing
2. Start new session in ~/mcp/zen-marketing/ directory
3. Begin Phase 1 implementation:
- Copy core architecture from zen-mcp-server
- Configure minimax provider
- Remove code-specific tools
- Test with basic chat functionality
4. Implement Phase 2 simple tools starting with `contentvariant`

163
README.md Normal file
View file

@ -0,0 +1,163 @@
# Zen-Marketing MCP Server
> **Status:** 🚧 In Development - Phase 1
AI-powered marketing tools for Claude Desktop, specialized for technical B2B content creation, multi-platform campaigns, and content variation testing.
## What is This?
A fork of the [Zen MCP Server](https://github.com/BeehiveInnovations/zen-mcp-server) optimized for marketing workflows instead of software development. Provides Claude Desktop with specialized tools for:
- **Content Variation** - Generate 5-20 variations for A/B testing
- **Platform Adaptation** - Adapt content across LinkedIn, Twitter, newsletters, blogs
- **Style Enforcement** - Apply writing guidelines automatically
- **SEO Optimization** - WordPress-specific SEO workflows
- **Guest Editing** - Edit external content while preserving author voice
- **Fact Verification** - Technical accuracy checking via web search
- **Internal Linking** - Cross-platform content integration strategy
- **Campaign Planning** - Multi-touch campaign mapping
## For Whom?
Solo marketing professionals managing:
- Technical newsletters and educational content
- Multi-platform social media campaigns
- WordPress blogs with SEO requirements
- Guest contributor content
- A/B testing and content experimentation
- B2B SaaS or technical product marketing
## Key Differences from Zen Code
| Zen Code | Zen Marketing |
|----------|---------------|
| Code review, debugging, testing | Content variation, SEO, style guide |
| Software architecture analysis | Campaign planning, voice analysis |
| Development workflows | Marketing workflows |
| Technical accuracy for code | Technical accuracy for content |
| GitHub/git integration | WordPress/platform integration |
## Tools (Planned)
### Simple Tools (Fast, Single-Shot)
- `contentvariant` - Generate 5-20 variations for testing
- `platformadapt` - Cross-platform content adaptation
- `subjectlines` - Email subject line generation
- `factcheck` - Technical fact verification
### Workflow Tools (Multi-Step)
- `styleguide` - Writing style enforcement
- `seooptimize` - WordPress SEO optimization
- `guestedit` - Guest content editing
- `linkstrategy` - Internal linking strategy
- `voiceanalysis` - Voice extraction and validation
- `campaignmap` - Campaign planning
### Kept from Zen
- `chat` - General brainstorming
- `thinkdeep` - Deep strategic thinking
- `planner` - Project planning
## Models
**Primary Models:**
- **Gemini 2.5 Pro** - Analytical and strategic work
- **Gemini Flash** - Fast bulk generation
- **Minimax M2** - Creative content generation
**Fallback:**
- OpenRouter (if configured)
## Quick Start
> **Note:** Not yet functional - see [PLAN.md](PLAN.md) for implementation roadmap
```bash
# Clone and setup
git clone [repo-url]
cd zen-marketing
./run-server.sh
# Configure Claude Desktop (~/.claude.json)
{
"mcpServers": {
"zen-marketing": {
"command": "/home/ben/mcp/zen-marketing/.venv/bin/python",
"args": ["/home/ben/mcp/zen-marketing/server.py"],
"env": {
"GEMINI_API_KEY": "your-key",
"OPENROUTER_API_KEY": "your-key",
"DEFAULT_MODEL": "gemini-2.5-pro"
}
}
}
}
# Restart Claude Desktop
```
## Example Usage
```
# Generate subject line variations
"Use zen-marketing contentvariant to generate 15 subject lines for an HVAC newsletter about PCB diagnostics. Test angles: technical curiosity, contrarian, knowledge gap, urgency."
# Adapt content across platforms
"Use zen-marketing platformadapt to take this blog post intro and create versions for LinkedIn (1300 chars), Twitter (280), Instagram (2200), and newsletter."
# Enforce writing style
"Use zen-marketing styleguide to check this draft for em-dashes, 'This isn't X, it's Y' constructions, and ensure direct affirmative statements."
# SEO optimize
"Use zen-marketing seooptimize to create SEO title under 60 chars, excerpt under 156 chars, and suggest WordPress tags for this article about HVAC voltage regulation."
```
## Development Status
### Phase 1: Foundation (Current)
- [x] Create directory structure
- [x] Write implementation plan
- [x] Create development guide
- [ ] Copy core architecture from zen-mcp-server
- [ ] Configure minimax provider
- [ ] Test basic functionality
### Phase 2: Simple Tools (Next)
- [ ] Implement `contentvariant`
- [ ] Implement `subjectlines`
- [ ] Implement `platformadapt`
- [ ] Implement `factcheck`
See [PLAN.md](PLAN.md) for complete roadmap.
## Documentation
- **[PLAN.md](PLAN.md)** - Detailed implementation plan with tool designs
- **[CLAUDE.md](CLAUDE.md)** - Development guide for contributors
- **[Project Memories](PLAN.md#project-overview)** - Real-world usage examples
## Architecture
Based on Zen MCP Server's proven architecture:
- **Conversation continuity** via `continuation_id`
- **Multi-model orchestration** (Gemini, Minimax, OpenRouter)
- **Simple tools** for fast iteration
- **Workflow tools** for multi-step processes
- **Web search integration** for current information
- **File handling** for content and brand assets
## License
Apache 2.0 License (inherited from Zen MCP Server)
## Acknowledgments
Built on the foundation of:
- [Zen MCP Server](https://github.com/BeehiveInnovations/zen-mcp-server) by Fahad Gilani
- [Model Context Protocol](https://modelcontextprotocol.com) by Anthropic
- [Claude Desktop](https://claude.ai/download) - AI interface
---
**Status:** Planning phase complete, ready for implementation
**Next:** Start new Claude session in `~/mcp/zen-marketing/` to begin Phase 1

107
config.py Normal file
View file

@ -0,0 +1,107 @@
"""
Configuration and constants for Zen-Marketing MCP Server
This module centralizes all configuration settings for the Zen-Marketing MCP Server.
It defines model configurations, token limits, temperature defaults, and other
constants used throughout the application.
Configuration values can be overridden by environment variables where appropriate.
"""
import os
# Version and metadata
__version__ = "0.1.0"
__updated__ = "2025-11-07"
__author__ = "Ben (based on Zen MCP Server by Fahad Gilani)"
# Model configuration
# DEFAULT_MODEL: The default model used for all AI operations
# Can be overridden by setting DEFAULT_MODEL environment variable
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "google/gemini-2.5-pro-latest")
# Fast model for quick operations (variations, subject lines)
FAST_MODEL = os.getenv("FAST_MODEL", "google/gemini-2.5-flash-preview-09-2025")
# Creative model for content generation
CREATIVE_MODEL = os.getenv("CREATIVE_MODEL", "minimax/minimax-m2")
# Auto mode detection - when DEFAULT_MODEL is "auto", Claude picks the model
IS_AUTO_MODE = DEFAULT_MODEL.lower() == "auto"
# Temperature defaults for different content types
# Temperature controls the randomness/creativity of model responses
# Lower values (0.0-0.3) produce more deterministic, focused responses
# Higher values (0.7-1.0) produce more creative, varied responses
# TEMPERATURE_PRECISION: Used for fact-checking and technical verification
TEMPERATURE_PRECISION = 0.2 # For factcheck, technical verification
# TEMPERATURE_ANALYTICAL: Used for style enforcement and SEO optimization
TEMPERATURE_ANALYTICAL = 0.3 # For styleguide, seooptimize, voiceanalysis
# TEMPERATURE_BALANCED: Used for strategic planning
TEMPERATURE_BALANCED = 0.5 # For guestedit, linkstrategy, campaignmap
# TEMPERATURE_CREATIVE: Used for content variation and adaptation
TEMPERATURE_CREATIVE = 0.7 # For platformadapt
# TEMPERATURE_HIGHLY_CREATIVE: Used for bulk variation generation
TEMPERATURE_HIGHLY_CREATIVE = 0.8 # For contentvariant, subjectlines
# Thinking Mode Defaults
DEFAULT_THINKING_MODE_THINKDEEP = os.getenv("DEFAULT_THINKING_MODE_THINKDEEP", "high")
# MCP Protocol Transport Limits
def _calculate_mcp_prompt_limit() -> int:
"""
Calculate MCP prompt size limit based on MAX_MCP_OUTPUT_TOKENS environment variable.
Returns:
Maximum character count for user input prompts
"""
max_tokens_str = os.getenv("MAX_MCP_OUTPUT_TOKENS")
if max_tokens_str:
try:
max_tokens = int(max_tokens_str)
# Allocate 60% of tokens for input, convert to characters (~4 chars per token)
input_token_budget = int(max_tokens * 0.6)
character_limit = input_token_budget * 4
return character_limit
except (ValueError, TypeError):
pass
# Default fallback: 60,000 characters
return 60_000
MCP_PROMPT_SIZE_LIMIT = _calculate_mcp_prompt_limit()
# Language/Locale Configuration
LOCALE = os.getenv("LOCALE", "")
# Platform character limits
PLATFORM_LIMITS = {
"twitter": 280,
"bluesky": 300,
"linkedin": 3000,
"linkedin_optimal": 1300,
"instagram": 2200,
"facebook": 500, # Optimal length
"email_subject": 60,
"email_preview": 100,
"meta_description": 156,
"page_title": 60,
}
# Web search configuration
ENABLE_WEB_SEARCH = os.getenv("ENABLE_WEB_SEARCH", "true").lower() == "true"
# Tool disabling
# Comma-separated list of tools to disable
DISABLED_TOOLS_STR = os.getenv("DISABLED_TOOLS", "")
DISABLED_TOOLS = set(tool.strip() for tool in DISABLED_TOOLS_STR.split(",") if tool.strip())
# Logging configuration
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")

20
providers/__init__.py Normal file
View file

@ -0,0 +1,20 @@
"""Model provider abstractions for supporting multiple AI providers."""
from .base import ModelProvider
from .gemini import GeminiModelProvider
from .openai_compatible import OpenAICompatibleProvider
from .openai_provider import OpenAIModelProvider
from .openrouter import OpenRouterProvider
from .registry import ModelProviderRegistry
from .shared import ModelCapabilities, ModelResponse
__all__ = [
"ModelProvider",
"ModelResponse",
"ModelCapabilities",
"ModelProviderRegistry",
"GeminiModelProvider",
"OpenAIModelProvider",
"OpenAICompatibleProvider",
"OpenRouterProvider",
]

268
providers/base.py Normal file
View file

@ -0,0 +1,268 @@
"""Base interfaces and common behaviour for model providers."""
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from .shared import ModelCapabilities, ModelResponse, ProviderType
logger = logging.getLogger(__name__)
class ModelProvider(ABC):
"""Abstract base class for all model backends in the MCP server.
Role
Defines the interface every provider must implement so the registry,
restriction service, and tools have a uniform surface for listing
models, resolving aliases, and executing requests.
Responsibilities
* expose static capability metadata for each supported model via
:class:`ModelCapabilities`
* accept user prompts, forward them to the underlying SDK, and wrap
responses in :class:`ModelResponse`
* report tokenizer counts for budgeting and validation logic
* advertise provider identity (``ProviderType``) so restriction
policies can map environment configuration onto providers
* validate whether a model name or alias is recognised by the provider
Shared helpers like temperature validation, alias resolution, and
restriction-aware ``list_models`` live here so concrete subclasses only
need to supply their catalogue and wire up SDK-specific behaviour.
"""
# All concrete providers must define their supported models
MODEL_CAPABILITIES: dict[str, Any] = {}
def __init__(self, api_key: str, **kwargs):
"""Initialize the provider with API key and optional configuration."""
self.api_key = api_key
self.config = kwargs
# ------------------------------------------------------------------
# Provider identity & capability surface
# ------------------------------------------------------------------
@abstractmethod
def get_provider_type(self) -> ProviderType:
"""Return the concrete provider identity."""
def get_capabilities(self, model_name: str) -> ModelCapabilities:
"""Resolve capability metadata for a model name.
This centralises the alias resolution lookup restriction check
pipeline so providers only override the pieces they genuinely need to
customise. Subclasses usually only override ``_lookup_capabilities`` to
integrate a registry or dynamic source, or ``_finalise_capabilities`` to
tweak the returned object.
"""
resolved_name = self._resolve_model_name(model_name)
capabilities = self._lookup_capabilities(resolved_name, model_name)
if capabilities is None:
self._raise_unsupported_model(model_name)
self._ensure_model_allowed(capabilities, resolved_name, model_name)
return self._finalise_capabilities(capabilities, resolved_name, model_name)
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Return statically declared capabilities when available."""
model_map = getattr(self, "MODEL_CAPABILITIES", None)
if isinstance(model_map, dict) and model_map:
return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)}
return {}
def list_models(
self,
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
) -> list[str]:
"""Return formatted model names supported by this provider."""
model_configs = self.get_all_model_capabilities()
if not model_configs:
return []
restriction_service = None
if respect_restrictions:
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if restriction_service:
allowed_configs = {}
for model_name, config in model_configs.items():
if restriction_service.is_allowed(self.get_provider_type(), model_name):
allowed_configs[model_name] = config
model_configs = allowed_configs
if not model_configs:
return []
return ModelCapabilities.collect_model_names(
model_configs,
include_aliases=include_aliases,
lowercase=lowercase,
unique=unique,
)
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
@abstractmethod
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using the model."""
def count_tokens(self, text: str, model_name: str) -> int:
"""Estimate token usage for a piece of text."""
resolved_model = self._resolve_model_name(model_name)
if not text:
return 0
estimated = max(1, len(text) // 4)
logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model)
return estimated
def close(self) -> None:
"""Clean up any resources held by the provider."""
return
# ------------------------------------------------------------------
# Validation hooks
# ------------------------------------------------------------------
def validate_model_name(self, model_name: str) -> bool:
"""Return ``True`` when the model resolves to an allowed capability."""
try:
self.get_capabilities(model_name)
except ValueError:
return False
return True
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters against capabilities."""
capabilities = self.get_capabilities(model_name)
if not capabilities.temperature_constraint.validate(temperature):
constraint_desc = capabilities.temperature_constraint.get_description()
raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
# ------------------------------------------------------------------
# Preference / registry hooks
# ------------------------------------------------------------------
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get the preferred model from this provider for a given category."""
return None
def get_model_registry(self) -> Optional[dict[str, Any]]:
"""Return the model registry backing this provider, if any."""
return None
# ------------------------------------------------------------------
# Capability lookup pipeline
# ------------------------------------------------------------------
def _lookup_capabilities(
self,
canonical_name: str,
requested_name: Optional[str] = None,
) -> Optional[ModelCapabilities]:
"""Return ``ModelCapabilities`` for the canonical model name."""
return self.get_all_model_capabilities().get(canonical_name)
def _ensure_model_allowed(
self,
capabilities: ModelCapabilities,
canonical_name: str,
requested_name: str,
) -> None:
"""Raise ``ValueError`` if the model violates restriction policy."""
try:
from utils.model_restrictions import get_restriction_service
except Exception: # pragma: no cover - only triggered if service import breaks
return
restriction_service = get_restriction_service()
if not restriction_service:
return
if restriction_service.is_allowed(self.get_provider_type(), canonical_name, requested_name):
return
raise ValueError(
f"{self.get_provider_type().value} model '{canonical_name}' is not allowed by restriction policy."
)
def _finalise_capabilities(
self,
capabilities: ModelCapabilities,
canonical_name: str,
requested_name: str,
) -> ModelCapabilities:
"""Allow subclasses to adjust capability metadata before returning."""
return capabilities
def _raise_unsupported_model(self, model_name: str) -> None:
"""Raise the canonical unsupported-model error."""
raise ValueError(f"Unsupported model '{model_name}' for provider {self.get_provider_type().value}.")
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve model shorthand to full name.
This implementation uses the hook methods to support different
model configuration sources.
Args:
model_name: Model name that may be an alias
Returns:
Resolved model name
"""
# Get model configurations from the hook method
model_configs = self.get_all_model_capabilities()
# First check if it's already a base model name (case-sensitive exact match)
if model_name in model_configs:
return model_name
# Check case-insensitively for both base models and aliases
model_name_lower = model_name.lower()
# Check base model names case-insensitively
for base_model in model_configs:
if base_model.lower() == model_name_lower:
return base_model
# Check aliases from the model configurations
alias_map = ModelCapabilities.collect_aliases(model_configs)
for base_model, aliases in alias_map.items():
if any(alias.lower() == model_name_lower for alias in aliases):
return base_model
# If not found, return as-is
return model_name

196
providers/custom.py Normal file
View file

@ -0,0 +1,196 @@
"""Custom API provider implementation."""
import logging
import os
from typing import Optional
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
from .shared import ModelCapabilities, ModelResponse, ProviderType
class CustomProvider(OpenAICompatibleProvider):
"""Adapter for self-hosted or local OpenAI-compatible endpoints.
Role
Provide a uniform bridge between the MCP server and user-managed
OpenAI-compatible services (Ollama, vLLM, LM Studio, bespoke gateways).
By subclassing :class:`OpenAICompatibleProvider` it inherits request and
token handling, while the custom registry exposes locally defined model
metadata.
Notable behaviour
* Uses :class:`OpenRouterModelRegistry` to load model definitions and
aliases so custom deployments share the same metadata pipeline as
OpenRouter itself.
* Normalises version-tagged model names (``model:latest``) and applies
restriction policies just like cloud providers, ensuring consistent
behaviour across environments.
"""
FRIENDLY_NAME = "Custom API"
# Model registry for managing configurations and aliases (shared with OpenRouter)
_registry: Optional[OpenRouterModelRegistry] = None
def __init__(self, api_key: str = "", base_url: str = "", **kwargs):
"""Initialize Custom provider for local/self-hosted models.
This provider supports any OpenAI-compatible API endpoint including:
- Ollama (typically no API key required)
- vLLM (may require API key)
- LM Studio (may require API key)
- Text Generation WebUI (may require API key)
- Enterprise/self-hosted APIs (typically require API key)
Args:
api_key: API key for the custom endpoint. Can be empty string for
providers that don't require authentication (like Ollama).
Falls back to CUSTOM_API_KEY environment variable if not provided.
base_url: Base URL for the custom API endpoint (e.g., 'http://localhost:11434/v1').
Falls back to CUSTOM_API_URL environment variable if not provided.
**kwargs: Additional configuration passed to parent OpenAI-compatible provider
Raises:
ValueError: If no base_url is provided via parameter or environment variable
"""
# Fall back to environment variables only if not provided
if not base_url:
base_url = os.getenv("CUSTOM_API_URL", "")
if not api_key:
api_key = os.getenv("CUSTOM_API_KEY", "")
if not base_url:
raise ValueError(
"Custom API URL must be provided via base_url parameter or CUSTOM_API_URL environment variable"
)
# For Ollama and other providers that don't require authentication,
# set a dummy API key to avoid OpenAI client header issues
if not api_key:
api_key = "dummy-key-for-unauthenticated-endpoint"
logging.debug("Using dummy API key for unauthenticated custom endpoint")
logging.info(f"Initializing Custom provider with endpoint: {base_url}")
super().__init__(api_key, base_url=base_url, **kwargs)
# Initialize model registry (shared with OpenRouter for consistent aliases)
if CustomProvider._registry is None:
CustomProvider._registry = OpenRouterModelRegistry()
# Log loaded models and aliases only on first load
models = self._registry.list_models()
aliases = self._registry.list_aliases()
logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases")
# ------------------------------------------------------------------
# Capability surface
# ------------------------------------------------------------------
def _lookup_capabilities(
self,
canonical_name: str,
requested_name: Optional[str] = None,
) -> Optional[ModelCapabilities]:
"""Return capabilities for models explicitly marked as custom."""
builtin = super()._lookup_capabilities(canonical_name, requested_name)
if builtin is not None:
return builtin
registry_entry = self._registry.resolve(canonical_name)
if registry_entry and getattr(registry_entry, "is_custom", False):
registry_entry.provider = ProviderType.CUSTOM
return registry_entry
logging.debug(
"Custom provider cannot resolve model '%s'; ensure it is declared with 'is_custom': true in custom_models.json",
canonical_name,
)
return None
def get_provider_type(self) -> ProviderType:
"""Identify this provider for restriction and logging logic."""
return ProviderType.CUSTOM
# ------------------------------------------------------------------
# Validation
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using the custom API.
Args:
prompt: User prompt to send to the model
model_name: Name of the model to use
system_prompt: Optional system prompt for model behavior
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
# Resolve model alias to actual model name
resolved_model = self._resolve_model_name(model_name)
# Call parent method with resolved model name
return super().generate_content(
prompt=prompt,
model_name=resolved_model,
system_prompt=system_prompt,
temperature=temperature,
max_output_tokens=max_output_tokens,
**kwargs,
)
# ------------------------------------------------------------------
# Registry helpers
# ------------------------------------------------------------------
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve registry aliases and strip version tags for local models."""
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
if ":" in model_name:
base_model = model_name.split(":")[0]
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
base_config = self._registry.resolve(base_model)
if base_config:
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
return base_config.model_name
return base_model
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Expose registry capabilities for models marked as custom."""
if not self._registry:
return {}
capabilities: dict[str, ModelCapabilities] = {}
for model_name in self._registry.list_models():
config = self._registry.resolve(model_name)
if config and getattr(config, "is_custom", False):
capabilities[model_name] = config
return capabilities

473
providers/dial.py Normal file
View file

@ -0,0 +1,473 @@
"""DIAL (Data & AI Layer) model provider implementation."""
import logging
import os
import threading
import time
from typing import Optional
from .openai_compatible import OpenAICompatibleProvider
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
logger = logging.getLogger(__name__)
class DIALModelProvider(OpenAICompatibleProvider):
"""Client for the DIAL (Data & AI Layer) aggregation service.
DIAL exposes several third-party models behind a single OpenAI-compatible
endpoint. This provider wraps the service, publishes capability metadata
for the known deployments, and centralises retry/backoff settings tailored
to DIAL's latency characteristics.
"""
FRIENDLY_NAME = "DIAL"
# Retry configuration for API calls
MAX_RETRIES = 4
RETRY_DELAYS = [1, 3, 5, 8] # seconds
# Model configurations using ModelCapabilities objects
MODEL_CAPABILITIES = {
"o3-2025-04-16": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="o3-2025-04-16",
friendly_name="DIAL (O3)",
context_window=200_000,
max_output_tokens=100_000,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # DIAL may not expose function calling
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=False, # O3 models don't accept temperature
temperature_constraint=TemperatureConstraint.create("fixed"),
description="OpenAI O3 via DIAL - Strong reasoning model",
aliases=["o3"],
),
"o4-mini-2025-04-16": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="o4-mini-2025-04-16",
friendly_name="DIAL (O4-mini)",
context_window=200_000,
max_output_tokens=100_000,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # DIAL may not expose function calling
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=False, # O4 models don't accept temperature
temperature_constraint=TemperatureConstraint.create("fixed"),
description="OpenAI O4-mini via DIAL - Fast reasoning model",
aliases=["o4-mini"],
),
"anthropic.claude-sonnet-4.1-20250805-v1:0": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="anthropic.claude-sonnet-4.1-20250805-v1:0",
friendly_name="DIAL (Sonnet 4.1)",
context_window=200_000,
max_output_tokens=64_000,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # Claude doesn't have function calling
supports_json_mode=False, # Claude doesn't have JSON mode
supports_images=True,
max_image_size_mb=5.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="Claude Sonnet 4.1 via DIAL - Balanced performance",
aliases=["sonnet-4.1", "sonnet-4"],
),
"anthropic.claude-sonnet-4.1-20250805-v1:0-with-thinking": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="anthropic.claude-sonnet-4.1-20250805-v1:0-with-thinking",
friendly_name="DIAL (Sonnet 4.1 Thinking)",
context_window=200_000,
max_output_tokens=64_000,
supports_extended_thinking=True, # Thinking mode variant
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # Claude doesn't have function calling
supports_json_mode=False, # Claude doesn't have JSON mode
supports_images=True,
max_image_size_mb=5.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="Claude Sonnet 4.1 with thinking mode via DIAL",
aliases=["sonnet-4.1-thinking", "sonnet-4-thinking"],
),
"anthropic.claude-opus-4.1-20250805-v1:0": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="anthropic.claude-opus-4.1-20250805-v1:0",
friendly_name="DIAL (Opus 4.1)",
context_window=200_000,
max_output_tokens=64_000,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # Claude doesn't have function calling
supports_json_mode=False, # Claude doesn't have JSON mode
supports_images=True,
max_image_size_mb=5.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="Claude Opus 4.1 via DIAL - Most capable Claude model",
aliases=["opus-4.1", "opus-4"],
),
"anthropic.claude-opus-4.1-20250805-v1:0-with-thinking": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="anthropic.claude-opus-4.1-20250805-v1:0-with-thinking",
friendly_name="DIAL (Opus 4.1 Thinking)",
context_window=200_000,
max_output_tokens=64_000,
supports_extended_thinking=True, # Thinking mode variant
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # Claude doesn't have function calling
supports_json_mode=False, # Claude doesn't have JSON mode
supports_images=True,
max_image_size_mb=5.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="Claude Opus 4.1 with thinking mode via DIAL",
aliases=["opus-4.1-thinking", "opus-4-thinking"],
),
"gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="gemini-2.5-pro-preview-03-25-google-search",
friendly_name="DIAL (Gemini 2.5 Pro Search)",
context_window=1_000_000,
max_output_tokens=65_536,
supports_extended_thinking=False, # DIAL doesn't expose thinking mode
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # DIAL may not expose function calling
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="Gemini 2.5 Pro with Google Search via DIAL",
aliases=["gemini-2.5-pro-search"],
),
"gemini-2.5-pro-preview-05-06": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="gemini-2.5-pro-preview-05-06",
friendly_name="DIAL (Gemini 2.5 Pro)",
context_window=1_000_000,
max_output_tokens=65_536,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # DIAL may not expose function calling
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="Gemini 2.5 Pro via DIAL - Deep reasoning",
aliases=["gemini-2.5-pro"],
),
"gemini-2.5-flash-preview-05-20": ModelCapabilities(
provider=ProviderType.DIAL,
model_name="gemini-2.5-flash-preview-05-20",
friendly_name="DIAL (Gemini Flash 2.5)",
context_window=1_000_000,
max_output_tokens=65_536,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False, # DIAL may not expose function calling
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="Gemini 2.5 Flash via DIAL - Ultra-fast",
aliases=["gemini-2.5-flash"],
),
}
def __init__(self, api_key: str, **kwargs):
"""Initialize DIAL provider with API key and host.
Args:
api_key: DIAL API key for authentication
**kwargs: Additional configuration options
"""
# Get DIAL API host from environment or kwargs
dial_host = kwargs.get("base_url") or os.getenv("DIAL_API_HOST") or "https://core.dialx.ai"
# DIAL uses /openai endpoint for OpenAI-compatible API
if not dial_host.endswith("/openai"):
dial_host = f"{dial_host.rstrip('/')}/openai"
kwargs["base_url"] = dial_host
# Get API version from environment or use default
self.api_version = os.getenv("DIAL_API_VERSION", "2024-12-01-preview")
# Add DIAL-specific headers
# DIAL uses Api-Key header instead of Authorization: Bearer
# Reference: https://dialx.ai/dial_api#section/Authorization
self.DEFAULT_HEADERS = {
"Api-Key": api_key,
}
# Store the actual API key for use in Api-Key header
self._dial_api_key = api_key
# Pass a placeholder API key to OpenAI client - we'll override the auth header in httpx
# The actual authentication happens via the Api-Key header in the httpx client
super().__init__("placeholder-not-used", **kwargs)
# Cache for deployment-specific clients to avoid recreating them on each request
self._deployment_clients = {}
# Lock to ensure thread-safe client creation
self._client_lock = threading.Lock()
# Create a SINGLE shared httpx client for the provider instance
import httpx
# Create custom event hooks to remove Authorization header
def remove_auth_header(request):
"""Remove Authorization header that OpenAI client adds."""
# httpx headers are case-insensitive, so we need to check all variations
headers_to_remove = []
for header_name in request.headers:
if header_name.lower() == "authorization":
headers_to_remove.append(header_name)
for header_name in headers_to_remove:
del request.headers[header_name]
self._http_client = httpx.Client(
timeout=self.timeout_config,
verify=True,
follow_redirects=True,
headers=self.DEFAULT_HEADERS.copy(), # Include DIAL headers including Api-Key
limits=httpx.Limits(
max_keepalive_connections=5,
max_connections=10,
keepalive_expiry=30.0,
),
event_hooks={"request": [remove_auth_header]},
)
logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}")
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.DIAL
def _get_deployment_client(self, deployment: str):
"""Get or create a cached client for a specific deployment.
This avoids recreating OpenAI clients on every request, improving performance.
Reuses the shared HTTP client for connection pooling.
Args:
deployment: The deployment/model name
Returns:
OpenAI client configured for the specific deployment
"""
# Check if client already exists without locking for performance
if deployment in self._deployment_clients:
return self._deployment_clients[deployment]
# Use lock to ensure thread-safe client creation
with self._client_lock:
# Double-check pattern: check again inside the lock
if deployment not in self._deployment_clients:
from openai import OpenAI
# Build deployment-specific URL
base_url = str(self.client.base_url)
if base_url.endswith("/"):
base_url = base_url[:-1]
# Remove /openai suffix if present to reconstruct properly
if base_url.endswith("/openai"):
base_url = base_url[:-7]
deployment_url = f"{base_url}/openai/deployments/{deployment}"
# Create and cache the client, REUSING the shared http_client
# Use placeholder API key - Authorization header will be removed by http_client event hook
self._deployment_clients[deployment] = OpenAI(
api_key="placeholder-not-used",
base_url=deployment_url,
http_client=self._http_client, # Pass the shared client with Api-Key header
default_query={"api-version": self.api_version}, # Add api-version as query param
)
return self._deployment_clients[deployment]
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
images: Optional[list[str]] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using DIAL's deployment-specific endpoint.
DIAL uses Azure OpenAI-style deployment endpoints:
/openai/deployments/{deployment}/chat/completions
Args:
prompt: User prompt
model_name: Model name or alias
system_prompt: Optional system prompt
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
# Validate model name against allow-list
if not self.validate_model_name(model_name):
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
# Validate parameters and fetch capabilities
self.validate_parameters(model_name, temperature)
capabilities = self.get_capabilities(model_name)
# Prepare messages
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Build user message content
user_message_content = []
if prompt:
user_message_content.append({"type": "text", "text": prompt})
if images and capabilities.supports_images:
for img_path in images:
processed_image = self._process_image(img_path)
if processed_image:
user_message_content.append(processed_image)
elif images:
logger.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
# Add user message. If only text, content will be a string, otherwise a list.
if len(user_message_content) == 1 and user_message_content[0]["type"] == "text":
messages.append({"role": "user", "content": prompt})
else:
messages.append({"role": "user", "content": user_message_content})
# Resolve model name
resolved_model = self._resolve_model_name(model_name)
# Build completion parameters
completion_params = {
"model": resolved_model,
"messages": messages,
}
# Determine temperature support from capabilities
supports_temperature = capabilities.supports_temperature
# Add temperature parameter if supported
if supports_temperature:
completion_params["temperature"] = temperature
# Add max tokens if specified and model supports it
if max_output_tokens and supports_temperature:
completion_params["max_tokens"] = max_output_tokens
# Add additional parameters
for key, value in kwargs.items():
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty"]:
continue
completion_params[key] = value
# DIAL-specific: Get cached client for deployment endpoint
deployment_client = self._get_deployment_client(resolved_model)
# Retry logic with progressive delays
last_exception = None
for attempt in range(self.MAX_RETRIES):
try:
# Generate completion using deployment-specific client
response = deployment_client.chat.completions.create(**completion_params)
# Extract content and usage
content = response.choices[0].message.content
usage = self._extract_usage(response)
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"finish_reason": response.choices[0].finish_reason,
"model": response.model,
"id": response.id,
"created": response.created,
},
)
except Exception as e:
last_exception = e
# Check if this is a retryable error
is_retryable = self._is_error_retryable(e)
if not is_retryable:
# Non-retryable error, raise immediately
raise ValueError(f"DIAL API error for model {model_name}: {str(e)}")
# If this isn't the last attempt and error is retryable, wait and retry
if attempt < self.MAX_RETRIES - 1:
delay = self.RETRY_DELAYS[attempt]
logger.info(
f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}"
)
time.sleep(delay)
continue
# All retries exhausted
raise ValueError(
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
)
def close(self) -> None:
"""Clean up HTTP clients when provider is closed."""
logger.info("Closing DIAL provider HTTP clients...")
# Clear the deployment clients cache
# Note: We don't need to close individual OpenAI clients since they
# use the shared httpx.Client which we close separately
self._deployment_clients.clear()
# Close the shared HTTP client
if hasattr(self, "_http_client"):
try:
self._http_client.close()
logger.debug("Closed shared HTTP client")
except Exception as e:
logger.warning(f"Error closing shared HTTP client: {e}")
# Also close the client created by the superclass (OpenAICompatibleProvider)
# as it holds its own httpx.Client instance that is not used by DIAL's generate_content
if hasattr(self, "client") and self.client and hasattr(self.client, "close"):
try:
self.client.close()
logger.debug("Closed superclass's OpenAI client")
except Exception as e:
logger.warning(f"Error closing superclass's OpenAI client: {e}")

578
providers/gemini.py Normal file
View file

@ -0,0 +1,578 @@
"""Gemini model provider implementation."""
import base64
import logging
import time
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from google import genai
from google.genai import types
from utils.image_utils import validate_image
from .base import ModelProvider
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
logger = logging.getLogger(__name__)
class GeminiModelProvider(ModelProvider):
"""First-party Gemini integration built on the official Google SDK.
The provider advertises detailed thinking-mode budgets, handles optional
custom endpoints, and performs image pre-processing before forwarding a
request to the Gemini APIs.
"""
# Model configurations using ModelCapabilities objects
MODEL_CAPABILITIES = {
"gemini-2.5-pro": ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name="gemini-2.5-pro",
friendly_name="Gemini (Pro 2.5)",
context_window=1_048_576, # 1M tokens
max_output_tokens=65_536,
supports_extended_thinking=True,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # Vision capability
max_image_size_mb=32.0, # Higher limit for Pro model
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
max_thinking_tokens=32768, # Max thinking tokens for Pro model
description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
aliases=["pro", "gemini pro", "gemini-pro"],
),
"gemini-2.0-flash": ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name="gemini-2.0-flash",
friendly_name="Gemini (Flash 2.0)",
context_window=1_048_576, # 1M tokens
max_output_tokens=65_536,
supports_extended_thinking=True, # Experimental thinking mode
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # Vision capability
max_image_size_mb=20.0, # Conservative 20MB limit for reliability
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
max_thinking_tokens=24576, # Same as 2.5 flash for consistency
description="Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input",
aliases=["flash-2.0", "flash2"],
),
"gemini-2.0-flash-lite": ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name="gemini-2.0-flash-lite",
friendly_name="Gemin (Flash Lite 2.0)",
context_window=1_048_576, # 1M tokens
max_output_tokens=65_536,
supports_extended_thinking=False, # Not supported per user request
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=False, # Does not support images
max_image_size_mb=0.0, # No image support
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only",
aliases=["flashlite", "flash-lite"],
),
"gemini-2.5-flash": ModelCapabilities(
provider=ProviderType.GOOGLE,
model_name="gemini-2.5-flash",
friendly_name="Gemini (Flash 2.5)",
context_window=1_048_576, # 1M tokens
max_output_tokens=65_536,
supports_extended_thinking=True,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # Vision capability
max_image_size_mb=20.0, # Conservative 20MB limit for reliability
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
max_thinking_tokens=24576, # Flash 2.5 thinking budget limit
description="Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
aliases=["flash", "flash2.5"],
),
}
# Thinking mode configurations - percentages of model's max_thinking_tokens
# These percentages work across all models that support thinking
THINKING_BUDGETS = {
"minimal": 0.005, # 0.5% of max - minimal thinking for fast responses
"low": 0.08, # 8% of max - light reasoning tasks
"medium": 0.33, # 33% of max - balanced reasoning (default)
"high": 0.67, # 67% of max - complex analysis
"max": 1.0, # 100% of max - full thinking budget
}
# Model-specific thinking token limits
MAX_THINKING_TOKENS = {
"gemini-2.0-flash": 24576, # Same as 2.5 flash for consistency
"gemini-2.0-flash-lite": 0, # No thinking support
"gemini-2.5-flash": 24576, # Flash 2.5 thinking budget limit
"gemini-2.5-pro": 32768, # Pro 2.5 thinking budget limit
}
def __init__(self, api_key: str, **kwargs):
"""Initialize Gemini provider with API key and optional base URL."""
super().__init__(api_key, **kwargs)
self._client = None
self._token_counters = {} # Cache for token counting
self._base_url = kwargs.get("base_url", None) # Optional custom endpoint
# ------------------------------------------------------------------
# Capability surface
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Client access
# ------------------------------------------------------------------
@property
def client(self):
"""Lazy initialization of Gemini client."""
if self._client is None:
# Check if custom base URL is provided
if self._base_url:
# Use HttpOptions to set custom endpoint
http_options = types.HttpOptions(baseUrl=self._base_url)
logger.debug(f"Initializing Gemini client with custom endpoint: {self._base_url}")
self._client = genai.Client(api_key=self.api_key, http_options=http_options)
else:
# Use default Google endpoint
self._client = genai.Client(api_key=self.api_key)
return self._client
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
thinking_mode: str = "medium",
images: Optional[list[str]] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using Gemini model."""
# Validate parameters and fetch capabilities
resolved_name = self._resolve_model_name(model_name)
self.validate_parameters(model_name, temperature)
capabilities = self.get_capabilities(model_name)
# Prepare content parts (text and potentially images)
parts = []
# Add system and user prompts as text
if system_prompt:
full_prompt = f"{system_prompt}\n\n{prompt}"
else:
full_prompt = prompt
parts.append({"text": full_prompt})
# Add images if provided and model supports vision
if images and capabilities.supports_images:
for image_path in images:
try:
image_part = self._process_image(image_path)
if image_part:
parts.append(image_part)
except Exception as e:
logger.warning(f"Failed to process image {image_path}: {e}")
# Continue with other images and text
continue
elif images and not capabilities.supports_images:
logger.warning(f"Model {resolved_name} does not support images, ignoring {len(images)} image(s)")
# Create contents structure
contents = [{"parts": parts}]
# Prepare generation config
generation_config = types.GenerateContentConfig(
temperature=temperature,
candidate_count=1,
)
# Add max output tokens if specified
if max_output_tokens:
generation_config.max_output_tokens = max_output_tokens
# Add thinking configuration for models that support it
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
# Get model's max thinking tokens and calculate actual budget
model_config = self.MODEL_CAPABILITIES.get(resolved_name)
if model_config and model_config.max_thinking_tokens > 0:
max_thinking_tokens = model_config.max_thinking_tokens
actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
generation_config.thinking_config = types.ThinkingConfig(thinking_budget=actual_thinking_budget)
# Retry logic with progressive delays
max_retries = 4 # Total of 4 attempts
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
last_exception = None
for attempt in range(max_retries):
try:
# Generate content
response = self.client.models.generate_content(
model=resolved_name,
contents=contents,
config=generation_config,
)
# Extract usage information if available
usage = self._extract_usage(response)
# Intelligently determine finish reason and safety blocks
finish_reason_str = "UNKNOWN"
is_blocked_by_safety = False
safety_feedback_details = None
if response.candidates:
candidate = response.candidates[0]
# Safely get finish reason
try:
finish_reason_enum = candidate.finish_reason
if finish_reason_enum:
# Handle both enum objects and string values
try:
finish_reason_str = finish_reason_enum.name
except AttributeError:
finish_reason_str = str(finish_reason_enum)
else:
finish_reason_str = "STOP"
except AttributeError:
finish_reason_str = "STOP"
# If content is empty, check safety ratings for the definitive cause
if not response.text:
try:
safety_ratings = candidate.safety_ratings
if safety_ratings: # Check it's not None or empty
for rating in safety_ratings:
try:
if rating.blocked:
is_blocked_by_safety = True
# Provide details for logging/debugging
category_name = "UNKNOWN"
probability_name = "UNKNOWN"
try:
category_name = rating.category.name
except (AttributeError, TypeError):
pass
try:
probability_name = rating.probability.name
except (AttributeError, TypeError):
pass
safety_feedback_details = (
f"Category: {category_name}, Probability: {probability_name}"
)
break
except (AttributeError, TypeError):
# Individual rating doesn't have expected attributes
continue
except (AttributeError, TypeError):
# candidate doesn't have safety_ratings or it's not iterable
pass
# Also check for prompt-level blocking (request rejected entirely)
elif response.candidates is not None and len(response.candidates) == 0:
# No candidates is the primary indicator of a prompt-level block
is_blocked_by_safety = True
finish_reason_str = "SAFETY"
safety_feedback_details = "Prompt blocked, reason unavailable" # Default message
try:
prompt_feedback = response.prompt_feedback
if prompt_feedback and prompt_feedback.block_reason:
try:
block_reason_name = prompt_feedback.block_reason.name
except AttributeError:
block_reason_name = str(prompt_feedback.block_reason)
safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}"
except (AttributeError, TypeError):
# prompt_feedback doesn't exist or has unexpected attributes; stick with the default message
pass
return ModelResponse(
content=response.text,
usage=usage,
model_name=resolved_name,
friendly_name="Gemini",
provider=ProviderType.GOOGLE,
metadata={
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
"finish_reason": finish_reason_str,
"is_blocked_by_safety": is_blocked_by_safety,
"safety_feedback": safety_feedback_details,
},
)
except Exception as e:
last_exception = e
# Check if this is a retryable error using structured error codes
is_retryable = self._is_error_retryable(e)
# If this is the last attempt or not retryable, give up
if attempt == max_retries - 1 or not is_retryable:
break
# Get progressive delay
delay = retry_delays[attempt]
# Log retry attempt
logger.warning(
f"Gemini API error for model {resolved_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
)
time.sleep(delay)
# If we get here, all retries failed
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
error_msg = f"Gemini API error for model {resolved_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
raise RuntimeError(error_msg) from last_exception
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.GOOGLE
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
"""Get actual thinking token budget for a model and thinking mode."""
resolved_name = self._resolve_model_name(model_name)
model_config = self.MODEL_CAPABILITIES.get(resolved_name)
if not model_config or not model_config.supports_extended_thinking:
return 0
if thinking_mode not in self.THINKING_BUDGETS:
return 0
max_thinking_tokens = model_config.max_thinking_tokens
if max_thinking_tokens == 0:
return 0
return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
def _extract_usage(self, response) -> dict[str, int]:
"""Extract token usage from Gemini response."""
usage = {}
# Try to extract usage metadata from response
# Note: The actual structure depends on the SDK version and response format
try:
metadata = response.usage_metadata
if metadata:
# Extract token counts with explicit None checks
input_tokens = None
output_tokens = None
try:
value = metadata.prompt_token_count
if value is not None:
input_tokens = value
usage["input_tokens"] = value
except (AttributeError, TypeError):
pass
try:
value = metadata.candidates_token_count
if value is not None:
output_tokens = value
usage["output_tokens"] = value
except (AttributeError, TypeError):
pass
# Calculate total only if both values are available and valid
if input_tokens is not None and output_tokens is not None:
usage["total_tokens"] = input_tokens + output_tokens
except (AttributeError, TypeError):
# response doesn't have usage_metadata
pass
return usage
def _is_error_retryable(self, error: Exception) -> bool:
"""Determine if an error should be retried based on structured error codes.
Uses Gemini API error structure instead of text pattern matching for reliability.
Args:
error: Exception from Gemini API call
Returns:
True if error should be retried, False otherwise
"""
error_str = str(error).lower()
# Check for 429 errors first - these need special handling
if "429" in error_str or "quota" in error_str or "resource_exhausted" in error_str:
# For Gemini, check for specific non-retryable error indicators
# These typically indicate permanent failures or quota/size limits
non_retryable_indicators = [
"quota exceeded",
"resource exhausted",
"context length",
"token limit",
"request too large",
"invalid request",
"quota_exceeded",
"resource_exhausted",
]
# Also check if this is a structured error from Gemini SDK
try:
# Try to access error details if available
error_details = None
try:
error_details = error.details
except AttributeError:
try:
error_details = error.reason
except AttributeError:
pass
if error_details:
error_details_str = str(error_details).lower()
# Check for non-retryable error codes/reasons
if any(indicator in error_details_str for indicator in non_retryable_indicators):
logger.debug(f"Non-retryable Gemini error: {error_details}")
return False
except Exception:
pass
# Check main error string for non-retryable patterns
if any(indicator in error_str for indicator in non_retryable_indicators):
logger.debug(f"Non-retryable Gemini error based on message: {error_str[:200]}...")
return False
# If it's a 429/quota error but doesn't match non-retryable patterns, it might be retryable rate limiting
logger.debug(f"Retryable Gemini rate limiting error: {error_str[:100]}...")
return True
# For non-429 errors, check if they're retryable
retryable_indicators = [
"timeout",
"connection",
"network",
"temporary",
"unavailable",
"retry",
"internal error",
"408", # Request timeout
"500", # Internal server error
"502", # Bad gateway
"503", # Service unavailable
"504", # Gateway timeout
"ssl", # SSL errors
"handshake", # Handshake failures
]
return any(indicator in error_str for indicator in retryable_indicators)
def _process_image(self, image_path: str) -> Optional[dict]:
"""Process an image for Gemini API."""
try:
# Use base class validation
image_bytes, mime_type = validate_image(image_path)
# For data URLs, extract the base64 data directly
if image_path.startswith("data:"):
# Extract base64 data from data URL
_, data = image_path.split(",", 1)
return {"inline_data": {"mime_type": mime_type, "data": data}}
else:
# For file paths, encode the bytes
image_data = base64.b64encode(image_bytes).decode()
return {"inline_data": {"mime_type": mime_type, "data": image_data}}
except ValueError as e:
logger.warning(str(e))
return None
except Exception as e:
logger.error(f"Error processing image {image_path}: {e}")
return None
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get Gemini's preferred model for a given category from allowed models.
Args:
category: The tool category requiring a model
allowed_models: Pre-filtered list of models allowed by restrictions
Returns:
Preferred model name or None
"""
from tools.models import ToolModelCategory
if not allowed_models:
return None
# Helper to find best model from candidates
def find_best(candidates: list[str]) -> Optional[str]:
"""Return best model from candidates (sorted for consistency)."""
return sorted(candidates, reverse=True)[0] if candidates else None
if category == ToolModelCategory.EXTENDED_REASONING:
# For extended reasoning, prefer models with thinking support
# First try Pro models that support thinking
pro_thinking = [
m
for m in allowed_models
if "pro" in m and m in self.MODEL_CAPABILITIES and self.MODEL_CAPABILITIES[m].supports_extended_thinking
]
if pro_thinking:
return find_best(pro_thinking)
# Then any model that supports thinking
any_thinking = [
m
for m in allowed_models
if m in self.MODEL_CAPABILITIES and self.MODEL_CAPABILITIES[m].supports_extended_thinking
]
if any_thinking:
return find_best(any_thinking)
# Finally, just prefer Pro models even without thinking
pro_models = [m for m in allowed_models if "pro" in m]
if pro_models:
return find_best(pro_models)
elif category == ToolModelCategory.FAST_RESPONSE:
# Prefer Flash models for speed
flash_models = [m for m in allowed_models if "flash" in m]
if flash_models:
return find_best(flash_models)
# Default for BALANCED or as fallback
# Prefer Flash for balanced use, then Pro, then anything
flash_models = [m for m in allowed_models if "flash" in m]
if flash_models:
return find_best(flash_models)
pro_models = [m for m in allowed_models if "pro" in m]
if pro_models:
return find_best(pro_models)
# Ultimate fallback to best available model
return find_best(allowed_models)

View file

@ -0,0 +1,826 @@
"""Base class for OpenAI-compatible API providers."""
import copy
import ipaddress
import logging
import os
import time
from typing import Optional
from urllib.parse import urlparse
from openai import OpenAI
from utils.image_utils import validate_image
from .base import ModelProvider
from .shared import (
ModelCapabilities,
ModelResponse,
ProviderType,
)
class OpenAICompatibleProvider(ModelProvider):
"""Shared implementation for OpenAI API lookalikes.
The class owns HTTP client configuration (timeouts, proxy hardening,
custom headers) and normalises the OpenAI SDK responses into
:class:`~providers.shared.ModelResponse`. Concrete subclasses only need to
provide capability metadata and any provider-specific request tweaks.
"""
DEFAULT_HEADERS = {}
FRIENDLY_NAME = "OpenAI Compatible"
def __init__(self, api_key: str, base_url: str = None, **kwargs):
"""Initialize the provider with API key and optional base URL.
Args:
api_key: API key for authentication
base_url: Base URL for the API endpoint
**kwargs: Additional configuration options including timeout
"""
super().__init__(api_key, **kwargs)
self._client = None
self.base_url = base_url
self.organization = kwargs.get("organization")
self.allowed_models = self._parse_allowed_models()
# Configure timeouts - especially important for custom/local endpoints
self.timeout_config = self._configure_timeouts(**kwargs)
# Validate base URL for security
if self.base_url:
self._validate_base_url()
# Warn if using external URL without authentication
if self.base_url and not self._is_localhost_url() and not api_key:
logging.warning(
f"Using external URL '{self.base_url}' without API key. "
"This may be insecure. Consider setting an API key for authentication."
)
def _ensure_model_allowed(
self,
capabilities: ModelCapabilities,
canonical_name: str,
requested_name: str,
) -> None:
"""Respect provider-specific allowlists before default restriction checks."""
super()._ensure_model_allowed(capabilities, canonical_name, requested_name)
if self.allowed_models is not None:
requested = requested_name.lower()
canonical = canonical_name.lower()
if requested not in self.allowed_models and canonical not in self.allowed_models:
raise ValueError(
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
)
def _parse_allowed_models(self) -> Optional[set[str]]:
"""Parse allowed models from environment variable.
Returns:
Set of allowed model names (lowercase) or None if not configured
"""
# Get provider-specific allowed models
provider_type = self.get_provider_type().value.upper()
env_var = f"{provider_type}_ALLOWED_MODELS"
models_str = os.getenv(env_var, "")
if models_str:
# Parse and normalize to lowercase for case-insensitive comparison
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
if models:
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
return models
# Log info if no allow-list configured for proxy providers
if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
logging.info(
f"Model allow-list not configured for {self.FRIENDLY_NAME} - all models permitted. "
f"To restrict access, set {env_var} with comma-separated model names."
)
return None
def _configure_timeouts(self, **kwargs):
"""Configure timeout settings based on provider type and custom settings.
Custom URLs and local models often need longer timeouts due to:
- Network latency on local networks
- Extended thinking models taking longer to respond
- Local inference being slower than cloud APIs
Returns:
httpx.Timeout object with appropriate timeout settings
"""
import httpx
# Default timeouts - more generous for custom/local endpoints
default_connect = 30.0 # 30 seconds for connection (vs OpenAI's 5s)
default_read = 600.0 # 10 minutes for reading (same as OpenAI default)
default_write = 600.0 # 10 minutes for writing
default_pool = 600.0 # 10 minutes for pool
# For custom/local URLs, use even longer timeouts
if self.base_url and self._is_localhost_url():
default_connect = 60.0 # 1 minute for local connections
default_read = 1800.0 # 30 minutes for local models (extended thinking)
default_write = 1800.0 # 30 minutes for local models
default_pool = 1800.0 # 30 minutes for local models
logging.info(f"Using extended timeouts for local endpoint: {self.base_url}")
elif self.base_url:
default_connect = 45.0 # 45 seconds for custom remote endpoints
default_read = 900.0 # 15 minutes for custom remote endpoints
default_write = 900.0 # 15 minutes for custom remote endpoints
default_pool = 900.0 # 15 minutes for custom remote endpoints
logging.info(f"Using extended timeouts for custom endpoint: {self.base_url}")
# Allow override via kwargs or environment variables in future, for now...
connect_timeout = kwargs.get("connect_timeout", float(os.getenv("CUSTOM_CONNECT_TIMEOUT", default_connect)))
read_timeout = kwargs.get("read_timeout", float(os.getenv("CUSTOM_READ_TIMEOUT", default_read)))
write_timeout = kwargs.get("write_timeout", float(os.getenv("CUSTOM_WRITE_TIMEOUT", default_write)))
pool_timeout = kwargs.get("pool_timeout", float(os.getenv("CUSTOM_POOL_TIMEOUT", default_pool)))
timeout = httpx.Timeout(connect=connect_timeout, read=read_timeout, write=write_timeout, pool=pool_timeout)
logging.debug(
f"Configured timeouts - Connect: {connect_timeout}s, Read: {read_timeout}s, "
f"Write: {write_timeout}s, Pool: {pool_timeout}s"
)
return timeout
def _is_localhost_url(self) -> bool:
"""Check if the base URL points to localhost or local network.
Returns:
True if URL is localhost or local network, False otherwise
"""
if not self.base_url:
return False
try:
parsed = urlparse(self.base_url)
hostname = parsed.hostname
# Check for common localhost patterns
if hostname in ["localhost", "127.0.0.1", "::1"]:
return True
# Check for private network ranges (local network)
if hostname:
try:
ip = ipaddress.ip_address(hostname)
return ip.is_private or ip.is_loopback
except ValueError:
# Not an IP address, might be a hostname
pass
return False
except Exception:
return False
def _validate_base_url(self) -> None:
"""Validate base URL for security (SSRF protection).
Raises:
ValueError: If URL is invalid or potentially unsafe
"""
if not self.base_url:
return
try:
parsed = urlparse(self.base_url)
# Check URL scheme - only allow http/https
if parsed.scheme not in ("http", "https"):
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
# Check hostname exists
if not parsed.hostname:
raise ValueError("URL must include a hostname")
# Check port is valid (if specified)
port = parsed.port
if port is not None and (port < 1 or port > 65535):
raise ValueError(f"Invalid port number: {port}. Must be between 1 and 65535.")
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
@property
def client(self):
"""Lazy initialization of OpenAI client with security checks and timeout configuration."""
if self._client is None:
import os
import httpx
# Temporarily disable proxy environment variables to prevent httpx from detecting them
original_env = {}
proxy_env_vars = ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]
for var in proxy_env_vars:
if var in os.environ:
original_env[var] = os.environ[var]
del os.environ[var]
try:
# Create a custom httpx client that explicitly avoids proxy parameters
timeout_config = (
self.timeout_config
if hasattr(self, "timeout_config") and self.timeout_config
else httpx.Timeout(30.0)
)
# Create httpx client with minimal config to avoid proxy conflicts
# Note: proxies parameter was removed in httpx 0.28.0
# Check for test transport injection
if hasattr(self, "_test_transport"):
# Use custom transport for testing (HTTP recording/replay)
http_client = httpx.Client(
transport=self._test_transport,
timeout=timeout_config,
follow_redirects=True,
)
else:
# Normal production client
http_client = httpx.Client(
timeout=timeout_config,
follow_redirects=True,
)
# Keep client initialization minimal to avoid proxy parameter conflicts
client_kwargs = {
"api_key": self.api_key,
"http_client": http_client,
}
if self.base_url:
client_kwargs["base_url"] = self.base_url
if self.organization:
client_kwargs["organization"] = self.organization
# Add default headers if any
if self.DEFAULT_HEADERS:
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
logging.debug(f"OpenAI client initialized with custom httpx client and timeout: {timeout_config}")
# Create OpenAI client with custom httpx client
self._client = OpenAI(**client_kwargs)
except Exception as e:
# If all else fails, try absolute minimal client without custom httpx
logging.warning(f"Failed to create client with custom httpx, falling back to minimal config: {e}")
try:
minimal_kwargs = {"api_key": self.api_key}
if self.base_url:
minimal_kwargs["base_url"] = self.base_url
self._client = OpenAI(**minimal_kwargs)
except Exception as fallback_error:
logging.error(f"Even minimal OpenAI client creation failed: {fallback_error}")
raise
finally:
# Restore original proxy environment variables
for var, value in original_env.items():
os.environ[var] = value
return self._client
def _sanitize_for_logging(self, params: dict) -> dict:
"""Sanitize sensitive data from parameters before logging.
Args:
params: Dictionary of API parameters
Returns:
dict: Sanitized copy of parameters safe for logging
"""
sanitized = copy.deepcopy(params)
# Sanitize messages content
if "input" in sanitized:
for msg in sanitized.get("input", []):
if isinstance(msg, dict) and "content" in msg:
for content_item in msg.get("content", []):
if isinstance(content_item, dict) and "text" in content_item:
# Truncate long text and add ellipsis
text = content_item["text"]
if len(text) > 100:
content_item["text"] = text[:100] + "... [truncated]"
# Remove any API keys that might be in headers/auth
sanitized.pop("api_key", None)
sanitized.pop("authorization", None)
return sanitized
def _safe_extract_output_text(self, response) -> str:
"""Safely extract output_text from o3-pro response with validation.
Args:
response: Response object from OpenAI SDK
Returns:
str: The output text content
Raises:
ValueError: If output_text is missing, None, or not a string
"""
logging.debug(f"Response object type: {type(response)}")
logging.debug(f"Response attributes: {dir(response)}")
if not hasattr(response, "output_text"):
raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}")
content = response.output_text
logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})")
if content is None:
raise ValueError("o3-pro returned None for output_text")
if not isinstance(content, str):
raise ValueError(f"o3-pro output_text is not a string. Got type: {type(content).__name__}")
return content
def _generate_with_responses_endpoint(
self,
model_name: str,
messages: list,
temperature: float,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using the /v1/responses endpoint for o3-pro via OpenAI library."""
# Convert messages to the correct format for responses endpoint
input_messages = []
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
if role == "system":
# For o3-pro, system messages should be handled carefully to avoid policy violations
# Instead of prefixing with "System:", we'll include the system content naturally
input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]})
elif role == "user":
input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]})
elif role == "assistant":
input_messages.append({"role": "assistant", "content": [{"type": "output_text", "text": content}]})
# Prepare completion parameters for responses endpoint
# Based on OpenAI documentation, use nested reasoning object for responses endpoint
completion_params = {
"model": model_name,
"input": input_messages,
"reasoning": {"effort": "medium"}, # Use nested object for responses endpoint
"store": True,
}
# Add max tokens if specified (using max_completion_tokens for responses endpoint)
if max_output_tokens:
completion_params["max_completion_tokens"] = max_output_tokens
# For responses endpoint, we only add parameters that are explicitly supported
# Remove unsupported chat completion parameters that may cause API errors
# Retry logic with progressive delays
max_retries = 4
retry_delays = [1, 3, 5, 8]
last_exception = None
actual_attempts = 0
for attempt in range(max_retries):
try: # Log sanitized payload for debugging
import json
sanitized_params = self._sanitize_for_logging(completion_params)
logging.info(
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
)
# Use OpenAI client's responses endpoint
response = self.client.responses.create(**completion_params)
# Extract content from responses endpoint format
# Use validation helper to safely extract output_text
content = self._safe_extract_output_text(response)
# Try to extract usage information
usage = None
if hasattr(response, "usage"):
usage = self._extract_usage(response)
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
# Safely extract token counts with None handling
input_tokens = getattr(response, "input_tokens", 0) or 0
output_tokens = getattr(response, "output_tokens", 0) or 0
usage = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"model": getattr(response, "model", model_name),
"id": getattr(response, "id", ""),
"created": getattr(response, "created_at", 0),
"endpoint": "responses",
},
)
except Exception as e:
last_exception = e
# Check if this is a retryable error using structured error codes
is_retryable = self._is_error_retryable(e)
if is_retryable and attempt < max_retries - 1:
delay = retry_delays[attempt]
logging.warning(
f"Retryable error for o3-pro responses endpoint, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
)
time.sleep(delay)
else:
break
# If we get here, all retries failed
error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
images: Optional[list[str]] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using the OpenAI-compatible API.
Args:
prompt: User prompt to send to the model
model_name: Name of the model to use
system_prompt: Optional system prompt for model behavior
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
# Validate model name against allow-list
if not self.validate_model_name(model_name):
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
capabilities: Optional[ModelCapabilities]
try:
capabilities = self.get_capabilities(model_name)
except Exception as exc:
logging.debug(f"Falling back to generic capabilities for {model_name}: {exc}")
capabilities = None
# Get effective temperature for this model from capabilities when available
if capabilities:
effective_temperature = capabilities.get_effective_temperature(temperature)
if effective_temperature is not None and effective_temperature != temperature:
logging.debug(
f"Adjusting temperature from {temperature} to {effective_temperature} for model {model_name}"
)
else:
effective_temperature = temperature
# Only validate if temperature is not None (meaning the model supports it)
if effective_temperature is not None:
# Validate parameters with the effective temperature
self.validate_parameters(model_name, effective_temperature)
# Prepare messages
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Prepare user message with text and potentially images
user_content = []
user_content.append({"type": "text", "text": prompt})
# Add images if provided and model supports vision
if images and capabilities and capabilities.supports_images:
for image_path in images:
try:
image_content = self._process_image(image_path)
if image_content:
user_content.append(image_content)
except Exception as e:
logging.warning(f"Failed to process image {image_path}: {e}")
# Continue with other images and text
continue
elif images and (not capabilities or not capabilities.supports_images):
logging.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
# Add user message
if len(user_content) == 1:
# Only text content, use simple string format for compatibility
messages.append({"role": "user", "content": prompt})
else:
# Text + images, use content array format
messages.append({"role": "user", "content": user_content})
# Prepare completion parameters
completion_params = {
"model": model_name,
"messages": messages,
}
# Check model capabilities once to determine parameter support
resolved_model = self._resolve_model_name(model_name)
# Use the effective temperature we calculated earlier
supports_sampling = effective_temperature is not None
if supports_sampling:
completion_params["temperature"] = effective_temperature
# Add max tokens if specified and model supports it
# O3/O4 models that don't support temperature also don't support max_tokens
if max_output_tokens and supports_sampling:
completion_params["max_tokens"] = max_output_tokens
# Add any additional OpenAI-specific parameters
# Use capabilities to filter parameters for reasoning models
for key, value in kwargs.items():
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
# Reasoning models (those that don't support temperature) also don't support these parameters
if not supports_sampling and key in ["top_p", "frequency_penalty", "presence_penalty"]:
continue # Skip unsupported parameters for reasoning models
completion_params[key] = value
# Check if this is o3-pro and needs the responses endpoint
if resolved_model == "o3-pro":
# This model requires the /v1/responses endpoint
# If it fails, we should not fall back to chat/completions
return self._generate_with_responses_endpoint(
model_name=resolved_model,
messages=messages,
temperature=temperature,
max_output_tokens=max_output_tokens,
**kwargs,
)
# Retry logic with progressive delays
max_retries = 4 # Total of 4 attempts
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
last_exception = None
actual_attempts = 0
for attempt in range(max_retries):
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
try:
# Generate completion
response = self.client.chat.completions.create(**completion_params)
# Extract content and usage
content = response.choices[0].message.content
usage = self._extract_usage(response)
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"finish_reason": response.choices[0].finish_reason,
"model": response.model, # Actual model used
"id": response.id,
"created": response.created,
},
)
except Exception as e:
last_exception = e
# Check if this is a retryable error using structured error codes
is_retryable = self._is_error_retryable(e)
# If this is the last attempt or not retryable, give up
if attempt == max_retries - 1 or not is_retryable:
break
# Get progressive delay
delay = retry_delays[attempt]
# Log retry attempt
logging.warning(
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
)
time.sleep(delay)
# If we get here, all retries failed
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
logging.error(error_msg)
raise RuntimeError(error_msg) from last_exception
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters.
For proxy providers, this may use generic capabilities.
Args:
model_name: Model to validate for
temperature: Temperature to validate
**kwargs: Additional parameters to validate
"""
try:
capabilities = self.get_capabilities(model_name)
# Check if we're using generic capabilities
if hasattr(capabilities, "_is_generic"):
logging.debug(
f"Using generic parameter validation for {model_name}. Actual model constraints may differ."
)
# Validate temperature using parent class method
super().validate_parameters(model_name, temperature, **kwargs)
except Exception as e:
# For proxy providers, we might not have accurate capabilities
# Log warning but don't fail
logging.warning(f"Parameter validation limited for {model_name}: {e}")
def _extract_usage(self, response) -> dict[str, int]:
"""Extract token usage from OpenAI response.
Args:
response: OpenAI API response object
Returns:
Dictionary with usage statistics
"""
usage = {}
if hasattr(response, "usage") and response.usage:
# Safely extract token counts with None handling
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0) or 0
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0) or 0
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0) or 0
return usage
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens using OpenAI-compatible tokenizer tables when available."""
resolved_model = self._resolve_model_name(model_name)
try:
import tiktoken
try:
encoding = tiktoken.encoding_for_model(resolved_model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(text))
except (ImportError, Exception) as exc:
logging.debug("tiktoken unavailable for %s: %s", resolved_model, exc)
return super().count_tokens(text, model_name)
def _is_error_retryable(self, error: Exception) -> bool:
"""Determine if an error should be retried based on structured error codes.
Uses OpenAI API error structure instead of text pattern matching for reliability.
Args:
error: Exception from OpenAI API call
Returns:
True if error should be retried, False otherwise
"""
error_str = str(error).lower()
# Check for 429 errors first - these need special handling
if "429" in error_str:
# Try to extract structured error information
error_type = None
error_code = None
# Parse structured error from OpenAI API response
# Format: "Error code: 429 - {'error': {'type': 'tokens', 'code': 'rate_limit_exceeded', ...}}"
try:
import ast
import json
import re
# Extract JSON part from error string using regex
# Look for pattern: {...} (from first { to last })
json_match = re.search(r"\{.*\}", str(error))
if json_match:
json_like_str = json_match.group(0)
# First try: parse as Python literal (handles single quotes safely)
try:
error_data = ast.literal_eval(json_like_str)
except (ValueError, SyntaxError):
# Fallback: try JSON parsing with simple quote replacement
# (for cases where it's already valid JSON or simple replacements work)
json_str = json_like_str.replace("'", '"')
error_data = json.loads(json_str)
if "error" in error_data:
error_info = error_data["error"]
error_type = error_info.get("type")
error_code = error_info.get("code")
except (json.JSONDecodeError, ValueError, SyntaxError, AttributeError):
# Fall back to checking hasattr for OpenAI SDK exception objects
if hasattr(error, "response") and hasattr(error.response, "json"):
try:
response_data = error.response.json()
if "error" in response_data:
error_info = response_data["error"]
error_type = error_info.get("type")
error_code = error_info.get("code")
except Exception:
pass
# Determine if 429 is retryable based on structured error codes
if error_type == "tokens":
# Token-related 429s are typically non-retryable (request too large)
logging.debug(f"Non-retryable 429: token-related error (type={error_type}, code={error_code})")
return False
elif error_code in ["invalid_request_error", "context_length_exceeded"]:
# These are permanent failures
logging.debug(f"Non-retryable 429: permanent failure (type={error_type}, code={error_code})")
return False
else:
# Other 429s (like requests per minute) are retryable
logging.debug(f"Retryable 429: rate limiting (type={error_type}, code={error_code})")
return True
# For non-429 errors, check if they're retryable
retryable_indicators = [
"timeout",
"connection",
"network",
"temporary",
"unavailable",
"retry",
"408", # Request timeout
"500", # Internal server error
"502", # Bad gateway
"503", # Service unavailable
"504", # Gateway timeout
"ssl", # SSL errors
"handshake", # Handshake failures
]
return any(indicator in error_str for indicator in retryable_indicators)
def _process_image(self, image_path: str) -> Optional[dict]:
"""Process an image for OpenAI-compatible API."""
try:
if image_path.startswith("data:"):
# Validate the data URL
validate_image(image_path)
# Handle data URL: data:image/png;base64,iVBORw0...
return {"type": "image_url", "image_url": {"url": image_path}}
else:
# Use base class validation
image_bytes, mime_type = validate_image(image_path)
# Read and encode the image
import base64
image_data = base64.b64encode(image_bytes).decode()
logging.debug(f"Processing image '{image_path}' as MIME type '{mime_type}'")
# Create data URL for OpenAI API
data_url = f"data:{mime_type};base64,{image_data}"
return {"type": "image_url", "image_url": {"url": data_url}}
except ValueError as e:
logging.warning(str(e))
return None
except Exception as e:
logging.error(f"Error processing image {image_path}: {e}")
return None

View file

@ -0,0 +1,296 @@
"""OpenAI model provider implementation."""
import logging
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from .openai_compatible import OpenAICompatibleProvider
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
logger = logging.getLogger(__name__)
class OpenAIModelProvider(OpenAICompatibleProvider):
"""Implementation that talks to api.openai.com using rich model metadata.
In addition to the built-in catalogue, the provider can surface models
defined in ``conf/custom_models.json`` (for organisations running their own
OpenAI-compatible gateways) while still respecting restriction policies.
"""
# Model configurations using ModelCapabilities objects
MODEL_CAPABILITIES = {
"gpt-5": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="gpt-5",
friendly_name="OpenAI (GPT-5)",
context_window=400_000, # 400K tokens
max_output_tokens=128_000, # 128K max output tokens
supports_extended_thinking=True, # Supports reasoning tokens
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # GPT-5 supports vision
max_image_size_mb=20.0, # 20MB per OpenAI docs
supports_temperature=True, # Regular models accept temperature parameter
temperature_constraint=TemperatureConstraint.create("fixed"),
description="GPT-5 (400K context, 128K output) - Advanced model with reasoning support",
aliases=["gpt5"],
),
"gpt-5-mini": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="gpt-5-mini",
friendly_name="OpenAI (GPT-5-mini)",
context_window=400_000, # 400K tokens
max_output_tokens=128_000, # 128K max output tokens
supports_extended_thinking=True, # Supports reasoning tokens
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # GPT-5-mini supports vision
max_image_size_mb=20.0, # 20MB per OpenAI docs
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("fixed"),
description="GPT-5-mini (400K context, 128K output) - Efficient variant with reasoning support",
aliases=["gpt5-mini", "gpt5mini", "mini"],
),
"gpt-5-nano": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="gpt-5-nano",
friendly_name="OpenAI (GPT-5 nano)",
context_window=400_000,
max_output_tokens=128_000,
supports_extended_thinking=True,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True,
max_image_size_mb=20.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("fixed"),
description="GPT-5 nano (400K context) - Fastest, cheapest version of GPT-5 for summarization and classification tasks",
aliases=["gpt5nano", "gpt5-nano", "nano"],
),
"o3": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="o3",
friendly_name="OpenAI (O3)",
context_window=200_000, # 200K tokens
max_output_tokens=65536, # 64K max output tokens
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # O3 models support vision
max_image_size_mb=20.0, # 20MB per OpenAI docs
supports_temperature=False, # O3 models don't accept temperature parameter
temperature_constraint=TemperatureConstraint.create("fixed"),
description="Strong reasoning (200K context) - Logical problems, code generation, systematic analysis",
aliases=[],
),
"o3-mini": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="o3-mini",
friendly_name="OpenAI (O3-mini)",
context_window=200_000, # 200K tokens
max_output_tokens=65536, # 64K max output tokens
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # O3 models support vision
max_image_size_mb=20.0, # 20MB per OpenAI docs
supports_temperature=False, # O3 models don't accept temperature parameter
temperature_constraint=TemperatureConstraint.create("fixed"),
description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
aliases=["o3mini"],
),
"o3-pro": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="o3-pro",
friendly_name="OpenAI (O3-Pro)",
context_window=200_000, # 200K tokens
max_output_tokens=65536, # 64K max output tokens
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # O3 models support vision
max_image_size_mb=20.0, # 20MB per OpenAI docs
supports_temperature=False, # O3 models don't accept temperature parameter
temperature_constraint=TemperatureConstraint.create("fixed"),
description="Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
aliases=["o3pro"],
),
"o4-mini": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="o4-mini",
friendly_name="OpenAI (O4-mini)",
context_window=200_000, # 200K tokens
max_output_tokens=65536, # 64K max output tokens
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # O4 models support vision
max_image_size_mb=20.0, # 20MB per OpenAI docs
supports_temperature=False, # O4 models don't accept temperature parameter
temperature_constraint=TemperatureConstraint.create("fixed"),
description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
aliases=["o4mini"],
),
"gpt-4.1": ModelCapabilities(
provider=ProviderType.OPENAI,
model_name="gpt-4.1",
friendly_name="OpenAI (GPT 4.1)",
context_window=1_000_000, # 1M tokens
max_output_tokens=32_768,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=True,
supports_images=True, # GPT-4.1 supports vision
max_image_size_mb=20.0, # 20MB per OpenAI docs
supports_temperature=True, # Regular models accept temperature parameter
temperature_constraint=TemperatureConstraint.create("range"),
description="GPT-4.1 (1M context) - Advanced reasoning model with large context window",
aliases=["gpt4.1"],
),
}
def __init__(self, api_key: str, **kwargs):
"""Initialize OpenAI provider with API key."""
# Set default OpenAI base URL, allow override for regions/custom endpoints
kwargs.setdefault("base_url", "https://api.openai.com/v1")
super().__init__(api_key, **kwargs)
# ------------------------------------------------------------------
# Capability surface
# ------------------------------------------------------------------
def _lookup_capabilities(
self,
canonical_name: str,
requested_name: Optional[str] = None,
) -> Optional[ModelCapabilities]:
"""Look up OpenAI capabilities from built-ins or the custom registry."""
builtin = super()._lookup_capabilities(canonical_name, requested_name)
if builtin is not None:
return builtin
try:
from .openrouter_registry import OpenRouterModelRegistry
registry = OpenRouterModelRegistry()
config = registry.get_model_config(canonical_name)
if config and config.provider == ProviderType.OPENAI:
return config
except Exception as exc: # pragma: no cover - registry failures are non-critical
logger.debug(f"Could not resolve custom OpenAI model '{canonical_name}': {exc}")
return None
def _finalise_capabilities(
self,
capabilities: ModelCapabilities,
canonical_name: str,
requested_name: str,
) -> ModelCapabilities:
"""Ensure registry-sourced models report the correct provider type."""
if capabilities.provider != ProviderType.OPENAI:
capabilities.provider = ProviderType.OPENAI
return capabilities
def _raise_unsupported_model(self, model_name: str) -> None:
raise ValueError(f"Unsupported OpenAI model: {model_name}")
# ------------------------------------------------------------------
# Provider identity
# ------------------------------------------------------------------
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.OPENAI
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using OpenAI API with proper model name resolution."""
# Resolve model alias before making API call
resolved_model_name = self._resolve_model_name(model_name)
# Call parent implementation with resolved model name
return super().generate_content(
prompt=prompt,
model_name=resolved_model_name,
system_prompt=system_prompt,
temperature=temperature,
max_output_tokens=max_output_tokens,
**kwargs,
)
# ------------------------------------------------------------------
# Provider preferences
# ------------------------------------------------------------------
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get OpenAI's preferred model for a given category from allowed models.
Args:
category: The tool category requiring a model
allowed_models: Pre-filtered list of models allowed by restrictions
Returns:
Preferred model name or None
"""
from tools.models import ToolModelCategory
if not allowed_models:
return None
# Helper to find first available from preference list
def find_first(preferences: list[str]) -> Optional[str]:
"""Return first available model from preference list."""
for model in preferences:
if model in allowed_models:
return model
return None
if category == ToolModelCategory.EXTENDED_REASONING:
# Prefer models with extended thinking support
preferred = find_first(["o3", "o3-pro", "gpt-5"])
return preferred if preferred else allowed_models[0]
elif category == ToolModelCategory.FAST_RESPONSE:
# Prefer fast, cost-efficient models
preferred = find_first(["gpt-5", "gpt-5-mini", "o4-mini", "o3-mini"])
return preferred if preferred else allowed_models[0]
else: # BALANCED or default
# Prefer balanced performance/cost models
preferred = find_first(["gpt-5", "gpt-5-mini", "o4-mini", "o3-mini"])
return preferred if preferred else allowed_models[0]

251
providers/openrouter.py Normal file
View file

@ -0,0 +1,251 @@
"""OpenRouter provider implementation."""
import logging
import os
from typing import Optional
from .openai_compatible import OpenAICompatibleProvider
from .openrouter_registry import OpenRouterModelRegistry
from .shared import (
ModelCapabilities,
ModelResponse,
ProviderType,
RangeTemperatureConstraint,
)
class OpenRouterProvider(OpenAICompatibleProvider):
"""Client for OpenRouter's multi-model aggregation service.
Role
Surface OpenRouters dynamic catalogue through the same interface as
native providers so tools can reference OpenRouter models and aliases
without special cases.
Characteristics
* Pulls live model definitions from :class:`OpenRouterModelRegistry`
(aliases, provider-specific metadata, capability hints)
* Applies alias-aware restriction checks before exposing models to the
registry or tooling
* Reuses :class:`OpenAICompatibleProvider` infrastructure for request
execution so OpenRouter endpoints behave like standard OpenAI-style
APIs.
"""
FRIENDLY_NAME = "OpenRouter"
# Custom headers required by OpenRouter
DEFAULT_HEADERS = {
"HTTP-Referer": os.getenv("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server"),
"X-Title": os.getenv("OPENROUTER_TITLE", "Zen MCP Server"),
}
# Model registry for managing configurations and aliases
_registry: Optional[OpenRouterModelRegistry] = None
def __init__(self, api_key: str, **kwargs):
"""Initialize OpenRouter provider.
Args:
api_key: OpenRouter API key
**kwargs: Additional configuration
"""
base_url = "https://openrouter.ai/api/v1"
super().__init__(api_key, base_url=base_url, **kwargs)
# Initialize model registry
if OpenRouterProvider._registry is None:
OpenRouterProvider._registry = OpenRouterModelRegistry()
# Log loaded models and aliases only on first load
models = self._registry.list_models()
aliases = self._registry.list_aliases()
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
# ------------------------------------------------------------------
# Capability surface
# ------------------------------------------------------------------
def _lookup_capabilities(
self,
canonical_name: str,
requested_name: Optional[str] = None,
) -> Optional[ModelCapabilities]:
"""Fetch OpenRouter capabilities from the registry or build a generic fallback."""
capabilities = self._registry.get_capabilities(canonical_name)
if capabilities:
return capabilities
base_identifier = canonical_name.split(":", 1)[0]
if "/" in base_identifier:
logging.debug(
"Using generic OpenRouter capabilities for %s (provider/model format detected)", canonical_name
)
generic = ModelCapabilities(
provider=ProviderType.OPENROUTER,
model_name=canonical_name,
friendly_name=self.FRIENDLY_NAME,
context_window=32_768,
max_output_tokens=32_768,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=False,
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
)
generic._is_generic = True
return generic
logging.debug(
"Rejecting unknown OpenRouter model '%s' (no provider prefix); requires explicit configuration",
canonical_name,
)
return None
# ------------------------------------------------------------------
# Provider identity
# ------------------------------------------------------------------
def get_provider_type(self) -> ProviderType:
"""Identify this provider for restrictions and logging."""
return ProviderType.OPENROUTER
# ------------------------------------------------------------------
# Request execution
# ------------------------------------------------------------------
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using the OpenRouter API.
Args:
prompt: User prompt to send to the model
model_name: Name of the model (or alias) to use
system_prompt: Optional system prompt for model behavior
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
**kwargs: Additional provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
# Resolve model alias to actual OpenRouter model name
resolved_model = self._resolve_model_name(model_name)
# Always disable streaming for OpenRouter
# MCP doesn't use streaming, and this avoids issues with O3 model access
if "stream" not in kwargs:
kwargs["stream"] = False
# Call parent method with resolved model name
return super().generate_content(
prompt=prompt,
model_name=resolved_model,
system_prompt=system_prompt,
temperature=temperature,
max_output_tokens=max_output_tokens,
**kwargs,
)
# ------------------------------------------------------------------
# Registry helpers
# ------------------------------------------------------------------
def list_models(
self,
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
) -> list[str]:
"""Return formatted OpenRouter model names, respecting alias-aware restrictions."""
if not self._registry:
return []
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
allowed_configs: dict[str, ModelCapabilities] = {}
for model_name in self._registry.list_models():
config = self._registry.resolve(model_name)
if not config:
continue
# Custom models belong to CustomProvider; skip them here so the two
# providers don't race over the same registrations (important for tests
# that stub the registry with minimal objects lacking attrs).
if hasattr(config, "is_custom") and config.is_custom is True:
continue
if restriction_service:
allowed = restriction_service.is_allowed(self.get_provider_type(), model_name)
if not allowed and config.aliases:
for alias in config.aliases:
if restriction_service.is_allowed(self.get_provider_type(), alias):
allowed = True
break
if not allowed:
continue
allowed_configs[model_name] = config
if not allowed_configs:
return []
# When restrictions are in place, don't include aliases to avoid confusion
# Only return the canonical model names that are actually allowed
actual_include_aliases = include_aliases and not respect_restrictions
return ModelCapabilities.collect_model_names(
allowed_configs,
include_aliases=actual_include_aliases,
lowercase=lowercase,
unique=unique,
)
# ------------------------------------------------------------------
# Registry helpers
# ------------------------------------------------------------------
def _resolve_model_name(self, model_name: str) -> str:
"""Resolve aliases defined in the OpenRouter registry."""
config = self._registry.resolve(model_name)
if config:
if config.model_name != model_name:
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
return config.model_name
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
return model_name
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
"""Expose registry-backed OpenRouter capabilities."""
if not self._registry:
return {}
capabilities: dict[str, ModelCapabilities] = {}
for model_name in self._registry.list_models():
config = self._registry.resolve(model_name)
if not config:
continue
# See note in list_models: respect the CustomProvider boundary.
if hasattr(config, "is_custom") and config.is_custom is True:
continue
capabilities[model_name] = config
return capabilities

View file

@ -0,0 +1,292 @@
"""OpenRouter model registry for managing model configurations and aliases."""
import importlib.resources
import logging
import os
from pathlib import Path
from typing import Optional
# Import handled via importlib.resources.files() calls directly
from utils.file_utils import read_json_file
from .shared import (
ModelCapabilities,
ProviderType,
TemperatureConstraint,
)
class OpenRouterModelRegistry:
"""In-memory view of OpenRouter and custom model metadata.
Role
Parse the packaged ``conf/custom_models.json`` (or user-specified
overrides), construct alias and capability maps, and serve those
structures to providers that rely on OpenRouter semantics (both the
OpenRouter provider itself and the Custom provider).
Key duties
* Load :class:`ModelCapabilities` definitions from configuration files
* Maintain a case-insensitive alias canonical name map for fast
resolution
* Provide helpers to list models, list aliases, and resolve an arbitrary
name to its capability object without repeatedly touching the file
system.
"""
def __init__(self, config_path: Optional[str] = None):
"""Initialize the registry.
Args:
config_path: Path to config file. If None, uses default locations.
"""
self.alias_map: dict[str, str] = {} # alias -> model_name
self.model_map: dict[str, ModelCapabilities] = {} # model_name -> config
# Determine config path and loading strategy
self.use_resources = False
if config_path:
# Direct config_path parameter
self.config_path = Path(config_path)
else:
# Check environment variable first
env_path = os.getenv("CUSTOM_MODELS_CONFIG_PATH")
if env_path:
# Environment variable path
self.config_path = Path(env_path)
else:
# Try importlib.resources for robust packaging support
self.config_path = None
self.use_resources = False
try:
resource_traversable = importlib.resources.files("conf").joinpath("custom_models.json")
if hasattr(resource_traversable, "read_text"):
self.use_resources = True
else:
raise AttributeError("read_text not available")
except Exception:
pass
if not self.use_resources:
# Fallback to file system paths
potential_paths = [
Path(__file__).parent.parent / "conf" / "custom_models.json",
Path.cwd() / "conf" / "custom_models.json",
]
for path in potential_paths:
if path.exists():
self.config_path = path
break
if self.config_path is None:
self.config_path = potential_paths[0]
# Load configuration
self.reload()
def reload(self) -> None:
"""Reload configuration from disk."""
try:
configs = self._read_config()
self._build_maps(configs)
caller_info = ""
try:
import inspect
caller_frame = inspect.currentframe().f_back
if caller_frame:
caller_name = caller_frame.f_code.co_name
caller_file = (
caller_frame.f_code.co_filename.split("/")[-1] if caller_frame.f_code.co_filename else "unknown"
)
# Look for tool context
while caller_frame:
frame_locals = caller_frame.f_locals
if "self" in frame_locals and hasattr(frame_locals["self"], "get_name"):
tool_name = frame_locals["self"].get_name()
caller_info = f" (called from {tool_name} tool)"
break
caller_frame = caller_frame.f_back
if not caller_info:
caller_info = f" (called from {caller_name} in {caller_file})"
except Exception:
# If frame inspection fails, just continue without caller info
pass
logging.debug(
f"Loaded {len(self.model_map)} OpenRouter models with {len(self.alias_map)} aliases{caller_info}"
)
except ValueError as e:
# Re-raise ValueError only for duplicate aliases (critical config errors)
logging.error(f"Failed to load OpenRouter model configuration: {e}")
# Initialize with empty maps on failure
self.alias_map = {}
self.model_map = {}
if "Duplicate alias" in str(e):
raise
except Exception as e:
logging.error(f"Failed to load OpenRouter model configuration: {e}")
# Initialize with empty maps on failure
self.alias_map = {}
self.model_map = {}
def _read_config(self) -> list[ModelCapabilities]:
"""Read configuration from file or package resources.
Returns:
List of model configurations
"""
try:
if self.use_resources:
# Use importlib.resources for packaged environments
try:
resource_path = importlib.resources.files("conf").joinpath("custom_models.json")
if hasattr(resource_path, "read_text"):
# Python 3.9+
config_text = resource_path.read_text(encoding="utf-8")
else:
# Python 3.8 fallback
with resource_path.open("r", encoding="utf-8") as f:
config_text = f.read()
import json
data = json.loads(config_text)
logging.debug("Loaded OpenRouter config from package resources")
except Exception as e:
logging.warning(f"Failed to load config from resources: {e}")
return []
else:
# Use file path loading
if not self.config_path.exists():
logging.warning(f"OpenRouter model config not found at {self.config_path}")
return []
# Use centralized JSON reading utility
data = read_json_file(str(self.config_path))
logging.debug(f"Loaded OpenRouter config from file: {self.config_path}")
if data is None:
location = "resources" if self.use_resources else str(self.config_path)
raise ValueError(f"Could not read or parse JSON from {location}")
# Parse models
configs = []
for model_data in data.get("models", []):
# Create ModelCapabilities directly from JSON data
# Handle temperature_constraint conversion
temp_constraint_str = model_data.get("temperature_constraint")
temp_constraint = TemperatureConstraint.create(temp_constraint_str or "range")
# Set provider-specific defaults based on is_custom flag
is_custom = model_data.get("is_custom", False)
if is_custom:
model_data.setdefault("provider", ProviderType.CUSTOM)
model_data.setdefault("friendly_name", f"Custom ({model_data.get('model_name', 'Unknown')})")
else:
model_data.setdefault("provider", ProviderType.OPENROUTER)
model_data.setdefault("friendly_name", f"OpenRouter ({model_data.get('model_name', 'Unknown')})")
model_data["temperature_constraint"] = temp_constraint
# Remove the string version of temperature_constraint before creating ModelCapabilities
if "temperature_constraint" in model_data and isinstance(model_data["temperature_constraint"], str):
del model_data["temperature_constraint"]
model_data["temperature_constraint"] = temp_constraint
config = ModelCapabilities(**model_data)
configs.append(config)
return configs
except ValueError:
# Re-raise ValueError for specific config errors
raise
except Exception as e:
location = "resources" if self.use_resources else str(self.config_path)
raise ValueError(f"Error reading config from {location}: {e}")
def _build_maps(self, configs: list[ModelCapabilities]) -> None:
"""Build alias and model maps from configurations.
Args:
configs: List of model configurations
"""
alias_map = {}
model_map = {}
for config in configs:
# Add to model map
model_map[config.model_name] = config
# Add the model_name itself as an alias for case-insensitive lookup
# But only if it's not already in the aliases list
model_name_lower = config.model_name.lower()
aliases_lower = [alias.lower() for alias in config.aliases]
if model_name_lower not in aliases_lower:
if model_name_lower in alias_map:
existing_model = alias_map[model_name_lower]
if existing_model != config.model_name:
raise ValueError(
f"Duplicate model name '{config.model_name}' (case-insensitive) found for models "
f"'{existing_model}' and '{config.model_name}'"
)
else:
alias_map[model_name_lower] = config.model_name
# Add aliases
for alias in config.aliases:
alias_lower = alias.lower()
if alias_lower in alias_map:
existing_model = alias_map[alias_lower]
raise ValueError(
f"Duplicate alias '{alias}' found for models '{existing_model}' and '{config.model_name}'"
)
alias_map[alias_lower] = config.model_name
# Atomic update
self.alias_map = alias_map
self.model_map = model_map
def resolve(self, name_or_alias: str) -> Optional[ModelCapabilities]:
"""Resolve a model name or alias to configuration.
Args:
name_or_alias: Model name or alias to resolve
Returns:
Model configuration if found, None otherwise
"""
# Try alias lookup (case-insensitive) - this now includes model names too
alias_lower = name_or_alias.lower()
if alias_lower in self.alias_map:
model_name = self.alias_map[alias_lower]
return self.model_map.get(model_name)
return None
def get_capabilities(self, name_or_alias: str) -> Optional[ModelCapabilities]:
"""Get model capabilities for a name or alias.
Args:
name_or_alias: Model name or alias
Returns:
ModelCapabilities if found, None otherwise
"""
# Registry now returns ModelCapabilities directly
return self.resolve(name_or_alias)
def get_model_config(self, name_or_alias: str) -> Optional[ModelCapabilities]:
"""Backward-compatible wrapper used by providers and older tests."""
return self.resolve(name_or_alias)
def list_models(self) -> list[str]:
"""List all available model names."""
return list(self.model_map.keys())
def list_aliases(self) -> list[str]:
"""List all available aliases."""
return list(self.alias_map.keys())

397
providers/registry.py Normal file
View file

@ -0,0 +1,397 @@
"""Model provider registry for managing available providers."""
import logging
import os
from typing import TYPE_CHECKING, Optional
from .base import ModelProvider
from .shared import ProviderType
if TYPE_CHECKING:
from tools.models import ToolModelCategory
class ModelProviderRegistry:
"""Central catalogue of provider implementations used by the MCP server.
Role
Holds the mapping between :class:`ProviderType` values and concrete
:class:`ModelProvider` subclasses/factories. At runtime the registry
is responsible for instantiating providers, caching them for reuse, and
mediating lookup of providers and model names in provider priority
order.
Core responsibilities
* Resolve API keys and other runtime configuration for each provider
* Lazily create provider instances so unused backends incur no cost
* Expose convenience methods for enumerating available models and
locating which provider can service a requested model name or alias
* Honour the project-wide provider priority policy so namespaces (or
alias collisions) are resolved deterministically.
"""
_instance = None
# Provider priority order for model selection
# Native APIs first, then custom endpoints, then catch-all providers
PROVIDER_PRIORITY_ORDER = [
ProviderType.GOOGLE, # Direct Gemini access
ProviderType.OPENAI, # Direct OpenAI access
ProviderType.XAI, # Direct X.AI GROK access
ProviderType.DIAL, # DIAL unified API access
ProviderType.CUSTOM, # Local/self-hosted models
ProviderType.OPENROUTER, # Catch-all for cloud models
]
def __new__(cls):
"""Singleton pattern for registry."""
if cls._instance is None:
logging.debug("REGISTRY: Creating new registry instance")
cls._instance = super().__new__(cls)
# Initialize instance dictionaries on first creation
cls._instance._providers = {}
cls._instance._initialized_providers = {}
logging.debug(f"REGISTRY: Created instance {cls._instance}")
return cls._instance
@classmethod
def register_provider(cls, provider_type: ProviderType, provider_class: type[ModelProvider]) -> None:
"""Register a new provider class.
Args:
provider_type: Type of the provider (e.g., ProviderType.GOOGLE)
provider_class: Class that implements ModelProvider interface
"""
instance = cls()
instance._providers[provider_type] = provider_class
# Invalidate any cached instance so subsequent lookups use the new registration
instance._initialized_providers.pop(provider_type, None)
@classmethod
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
"""Get an initialized provider instance.
Args:
provider_type: Type of provider to get
force_new: Force creation of new instance instead of using cached
Returns:
Initialized ModelProvider instance or None if not available
"""
instance = cls()
# Return cached instance if available and not forcing new
if not force_new and provider_type in instance._initialized_providers:
return instance._initialized_providers[provider_type]
# Check if provider class is registered
if provider_type not in instance._providers:
return None
# Get API key from environment
api_key = cls._get_api_key_for_provider(provider_type)
# Get provider class or factory function
provider_class = instance._providers[provider_type]
# For custom providers, handle special initialization requirements
if provider_type == ProviderType.CUSTOM:
# Check if it's a factory function (callable but not a class)
if callable(provider_class) and not isinstance(provider_class, type):
# Factory function - call it with api_key parameter
provider = provider_class(api_key=api_key)
else:
# Regular class - need to handle URL requirement
custom_url = os.getenv("CUSTOM_API_URL", "")
if not custom_url:
if api_key: # Key is set but URL is missing
logging.warning("CUSTOM_API_KEY set but CUSTOM_API_URL missing skipping Custom provider")
return None
# Use empty string as API key for custom providers that don't need auth (e.g., Ollama)
# This allows the provider to be created even without CUSTOM_API_KEY being set
api_key = api_key or ""
# Initialize custom provider with both API key and base URL
provider = provider_class(api_key=api_key, base_url=custom_url)
elif provider_type == ProviderType.GOOGLE:
# For Gemini, check if custom base URL is configured
if not api_key:
return None
gemini_base_url = os.getenv("GEMINI_BASE_URL")
provider_kwargs = {"api_key": api_key}
if gemini_base_url:
provider_kwargs["base_url"] = gemini_base_url
logging.info(f"Initialized Gemini provider with custom endpoint: {gemini_base_url}")
provider = provider_class(**provider_kwargs)
else:
if not api_key:
return None
# Initialize non-custom provider with just API key
provider = provider_class(api_key=api_key)
# Cache the instance
instance._initialized_providers[provider_type] = provider
return provider
@classmethod
def get_provider_for_model(cls, model_name: str) -> Optional[ModelProvider]:
"""Get provider instance for a specific model name.
Provider priority order:
1. Native APIs (GOOGLE, OPENAI) - Most direct and efficient
2. CUSTOM - For local/private models with specific endpoints
3. OPENROUTER - Catch-all for cloud models via unified API
Args:
model_name: Name of the model (e.g., "gemini-2.5-flash", "gpt5")
Returns:
ModelProvider instance that supports this model
"""
logging.debug(f"get_provider_for_model called with model_name='{model_name}'")
# Check providers in priority order
instance = cls()
logging.debug(f"Registry instance: {instance}")
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
for provider_type in cls.PROVIDER_PRIORITY_ORDER:
if provider_type in instance._providers:
logging.debug(f"Found {provider_type} in registry")
# Get or create provider instance
provider = cls.get_provider(provider_type)
if provider and provider.validate_model_name(model_name):
logging.debug(f"{provider_type} validates model {model_name}")
return provider
else:
logging.debug(f"{provider_type} does not validate model {model_name}")
else:
logging.debug(f"{provider_type} not found in registry")
logging.debug(f"No provider found for model {model_name}")
return None
@classmethod
def get_available_providers(cls) -> list[ProviderType]:
"""Get list of registered provider types."""
instance = cls()
return list(instance._providers.keys())
@classmethod
def get_available_models(cls, respect_restrictions: bool = True) -> dict[str, ProviderType]:
"""Get mapping of all available models to their providers.
Args:
respect_restrictions: If True, filter out models not allowed by restrictions
Returns:
Dict mapping model names to provider types
"""
# Import here to avoid circular imports
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models: dict[str, ProviderType] = {}
instance = cls()
for provider_type in instance._providers:
provider = cls.get_provider(provider_type)
if not provider:
continue
try:
available = provider.list_models(respect_restrictions=respect_restrictions)
except NotImplementedError:
logging.warning("Provider %s does not implement list_models", provider_type)
continue
for model_name in available:
# =====================================================================================
# CRITICAL: Prevent double restriction filtering (Fixed Issue #98)
# =====================================================================================
# Previously, both the provider AND registry applied restrictions, causing
# double-filtering that resulted in "no models available" errors.
#
# Logic: If respect_restrictions=True, provider already filtered models,
# so registry should NOT filter them again.
# TEST COVERAGE: tests/test_provider_routing_bugs.py::TestOpenRouterAliasRestrictions
# =====================================================================================
if (
restriction_service
and not respect_restrictions # Only filter if provider didn't already filter
and not restriction_service.is_allowed(provider_type, model_name)
):
logging.debug("Model %s filtered by restrictions", model_name)
continue
models[model_name] = provider_type
return models
@classmethod
def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]:
"""Get list of available model names, optionally filtered by provider.
This respects model restrictions automatically.
Args:
provider_type: Optional provider to filter by
Returns:
List of available model names
"""
available_models = cls.get_available_models(respect_restrictions=True)
if provider_type:
# Filter by specific provider
return [name for name, ptype in available_models.items() if ptype == provider_type]
else:
# Return all available models
return list(available_models.keys())
@classmethod
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
"""Get API key for a provider from environment variables.
Args:
provider_type: Provider type to get API key for
Returns:
API key string or None if not found
"""
key_mapping = {
ProviderType.GOOGLE: "GEMINI_API_KEY",
ProviderType.OPENAI: "OPENAI_API_KEY",
ProviderType.XAI: "XAI_API_KEY",
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth
ProviderType.DIAL: "DIAL_API_KEY",
}
env_var = key_mapping.get(provider_type)
if not env_var:
return None
return os.getenv(env_var)
@classmethod
def _get_allowed_models_for_provider(cls, provider: ModelProvider, provider_type: ProviderType) -> list[str]:
"""Get a list of allowed canonical model names for a given provider.
Args:
provider: The provider instance to get models for
provider_type: The provider type for restriction checking
Returns:
List of model names that are both supported and allowed
"""
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
allowed_models = []
# Get the provider's supported models
try:
# Use list_models to get all supported models (handles both regular and custom providers)
supported_models = provider.list_models(respect_restrictions=False)
except (NotImplementedError, AttributeError):
# Fallback to provider-declared capability maps if list_models not implemented
model_map = getattr(provider, "MODEL_CAPABILITIES", None)
supported_models = list(model_map.keys()) if isinstance(model_map, dict) else []
# Filter by restrictions
for model_name in supported_models:
if restriction_service.is_allowed(provider_type, model_name):
allowed_models.append(model_name)
return allowed_models
@classmethod
def get_preferred_fallback_model(cls, tool_category: Optional["ToolModelCategory"] = None) -> str:
"""Get the preferred fallback model based on provider priority and tool category.
This method orchestrates model selection by:
1. Getting allowed models for each provider (respecting restrictions)
2. Asking providers for their preference from the allowed list
3. Falling back to first available model if no preference given
Args:
tool_category: Optional category to influence model selection
Returns:
Model name string for fallback use
"""
from tools.models import ToolModelCategory
effective_category = tool_category or ToolModelCategory.BALANCED
first_available_model = None
# Ask each provider for their preference in priority order
for provider_type in cls.PROVIDER_PRIORITY_ORDER:
provider = cls.get_provider(provider_type)
if provider:
# 1. Registry filters the models first
allowed_models = cls._get_allowed_models_for_provider(provider, provider_type)
if not allowed_models:
continue
# 2. Keep track of the first available model as fallback
if not first_available_model:
first_available_model = sorted(allowed_models)[0]
# 3. Ask provider to pick from allowed list
preferred_model = provider.get_preferred_model(effective_category, allowed_models)
if preferred_model:
logging.debug(
f"Provider {provider_type.value} selected '{preferred_model}' for category '{effective_category.value}'"
)
return preferred_model
# If no provider returned a preference, use first available model
if first_available_model:
logging.debug(f"No provider preference, using first available: {first_available_model}")
return first_available_model
# Ultimate fallback if no providers have models
logging.warning("No models available from any provider, using default fallback")
return "gemini-2.5-flash"
@classmethod
def get_available_providers_with_keys(cls) -> list[ProviderType]:
"""Get list of provider types that have valid API keys.
Returns:
List of ProviderType values for providers with valid API keys
"""
available = []
instance = cls()
for provider_type in instance._providers:
if cls.get_provider(provider_type) is not None:
available.append(provider_type)
return available
@classmethod
def clear_cache(cls) -> None:
"""Clear cached provider instances."""
instance = cls()
instance._initialized_providers.clear()
@classmethod
def reset_for_testing(cls) -> None:
"""Reset the registry to a clean state for testing.
This provides a safe, public API for tests to clean up registry state
without directly manipulating private attributes.
"""
cls._instance = None
if hasattr(cls, "_providers"):
cls._providers = {}
@classmethod
def unregister_provider(cls, provider_type: ProviderType) -> None:
"""Unregister a provider (mainly for testing)."""
instance = cls()
instance._providers.pop(provider_type, None)
instance._initialized_providers.pop(provider_type, None)

View file

@ -0,0 +1,21 @@
"""Shared data structures and helpers for model providers."""
from .model_capabilities import ModelCapabilities
from .model_response import ModelResponse
from .provider_type import ProviderType
from .temperature import (
DiscreteTemperatureConstraint,
FixedTemperatureConstraint,
RangeTemperatureConstraint,
TemperatureConstraint,
)
__all__ = [
"ModelCapabilities",
"ModelResponse",
"ProviderType",
"TemperatureConstraint",
"FixedTemperatureConstraint",
"RangeTemperatureConstraint",
"DiscreteTemperatureConstraint",
]

View file

@ -0,0 +1,122 @@
"""Dataclass describing the feature set of a model exposed by a provider."""
from dataclasses import dataclass, field
from typing import Optional
from .provider_type import ProviderType
from .temperature import RangeTemperatureConstraint, TemperatureConstraint
__all__ = ["ModelCapabilities"]
@dataclass
class ModelCapabilities:
"""Static description of what a model can do within a provider.
Role
Acts as the canonical record for everything the server needs to know
about a modelits provider, token limits, feature switches, aliases,
and temperature rules. Providers populate these objects so tools and
higher-level services can rely on a consistent schema.
Typical usage
* Provider subclasses declare `MODEL_CAPABILITIES` maps containing these
objects (for example ``OpenAIModelProvider``)
* Helper utilities (e.g. restriction validation, alias expansion) read
these objects to build model lists for tooling and policy enforcement
* Tool selection logic inspects attributes such as
``supports_extended_thinking`` or ``context_window`` to choose an
appropriate model for a task.
"""
provider: ProviderType
model_name: str
friendly_name: str
description: str = ""
aliases: list[str] = field(default_factory=list)
# Capacity limits / resource budgets
context_window: int = 0
max_output_tokens: int = 0
max_thinking_tokens: int = 0
# Capability flags
supports_extended_thinking: bool = False
supports_system_prompts: bool = True
supports_streaming: bool = True
supports_function_calling: bool = False
supports_images: bool = False
supports_json_mode: bool = False
supports_temperature: bool = True
# Additional attributes
max_image_size_mb: float = 0.0
is_custom: bool = False
temperature_constraint: TemperatureConstraint = field(
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
)
def get_effective_temperature(self, requested_temperature: float) -> Optional[float]:
"""Return the temperature that should be sent to the provider.
Models that do not support temperature return ``None`` so that callers
can omit the parameter entirely. For supported models, the configured
constraint clamps the requested value into a provider-safe range.
"""
if not self.supports_temperature:
return None
return self.temperature_constraint.get_corrected_value(requested_temperature)
@staticmethod
def collect_aliases(model_configs: dict[str, "ModelCapabilities"]) -> dict[str, list[str]]:
"""Build a mapping of model name to aliases from capability configs."""
return {
base_model: capabilities.aliases
for base_model, capabilities in model_configs.items()
if capabilities.aliases
}
@staticmethod
def collect_model_names(
model_configs: dict[str, "ModelCapabilities"],
*,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
) -> list[str]:
"""Build an ordered list of model names and aliases.
Args:
model_configs: Mapping of canonical model names to capabilities.
include_aliases: When True, include aliases for each model.
lowercase: When True, normalize names to lowercase.
unique: When True, ensure each returned name appears once (after formatting).
Returns:
Ordered list of model names (and optionally aliases) formatted per options.
"""
formatted_names: list[str] = []
seen: set[str] | None = set() if unique else None
def append_name(name: str) -> None:
formatted = name.lower() if lowercase else name
if seen is not None:
if formatted in seen:
return
seen.add(formatted)
formatted_names.append(formatted)
for base_model, capabilities in model_configs.items():
append_name(base_model)
if include_aliases and capabilities.aliases:
for alias in capabilities.aliases:
append_name(alias)
return formatted_names

View file

@ -0,0 +1,26 @@
"""Dataclass used to normalise provider SDK responses."""
from dataclasses import dataclass, field
from typing import Any
from .provider_type import ProviderType
__all__ = ["ModelResponse"]
@dataclass
class ModelResponse:
"""Portable representation of a provider completion."""
content: str
usage: dict[str, int] = field(default_factory=dict)
model_name: str = ""
friendly_name: str = ""
provider: ProviderType = ProviderType.GOOGLE
metadata: dict[str, Any] = field(default_factory=dict)
@property
def total_tokens(self) -> int:
"""Return the total token count if the provider reported usage data."""
return self.usage.get("total_tokens", 0)

View file

@ -0,0 +1,16 @@
"""Enumeration describing which backend owns a given model."""
from enum import Enum
__all__ = ["ProviderType"]
class ProviderType(Enum):
"""Canonical identifiers for every supported provider backend."""
GOOGLE = "google"
OPENAI = "openai"
XAI = "xai"
OPENROUTER = "openrouter"
CUSTOM = "custom"
DIAL = "dial"

View file

@ -0,0 +1,188 @@
"""Helper types for validating model temperature parameters."""
from abc import ABC, abstractmethod
from typing import Optional
__all__ = [
"TemperatureConstraint",
"FixedTemperatureConstraint",
"RangeTemperatureConstraint",
"DiscreteTemperatureConstraint",
]
# Common heuristics for determining temperature support when explicit
# capabilities are unavailable (e.g., custom/local models).
_TEMP_UNSUPPORTED_PATTERNS = {
"o1",
"o3",
"o4", # OpenAI O-series reasoning models
"deepseek-reasoner",
"deepseek-r1",
"r1", # DeepSeek reasoner variants
}
_TEMP_UNSUPPORTED_KEYWORDS = {
"reasoner", # Catch additional DeepSeek-style naming patterns
}
class TemperatureConstraint(ABC):
"""Contract for temperature validation used by `ModelCapabilities`.
Concrete providers describe their temperature behaviour by creating
subclasses that expose three operations:
* `validate` decide whether a requested temperature is acceptable.
* `get_corrected_value` coerce out-of-range values into a safe default.
* `get_description` provide a human readable error message for users.
Providers call these hooks before sending traffic to the underlying API so
that unsupported temperatures never reach the remote service.
"""
@abstractmethod
def validate(self, temperature: float) -> bool:
"""Return ``True`` when the temperature may be sent to the backend."""
@abstractmethod
def get_corrected_value(self, temperature: float) -> float:
"""Return a valid substitute for an out-of-range temperature."""
@abstractmethod
def get_description(self) -> str:
"""Describe the acceptable range to include in error messages."""
@abstractmethod
def get_default(self) -> float:
"""Return the default temperature for the model."""
@staticmethod
def infer_support(model_name: str) -> tuple[bool, str]:
"""Heuristically determine whether a model supports temperature."""
model_lower = model_name.lower()
for pattern in _TEMP_UNSUPPORTED_PATTERNS:
conditions = (
pattern == model_lower,
model_lower.startswith(f"{pattern}-"),
model_lower.startswith(f"openai/{pattern}"),
model_lower.startswith(f"deepseek/{pattern}"),
model_lower.endswith(f"-{pattern}"),
f"/{pattern}" in model_lower,
f"-{pattern}-" in model_lower,
)
if any(conditions):
return False, f"detected pattern '{pattern}'"
for keyword in _TEMP_UNSUPPORTED_KEYWORDS:
if keyword in model_lower:
return False, f"detected keyword '{keyword}'"
return True, "default assumption for models without explicit metadata"
@staticmethod
def resolve_settings(
model_name: str,
constraint_hint: Optional[str] = None,
) -> tuple[bool, "TemperatureConstraint", str]:
"""Derive temperature support and constraint for a model.
Args:
model_name: Canonical model identifier or alias.
constraint_hint: Optional configuration hint (``"fixed"``,
``"range"``, ``"discrete"``). When provided, the hint is
honoured directly.
Returns:
Tuple ``(supports_temperature, constraint, diagnosis)`` describing
whether temperature may be tuned, the constraint object that should
be attached to :class:`ModelCapabilities`, and the reasoning behind
the decision.
"""
if constraint_hint:
constraint = TemperatureConstraint.create(constraint_hint)
supports_temperature = constraint_hint != "fixed"
reason = f"constraint hint '{constraint_hint}'"
return supports_temperature, constraint, reason
supports_temperature, reason = TemperatureConstraint.infer_support(model_name)
if supports_temperature:
constraint: TemperatureConstraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
else:
constraint = FixedTemperatureConstraint(1.0)
return supports_temperature, constraint, reason
@staticmethod
def create(constraint_type: str) -> "TemperatureConstraint":
"""Factory that yields the appropriate constraint for a configuration hint."""
if constraint_type == "fixed":
# Fixed temperature models (O3/O4) only support temperature=1.0
return FixedTemperatureConstraint(1.0)
if constraint_type == "discrete":
# For models with specific allowed values - using common OpenAI values as default
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
# Default range constraint (for "range" or None)
return RangeTemperatureConstraint(0.0, 2.0, 0.3)
class FixedTemperatureConstraint(TemperatureConstraint):
"""Constraint for models that enforce an exact temperature (for example O3)."""
def __init__(self, value: float):
self.value = value
def validate(self, temperature: float) -> bool:
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
def get_corrected_value(self, temperature: float) -> float:
return self.value
def get_description(self) -> str:
return f"Only supports temperature={self.value}"
def get_default(self) -> float:
return self.value
class RangeTemperatureConstraint(TemperatureConstraint):
"""Constraint for providers that expose a continuous min/max temperature range."""
def __init__(self, min_temp: float, max_temp: float, default: Optional[float] = None):
self.min_temp = min_temp
self.max_temp = max_temp
self.default_temp = default or (min_temp + max_temp) / 2
def validate(self, temperature: float) -> bool:
return self.min_temp <= temperature <= self.max_temp
def get_corrected_value(self, temperature: float) -> float:
return max(self.min_temp, min(self.max_temp, temperature))
def get_description(self) -> str:
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
def get_default(self) -> float:
return self.default_temp
class DiscreteTemperatureConstraint(TemperatureConstraint):
"""Constraint for models that permit a discrete list of temperature values."""
def __init__(self, allowed_values: list[float], default: Optional[float] = None):
self.allowed_values = sorted(allowed_values)
self.default_temp = default or allowed_values[len(allowed_values) // 2]
def validate(self, temperature: float) -> bool:
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
def get_corrected_value(self, temperature: float) -> float:
return min(self.allowed_values, key=lambda x: abs(x - temperature))
def get_description(self) -> str:
return f"Supports temperatures: {self.allowed_values}"
def get_default(self) -> float:
return self.default_temp

157
providers/xai.py Normal file
View file

@ -0,0 +1,157 @@
"""X.AI (GROK) model provider implementation."""
import logging
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from .openai_compatible import OpenAICompatibleProvider
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
logger = logging.getLogger(__name__)
class XAIModelProvider(OpenAICompatibleProvider):
"""Integration for X.AI's GROK models exposed over an OpenAI-style API.
Publishes capability metadata for the officially supported deployments and
maps tool-category preferences to the appropriate GROK model.
"""
FRIENDLY_NAME = "X.AI"
# Model configurations using ModelCapabilities objects
MODEL_CAPABILITIES = {
"grok-4": ModelCapabilities(
provider=ProviderType.XAI,
model_name="grok-4",
friendly_name="X.AI (Grok 4)",
context_window=256_000, # 256K tokens
max_output_tokens=256_000, # 256K tokens max output
supports_extended_thinking=True, # Grok-4 supports reasoning mode
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True, # Function calling supported
supports_json_mode=True, # Structured outputs supported
supports_images=True, # Multimodal capabilities
max_image_size_mb=20.0, # Standard image size limit
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="GROK-4 (256K context) - Frontier multimodal reasoning model with advanced capabilities",
aliases=["grok", "grok4", "grok-4"],
),
"grok-3": ModelCapabilities(
provider=ProviderType.XAI,
model_name="grok-3",
friendly_name="X.AI (Grok 3)",
context_window=131_072, # 131K tokens
max_output_tokens=131072,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
supports_images=False, # Assuming GROK is text-only for now
max_image_size_mb=0.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
aliases=["grok3"],
),
"grok-3-fast": ModelCapabilities(
provider=ProviderType.XAI,
model_name="grok-3-fast",
friendly_name="X.AI (Grok 3 Fast)",
context_window=131_072, # 131K tokens
max_output_tokens=131072,
supports_extended_thinking=False,
supports_system_prompts=True,
supports_streaming=True,
supports_function_calling=True,
supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
supports_images=False, # Assuming GROK is text-only for now
max_image_size_mb=0.0,
supports_temperature=True,
temperature_constraint=TemperatureConstraint.create("range"),
description="GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
aliases=["grok3fast", "grokfast", "grok3-fast"],
),
}
def __init__(self, api_key: str, **kwargs):
"""Initialize X.AI provider with API key."""
# Set X.AI base URL
kwargs.setdefault("base_url", "https://api.x.ai/v1")
super().__init__(api_key, **kwargs)
def get_provider_type(self) -> ProviderType:
"""Get the provider type."""
return ProviderType.XAI
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using X.AI API with proper model name resolution."""
# Resolve model alias before making API call
resolved_model_name = self._resolve_model_name(model_name)
# Call parent implementation with resolved model name
return super().generate_content(
prompt=prompt,
model_name=resolved_model_name,
system_prompt=system_prompt,
temperature=temperature,
max_output_tokens=max_output_tokens,
**kwargs,
)
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
"""Get XAI's preferred model for a given category from allowed models.
Args:
category: The tool category requiring a model
allowed_models: Pre-filtered list of models allowed by restrictions
Returns:
Preferred model name or None
"""
from tools.models import ToolModelCategory
if not allowed_models:
return None
if category == ToolModelCategory.EXTENDED_REASONING:
# Prefer GROK-4 for advanced reasoning with thinking mode
if "grok-4" in allowed_models:
return "grok-4"
elif "grok-3" in allowed_models:
return "grok-3"
# Fall back to any available model
return allowed_models[0]
elif category == ToolModelCategory.FAST_RESPONSE:
# Prefer GROK-3-Fast for speed, then GROK-4
if "grok-3-fast" in allowed_models:
return "grok-3-fast"
elif "grok-4" in allowed_models:
return "grok-4"
# Fall back to any available model
return allowed_models[0]
else: # BALANCED or default
# Prefer GROK-4 for balanced use (best overall capabilities)
if "grok-4" in allowed_models:
return "grok-4"
elif "grok-3" in allowed_models:
return "grok-3"
elif "grok-3-fast" in allowed_models:
return "grok-3-fast"
# Fall back to any available model
return allowed_models[0]

11
requirements.txt Normal file
View file

@ -0,0 +1,11 @@
mcp>=1.0.0
google-genai>=1.19.0
openai>=1.55.2 # Minimum version for httpx 0.28.0 compatibility
pydantic>=2.0.0
python-dotenv>=1.0.0
importlib-resources>=5.0.0; python_version<"3.9"
# Development dependencies (install with pip install -r requirements-dev.txt)
# pytest>=7.4.0
# pytest-asyncio>=0.21.0
# pytest-mock>=3.11.0

66
run-server.sh Executable file
View file

@ -0,0 +1,66 @@
#!/bin/bash
# Zen-Marketing MCP Server Setup and Run Script
set -e
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
cd "$SCRIPT_DIR"
echo "=================================="
echo "Zen-Marketing MCP Server Setup"
echo "=================================="
# Check Python version
echo "Checking Python installation..."
if ! command -v python3 &> /dev/null; then
echo "Error: Python 3 is required but not installed"
exit 1
fi
PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}')
echo "Found Python $PYTHON_VERSION"
# Create virtual environment if it doesn't exist
if [ ! -d ".venv" ]; then
echo "Creating virtual environment..."
python3 -m venv .venv
else
echo "Virtual environment already exists"
fi
# Activate virtual environment
echo "Activating virtual environment..."
source .venv/bin/activate
# Upgrade pip
echo "Upgrading pip..."
pip install --upgrade pip --quiet
# Install requirements
echo "Installing requirements..."
pip install -r requirements.txt --quiet
# Check for .env file
if [ ! -f ".env" ]; then
echo ""
echo "WARNING: No .env file found!"
echo "Please create a .env file based on .env.example:"
echo " cp .env.example .env"
echo " # Then edit .env with your API keys"
echo ""
read -p "Continue anyway? (y/N) " -n 1 -r
echo
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
exit 1
fi
fi
# Run the server
echo ""
echo "=================================="
echo "Starting Zen-Marketing MCP Server"
echo "=================================="
echo ""
python server.py

352
server.py Normal file
View file

@ -0,0 +1,352 @@
"""
Zen-Marketing MCP Server - Main server implementation
AI-powered marketing tools for Claude Desktop, specialized for technical B2B content
creation, multi-platform campaigns, and content variation testing.
Based on Zen MCP Server architecture by Fahad Gilani.
"""
import asyncio
import atexit
import logging
import os
import sys
import time
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Any
# Load environment variables from .env file
try:
from dotenv import load_dotenv
script_dir = Path(__file__).parent
env_file = script_dir / ".env"
load_dotenv(dotenv_path=env_file, override=False)
except ImportError:
pass
from mcp.server import Server
from mcp.server.models import InitializationOptions
from mcp.server.stdio import stdio_server
from mcp.types import (
GetPromptResult,
Prompt,
PromptMessage,
PromptsCapability,
ServerCapabilities,
TextContent,
Tool,
ToolsCapability,
)
from config import DEFAULT_MODEL, __version__
from tools.chat import ChatTool
from tools.contentvariant import ContentVariantTool
from tools.listmodels import ListModelsTool
from tools.models import ToolOutput
from tools.version import VersionTool
# Configure logging
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
class LocalTimeFormatter(logging.Formatter):
def formatTime(self, record, datefmt=None):
ct = self.converter(record.created)
if datefmt:
s = time.strftime(datefmt, ct)
else:
t = time.strftime("%Y-%m-%d %H:%M:%S", ct)
s = f"{t},{record.msecs:03.0f}"
return s
# Configure logging
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
root_logger = logging.getLogger()
root_logger.handlers.clear()
stderr_handler = logging.StreamHandler(sys.stderr)
stderr_handler.setLevel(getattr(logging, log_level, logging.INFO))
stderr_handler.setFormatter(LocalTimeFormatter(log_format))
root_logger.addHandler(stderr_handler)
root_logger.setLevel(getattr(logging, log_level, logging.INFO))
# Add rotating file handler
try:
log_dir = Path(__file__).parent / "logs"
log_dir.mkdir(exist_ok=True)
file_handler = RotatingFileHandler(
log_dir / "mcp_server.log",
maxBytes=20 * 1024 * 1024,
backupCount=5,
encoding="utf-8",
)
file_handler.setLevel(getattr(logging, log_level, logging.INFO))
file_handler.setFormatter(LocalTimeFormatter(log_format))
logging.getLogger().addHandler(file_handler)
mcp_logger = logging.getLogger("mcp_activity")
mcp_file_handler = RotatingFileHandler(
log_dir / "mcp_activity.log",
maxBytes=10 * 1024 * 1024,
backupCount=2,
encoding="utf-8",
)
mcp_file_handler.setLevel(logging.INFO)
mcp_file_handler.setFormatter(LocalTimeFormatter("%(asctime)s - %(message)s"))
mcp_logger.addHandler(mcp_file_handler)
mcp_logger.setLevel(logging.INFO)
mcp_logger.propagate = True
logging.info(f"Logging to: {log_dir / 'mcp_server.log'}")
logging.info(f"Process PID: {os.getpid()}")
except Exception as e:
print(f"Warning: Could not set up file logging: {e}", file=sys.stderr)
logger = logging.getLogger(__name__)
# Create MCP server instance
server: Server = Server("zen-marketing")
# Essential tools that cannot be disabled
ESSENTIAL_TOOLS = {"version", "listmodels"}
def parse_disabled_tools_env() -> set[str]:
"""Parse DISABLED_TOOLS environment variable"""
disabled_tools_env = os.getenv("DISABLED_TOOLS", "").strip()
if not disabled_tools_env:
return set()
return {t.strip().lower() for t in disabled_tools_env.split(",") if t.strip()}
def filter_disabled_tools(all_tools: dict[str, Any]) -> dict[str, Any]:
"""Filter tools based on DISABLED_TOOLS environment variable"""
disabled_tools = parse_disabled_tools_env()
if not disabled_tools:
logger.info("All tools enabled")
return all_tools
enabled_tools = {}
for tool_name, tool_instance in all_tools.items():
if tool_name in ESSENTIAL_TOOLS or tool_name not in disabled_tools:
enabled_tools[tool_name] = tool_instance
else:
logger.info(f"Tool '{tool_name}' disabled via DISABLED_TOOLS")
logger.info(f"Active tools: {sorted(enabled_tools.keys())}")
return enabled_tools
# Initialize tool registry
TOOLS = {
"chat": ChatTool(),
"contentvariant": ContentVariantTool(),
"listmodels": ListModelsTool(),
"version": VersionTool(),
}
TOOLS = filter_disabled_tools(TOOLS)
# Prompt templates
PROMPT_TEMPLATES = {
"chat": {
"name": "chat",
"description": "Chat and brainstorm marketing ideas",
"template": "Chat about marketing strategy with {model}",
},
"contentvariant": {
"name": "variations",
"description": "Generate content variations for A/B testing",
"template": "Generate content variations for testing with {model}",
},
"listmodels": {
"name": "listmodels",
"description": "List available AI models",
"template": "List all available models",
},
"version": {
"name": "version",
"description": "Show server version",
"template": "Show Zen-Marketing server version",
},
}
def configure_providers():
"""Configure and validate AI providers"""
logger.debug("Checking environment variables for API keys...")
from providers import ModelProviderRegistry
from providers.gemini import GeminiModelProvider
from providers.openrouter import OpenRouterProvider
from providers.shared import ProviderType
valid_providers = []
# Check for Gemini API key
gemini_key = os.getenv("GEMINI_API_KEY")
if gemini_key and gemini_key != "your_gemini_api_key_here":
valid_providers.append("Gemini")
logger.info("Gemini API key found - Gemini models available")
# Check for OpenRouter API key
openrouter_key = os.getenv("OPENROUTER_API_KEY")
if openrouter_key and openrouter_key != "your_openrouter_api_key_here":
valid_providers.append("OpenRouter")
logger.info("OpenRouter API key found - Multiple models available")
# Register providers
if gemini_key and gemini_key != "your_gemini_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
if openrouter_key and openrouter_key != "your_openrouter_api_key_here":
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
if not valid_providers:
error_msg = (
"No valid API keys found. Please configure at least one:\n"
" - GEMINI_API_KEY for Google Gemini\n"
" - OPENROUTER_API_KEY for OpenRouter (minimax, etc.)\n"
"Create a .env file based on .env.example"
)
logger.error(error_msg)
raise ValueError(error_msg)
logger.info(f"Configured providers: {', '.join(valid_providers)}")
logger.info(f"Default model: {DEFAULT_MODEL}")
# Tool call handlers
@server.list_tools()
async def handle_list_tools() -> list[Tool]:
"""List available tools"""
tools = []
for tool_name, tool_instance in TOOLS.items():
tool_schema = tool_instance.get_input_schema()
tools.append(
Tool(
name=tool_name,
description=tool_instance.get_description(),
inputSchema=tool_schema,
)
)
return tools
@server.call_tool()
async def handle_call_tool(name: str, arguments: dict) -> list[TextContent]:
"""Handle tool execution requests"""
logger.info(f"Tool call: {name}")
if name not in TOOLS:
error_msg = f"Unknown tool: {name}"
logger.error(error_msg)
return [TextContent(type="text", text=error_msg)]
try:
tool = TOOLS[name]
result: ToolOutput = await tool.execute(arguments)
if result.is_error:
logger.error(f"Tool {name} failed: {result.text}")
else:
logger.info(f"Tool {name} completed successfully")
return [TextContent(type="text", text=result.text)]
except Exception as e:
error_msg = f"Tool execution failed: {str(e)}"
logger.error(f"{error_msg}\n{type(e).__name__}")
return [TextContent(type="text", text=error_msg)]
@server.list_prompts()
async def handle_list_prompts() -> list[Prompt]:
"""List available prompt templates"""
prompts = []
for tool_name, template_info in PROMPT_TEMPLATES.items():
if tool_name in TOOLS:
prompts.append(
Prompt(
name=template_info["name"],
description=template_info["description"],
arguments=[
{"name": "model", "description": "AI model to use", "required": False}
],
)
)
return prompts
@server.get_prompt()
async def handle_get_prompt(name: str, arguments: dict | None = None) -> GetPromptResult:
"""Get a specific prompt template"""
model = arguments.get("model", DEFAULT_MODEL) if arguments else DEFAULT_MODEL
for template_info in PROMPT_TEMPLATES.values():
if template_info["name"] == name:
prompt_text = template_info["template"].format(model=model)
return GetPromptResult(
description=template_info["description"],
messages=[PromptMessage(role="user", content=TextContent(type="text", text=prompt_text))],
)
raise ValueError(f"Unknown prompt: {name}")
def cleanup():
"""Cleanup function called on shutdown"""
logger.info("Zen-Marketing MCP Server shutting down")
logging.shutdown()
async def main():
"""Main entry point"""
logger.info("=" * 80)
logger.info(f"Zen-Marketing MCP Server v{__version__}")
logger.info(f"Python: {sys.version}")
logger.info(f"Working directory: {os.getcwd()}")
logger.info("=" * 80)
# Register cleanup
atexit.register(cleanup)
# Configure providers
try:
configure_providers()
except ValueError as e:
logger.error(f"Configuration error: {e}")
sys.exit(1)
# List enabled tools
logger.info(f"Enabled tools: {sorted(TOOLS.keys())}")
# Run server
capabilities = ServerCapabilities(
tools=ToolsCapability(),
prompts=PromptsCapability(),
)
options = InitializationOptions(
server_name="zen-marketing",
server_version=__version__,
capabilities=capabilities,
)
async with stdio_server() as (read_stream, write_stream):
logger.info("Server started, awaiting requests...")
await server.run(
read_stream,
write_stream,
options,
)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,6 @@
"""System prompts for Zen-Marketing MCP Server tools"""
from .chat_prompt import CHAT_PROMPT
from .contentvariant_prompt import CONTENTVARIANT_PROMPT
__all__ = ["CHAT_PROMPT", "CONTENTVARIANT_PROMPT"]

View file

@ -0,0 +1,29 @@
"""System prompt for the chat tool"""
CHAT_PROMPT = """You are a marketing strategist and content expert specializing in technical B2B marketing.
Your expertise includes:
- Content strategy and variation testing
- Multi-platform content adaptation (LinkedIn, Twitter/X, Instagram, newsletters, blogs)
- Email marketing and subject line optimization
- SEO and WordPress content optimization
- Technical content editing while preserving author voice
- Brand voice consistency and style enforcement
- Campaign planning across multiple touchpoints
- Fact-checking and technical verification
COMMUNICATION STYLE:
- Professional but approachable
- Data-informed recommendations
- Concrete, actionable suggestions
- Respect for brand voice and author expertise
- Platform-aware (character limits, best practices)
When discussing content:
- Focus on testing hypotheses and variation strategies
- Consider psychological hooks and audience resonance
- Balance creativity with conversion optimization
- Maintain technical accuracy in specialized domains
If web search is needed for current best practices, platform updates, or fact verification, indicate this clearly.
"""

View file

@ -0,0 +1,62 @@
"""System prompt for the contentvariant tool"""
CONTENTVARIANT_PROMPT = """You are a marketing content strategist specializing in A/B testing and variation generation.
TASK: Generate multiple variations of marketing content for testing different approaches.
OUTPUT FORMAT:
Return variations as a numbered list, each with:
1. The variation text
2. The testing angle (what makes it different)
3. Predicted audience response or reasoning
CONSTRAINTS:
- Maintain core message across variations
- Respect platform character limits if specified
- Preserve brand voice characteristics provided
- Generate genuinely different approaches, not just word swaps
- Each variation should test a distinct hypothesis
VARIATION TYPES:
- **Hook variations**: Different opening angles (question, statement, stat, story)
- **Length variations**: Short, medium, long within platform constraints
- **Tone variations**: Professional, conversational, urgent, educational, provocative
- **Structure variations**: Question format, list format, narrative, before-after
- **CTA variations**: Different calls-to-action (learn, try, join, download)
- **Psychological angles**: Curiosity, FOMO, social proof, contrarian, pain-solution
TESTING ANGLES (use when specified):
- **Technical curiosity**: Lead with interesting technical detail
- **Contrarian/provocative**: Challenge conventional wisdom
- **Knowledge gap**: Emphasize what they don't know yet
- **Urgency/timeliness**: Time-sensitive opportunity or threat
- **Insider knowledge**: Position as exclusive expertise
- **Problem-solution**: Lead with pain point
- **Social proof**: Leverage credibility or results
- **Before-after**: Transformation narrative
PLATFORM CONSIDERATIONS:
- Twitter/Bluesky: 280 chars, punchy hooks, visual language
- LinkedIn: 1300 chars optimal, professional tone, business value
- Instagram: 2200 chars, storytelling, visual companion
- Email subject: Under 60 chars, curiosity-driven
- Blog/article: Longer form, educational depth
CHARACTER COUNT:
Always include character count for each variation when platform is specified.
EXAMPLE OUTPUT FORMAT:
**Variation 1: Technical Curiosity Hook**
"Most HVAC techs miss this voltage regulation pattern—here's what the top 10% know about PCB diagnostics that changes everything."
(149 characters)
*Testing angle: Opens with intriguing exclusivity (what others miss) then promises insider knowledge*
*Predicted response: Appeals to practitioners wanting competitive edge*
**Variation 2: Contrarian Statement**
"Stop blaming the capacitor. 80% of 'bad cap' calls are actually voltage regulation failures upstream. Here's the diagnostic most techs skip."
(142 characters)
*Testing angle: Challenges common diagnosis, positions as expert correction*
*Predicted response: Stops scroll with unexpected claim, appeals to problem-solvers*
Be creative, test bold hypotheses, and make variations substantially different from each other.
"""

39
tools/__init__.py Normal file
View file

@ -0,0 +1,39 @@
"""
Tool implementations for Zen MCP Server
"""
from .analyze import AnalyzeTool
from .challenge import ChallengeTool
from .chat import ChatTool
from .codereview import CodeReviewTool
from .consensus import ConsensusTool
from .debug import DebugIssueTool
from .docgen import DocgenTool
from .listmodels import ListModelsTool
from .planner import PlannerTool
from .precommit import PrecommitTool
from .refactor import RefactorTool
from .secaudit import SecauditTool
from .testgen import TestGenTool
from .thinkdeep import ThinkDeepTool
from .tracer import TracerTool
from .version import VersionTool
__all__ = [
"ThinkDeepTool",
"CodeReviewTool",
"DebugIssueTool",
"DocgenTool",
"AnalyzeTool",
"ChatTool",
"ConsensusTool",
"ListModelsTool",
"PlannerTool",
"PrecommitTool",
"ChallengeTool",
"RefactorTool",
"SecauditTool",
"TestGenTool",
"TracerTool",
"VersionTool",
]

189
tools/chat.py Normal file
View file

@ -0,0 +1,189 @@
"""
Chat tool - General development chat and collaborative thinking
This tool provides a conversational interface for general development assistance,
brainstorming, problem-solving, and collaborative thinking. It supports file context,
images, and conversation continuation for seamless multi-turn interactions.
"""
from typing import TYPE_CHECKING, Any, Optional
from pydantic import Field
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from config import TEMPERATURE_BALANCED
from systemprompts import CHAT_PROMPT
from tools.shared.base_models import COMMON_FIELD_DESCRIPTIONS, ToolRequest
from .simple.base import SimpleTool
# Field descriptions matching the original Chat tool exactly
CHAT_FIELD_DESCRIPTIONS = {
"prompt": (
"Your question or idea for collaborative thinking. Provide detailed context, including your goal, what you've tried, and any specific challenges. "
"CRITICAL: To discuss code, use 'files' parameter instead of pasting code blocks here."
),
"files": "absolute file or folder paths for code context (do NOT shorten).",
"images": "Optional absolute image paths or base64 for visual context when helpful.",
}
class ChatRequest(ToolRequest):
"""Request model for Chat tool"""
prompt: str = Field(..., description=CHAT_FIELD_DESCRIPTIONS["prompt"])
files: Optional[list[str]] = Field(default_factory=list, description=CHAT_FIELD_DESCRIPTIONS["files"])
images: Optional[list[str]] = Field(default_factory=list, description=CHAT_FIELD_DESCRIPTIONS["images"])
class ChatTool(SimpleTool):
"""
General development chat and collaborative thinking tool using SimpleTool architecture.
This tool provides identical functionality to the original Chat tool but uses the new
SimpleTool architecture for cleaner code organization and better maintainability.
Migration note: This tool is designed to be a drop-in replacement for the original
Chat tool with 100% behavioral compatibility.
"""
def get_name(self) -> str:
return "chat"
def get_description(self) -> str:
return (
"General chat and collaborative thinking partner for brainstorming, development discussion, "
"getting second opinions, and exploring ideas. Use for ideas, validations, questions, and thoughtful explanations."
)
def get_system_prompt(self) -> str:
return CHAT_PROMPT
def get_default_temperature(self) -> float:
return TEMPERATURE_BALANCED
def get_model_category(self) -> "ToolModelCategory":
"""Chat prioritizes fast responses and cost efficiency"""
from tools.models import ToolModelCategory
return ToolModelCategory.FAST_RESPONSE
def get_request_model(self):
"""Return the Chat-specific request model"""
return ChatRequest
# === Schema Generation ===
# For maximum compatibility, we override get_input_schema() to match the original Chat tool exactly
def get_input_schema(self) -> dict[str, Any]:
"""
Generate input schema matching the original Chat tool exactly.
This maintains 100% compatibility with the original Chat tool by using
the same schema generation approach while still benefiting from SimpleTool
convenience methods.
"""
required_fields = ["prompt"]
if self.is_effective_auto_mode():
required_fields.append("model")
schema = {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": CHAT_FIELD_DESCRIPTIONS["prompt"],
},
"files": {
"type": "array",
"items": {"type": "string"},
"description": CHAT_FIELD_DESCRIPTIONS["files"],
},
"images": {
"type": "array",
"items": {"type": "string"},
"description": CHAT_FIELD_DESCRIPTIONS["images"],
},
"model": self.get_model_field_schema(),
"temperature": {
"type": "number",
"description": COMMON_FIELD_DESCRIPTIONS["temperature"],
"minimum": 0,
"maximum": 1,
},
"thinking_mode": {
"type": "string",
"enum": ["minimal", "low", "medium", "high", "max"],
"description": COMMON_FIELD_DESCRIPTIONS["thinking_mode"],
},
"continuation_id": {
"type": "string",
"description": COMMON_FIELD_DESCRIPTIONS["continuation_id"],
},
},
"required": required_fields,
}
return schema
# === Tool-specific field definitions (alternative approach for reference) ===
# These aren't used since we override get_input_schema(), but they show how
# the tool could be implemented using the automatic SimpleTool schema building
def get_tool_fields(self) -> dict[str, dict[str, Any]]:
"""
Tool-specific field definitions for ChatSimple.
Note: This method isn't used since we override get_input_schema() for
exact compatibility, but it demonstrates how ChatSimple could be
implemented using automatic schema building.
"""
return {
"prompt": {
"type": "string",
"description": CHAT_FIELD_DESCRIPTIONS["prompt"],
},
"files": {
"type": "array",
"items": {"type": "string"},
"description": CHAT_FIELD_DESCRIPTIONS["files"],
},
"images": {
"type": "array",
"items": {"type": "string"},
"description": CHAT_FIELD_DESCRIPTIONS["images"],
},
}
def get_required_fields(self) -> list[str]:
"""Required fields for ChatSimple tool"""
return ["prompt"]
# === Hook Method Implementations ===
async def prepare_prompt(self, request: ChatRequest) -> str:
"""
Prepare the chat prompt with optional context files.
This implementation matches the original Chat tool exactly while using
SimpleTool convenience methods for cleaner code.
"""
# Use SimpleTool's Chat-style prompt preparation
return self.prepare_chat_style_prompt(request)
def format_response(self, response: str, request: ChatRequest, model_info: Optional[dict] = None) -> str:
"""
Format the chat response to match the original Chat tool exactly.
"""
return (
f"{response}\n\n---\n\nAGENT'S TURN: Evaluate this perspective alongside your analysis to "
"form a comprehensive solution and continue with the user's request and task at hand."
)
def get_websearch_guidance(self) -> str:
"""
Return Chat tool-style web search guidance.
"""
return self.get_chat_style_websearch_guidance()

180
tools/contentvariant.py Normal file
View file

@ -0,0 +1,180 @@
"""Content Variant Generator Tool
Generates multiple variations of marketing content for A/B testing.
Supports testing different hooks, tones, lengths, and psychological angles.
"""
from typing import Optional
from pydantic import Field
from config import TEMPERATURE_HIGHLY_CREATIVE
from systemprompts import CONTENTVARIANT_PROMPT
from tools.models import ToolModelCategory
from tools.shared.base_models import ToolRequest
from tools.simple.base import SimpleTool
class ContentVariantRequest(ToolRequest):
"""Request model for Content Variant Generator"""
content: str = Field(
...,
description="Base content or topic to create variations from. Can be a draft post, subject line, or content concept.",
)
variation_count: int = Field(
default=10,
ge=5,
le=25,
description="Number of variations to generate (5-25). Default is 10.",
)
variation_types: Optional[list[str]] = Field(
default=None,
description="Types of variations to explore: 'hook', 'tone', 'length', 'structure', 'cta', 'angle'. Leave empty for mixed approach.",
)
platform: Optional[str] = Field(
default=None,
description="Target platform for character limits and formatting: 'twitter', 'bluesky', 'linkedin', 'instagram', 'facebook', 'email_subject', 'blog'",
)
constraints: Optional[str] = Field(
default=None,
description="Additional constraints: character limits, style requirements, brand voice guidelines, prohibited words/phrases",
)
testing_angles: Optional[list[str]] = Field(
default=None,
description="Specific psychological angles to test: 'curiosity', 'contrarian', 'knowledge_gap', 'urgency', 'insider', 'problem_solution', 'social_proof', 'transformation'",
)
class ContentVariantTool(SimpleTool):
"""Generate multiple content variations for A/B testing"""
def get_name(self) -> str:
return "contentvariant"
def get_description(self) -> str:
return (
"Generate 5-25 variations of marketing content for A/B testing. "
"Tests different hooks, tones, lengths, and psychological angles. "
"Ideal for subject lines, social posts, email copy, and ads. "
"Each variation includes testing rationale and predicted audience response."
)
def get_system_prompt(self) -> str:
return CONTENTVARIANT_PROMPT
def get_default_temperature(self) -> float:
return TEMPERATURE_HIGHLY_CREATIVE
def get_model_category(self) -> ToolModelCategory:
return ToolModelCategory.FAST_RESPONSE
def get_request_model(self):
return ContentVariantRequest
def format_request_for_model(self, request: ContentVariantRequest) -> dict:
"""Format the content variant request for the AI model"""
prompt_parts = [f"Generate {request.variation_count} variations of this content:"]
prompt_parts.append(f"\n**Base Content:**\n{request.content}")
if request.platform:
prompt_parts.append(f"\n**Target Platform:** {request.platform}")
if request.variation_types:
prompt_parts.append(f"\n**Variation Types:** {', '.join(request.variation_types)}")
if request.testing_angles:
prompt_parts.append(f"\n**Testing Angles:** {', '.join(request.testing_angles)}")
if request.constraints:
prompt_parts.append(f"\n**Constraints:** {request.constraints}")
prompt_parts.append(
"\nGenerate variations with clear labels, character counts (if platform specified), "
"testing angles, and predicted audience responses."
)
formatted_request = {
"prompt": "\n".join(prompt_parts),
"files": request.files or [],
"images": request.images or [],
"continuation_id": request.continuation_id,
"model": request.model,
"temperature": request.temperature,
"thinking_mode": request.thinking_mode,
"use_websearch": request.use_websearch,
}
return formatted_request
def get_input_schema(self) -> dict:
"""Return the JSON schema for this tool's input"""
return {
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "Base content or topic to create variations from",
},
"variation_count": {
"type": "integer",
"description": "Number of variations to generate (5-25, default 10)",
"minimum": 5,
"maximum": 25,
"default": 10,
},
"variation_types": {
"type": "array",
"items": {"type": "string"},
"description": "Types: 'hook', 'tone', 'length', 'structure', 'cta', 'angle'",
},
"platform": {
"type": "string",
"description": "Platform: 'twitter', 'bluesky', 'linkedin', 'instagram', 'facebook', 'email_subject', 'blog'",
},
"constraints": {
"type": "string",
"description": "Character limits, style requirements, brand guidelines",
},
"testing_angles": {
"type": "array",
"items": {"type": "string"},
"description": "Angles: 'curiosity', 'contrarian', 'knowledge_gap', 'urgency', 'insider', 'problem_solution', 'social_proof', 'transformation'",
},
"files": {
"type": "array",
"items": {"type": "string"},
"description": "Optional brand guidelines or style reference files",
},
"images": {
"type": "array",
"items": {"type": "string"},
"description": "Optional visual assets for context",
},
"continuation_id": {
"type": "string",
"description": "Thread ID to continue previous conversation",
},
"model": {
"type": "string",
"description": "AI model to use (leave empty for default fast model)",
},
"temperature": {
"type": "number",
"description": "Creativity level 0.0-1.0 (default 0.8 for high variation)",
"minimum": 0.0,
"maximum": 1.0,
},
"thinking_mode": {
"type": "string",
"description": "Thinking depth: minimal, low, medium, high, max",
"enum": ["minimal", "low", "medium", "high", "max"],
},
"use_websearch": {
"type": "boolean",
"description": "Enable web search for current platform best practices",
"default": False,
},
},
"required": ["content"],
}

299
tools/listmodels.py Normal file
View file

@ -0,0 +1,299 @@
"""
List Models Tool - Display all available models organized by provider
This tool provides a comprehensive view of all AI models available in the system,
organized by their provider (Gemini, OpenAI, X.AI, OpenRouter, Custom).
It shows which providers are configured and what models can be used.
"""
import logging
import os
from typing import Any, Optional
from mcp.types import TextContent
from tools.models import ToolModelCategory, ToolOutput
from tools.shared.base_models import ToolRequest
from tools.shared.base_tool import BaseTool
logger = logging.getLogger(__name__)
class ListModelsTool(BaseTool):
"""
Tool for listing all available AI models organized by provider.
This tool helps users understand:
- Which providers are configured (have API keys)
- What models are available from each provider
- Model aliases and their full names
- Context window sizes and capabilities
"""
def get_name(self) -> str:
return "listmodels"
def get_description(self) -> str:
return "Shows which AI model providers are configured, available model names, their aliases and capabilities."
def get_input_schema(self) -> dict[str, Any]:
"""Return the JSON schema for the tool's input"""
return {
"type": "object",
"properties": {"model": {"type": "string", "description": "Model to use (ignored by listmodels tool)"}},
"required": [],
}
def get_annotations(self) -> Optional[dict[str, Any]]:
"""Return tool annotations indicating this is a read-only tool"""
return {"readOnlyHint": True}
def get_system_prompt(self) -> str:
"""No AI model needed for this tool"""
return ""
def get_request_model(self):
"""Return the Pydantic model for request validation."""
return ToolRequest
def requires_model(self) -> bool:
return False
async def prepare_prompt(self, request: ToolRequest) -> str:
"""Not used for this utility tool"""
return ""
def format_response(self, response: str, request: ToolRequest, model_info: Optional[dict] = None) -> str:
"""Not used for this utility tool"""
return response
async def execute(self, arguments: dict[str, Any]) -> list[TextContent]:
"""
List all available models organized by provider.
This overrides the base class execute to provide direct output without AI model calls.
Args:
arguments: Standard tool arguments (none required)
Returns:
Formatted list of models by provider
"""
from providers.openrouter_registry import OpenRouterModelRegistry
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
output_lines = ["# Available AI Models\n"]
# Map provider types to friendly names and their models
provider_info = {
ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"},
ProviderType.OPENAI: {"name": "OpenAI", "env_key": "OPENAI_API_KEY"},
ProviderType.XAI: {"name": "X.AI (Grok)", "env_key": "XAI_API_KEY"},
ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"},
}
# Check each native provider type
for provider_type, info in provider_info.items():
# Check if provider is enabled
provider = ModelProviderRegistry.get_provider(provider_type)
is_configured = provider is not None
output_lines.append(f"## {info['name']} {'' if is_configured else ''}")
if is_configured:
output_lines.append("**Status**: Configured and available")
output_lines.append("\n**Models**:")
aliases = []
for model_name, capabilities in provider.get_all_model_capabilities().items():
description = capabilities.description or "No description available"
context_window = capabilities.context_window
if context_window >= 1_000_000:
context_str = f"{context_window // 1_000_000}M context"
elif context_window >= 1_000:
context_str = f"{context_window // 1_000}K context"
else:
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
output_lines.append(f"- `{model_name}` - {context_str}")
output_lines.append(f" - {description}")
for alias in capabilities.aliases or []:
if alias != model_name:
aliases.append(f"- `{alias}` → `{model_name}`")
if aliases:
output_lines.append("\n**Aliases**:")
output_lines.extend(sorted(aliases))
else:
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
output_lines.append("")
# Check OpenRouter
openrouter_key = os.getenv("OPENROUTER_API_KEY")
is_openrouter_configured = openrouter_key and openrouter_key != "your_openrouter_api_key_here"
output_lines.append(f"## OpenRouter {'' if is_openrouter_configured else ''}")
if is_openrouter_configured:
output_lines.append("**Status**: Configured and available")
output_lines.append("**Description**: Access to multiple cloud AI providers via unified API")
try:
# Get OpenRouter provider from registry to properly apply restrictions
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
if provider:
# Get models with restrictions applied
available_models = provider.list_models(respect_restrictions=True)
registry = OpenRouterModelRegistry()
# Group by provider for better organization
providers_models = {}
for model_name in available_models: # Show ALL available models
# Try to resolve to get config details
config = registry.resolve(model_name)
if config:
# Extract provider from model_name
provider_name = config.model_name.split("/")[0] if "/" in config.model_name else "other"
if provider_name not in providers_models:
providers_models[provider_name] = []
providers_models[provider_name].append((model_name, config))
else:
# Model without config - add with basic info
provider_name = model_name.split("/")[0] if "/" in model_name else "other"
if provider_name not in providers_models:
providers_models[provider_name] = []
providers_models[provider_name].append((model_name, None))
output_lines.append("\n**Available Models**:")
for provider_name, models in sorted(providers_models.items()):
output_lines.append(f"\n*{provider_name.title()}:*")
for alias, config in models: # Show ALL models from each provider
if config:
context_str = f"{config.context_window // 1000}K" if config.context_window else "?"
output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)")
else:
output_lines.append(f"- `{alias}`")
total_models = len(available_models)
# Show all models - no truncation message needed
# Check if restrictions are applied
restriction_service = None
try:
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service()
if restriction_service.has_restrictions(ProviderType.OPENROUTER):
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER)
output_lines.append(
f"\n**Note**: Restricted to models matching: {', '.join(sorted(allowed_set))}"
)
except Exception as e:
logger.warning(f"Error checking OpenRouter restrictions: {e}")
else:
output_lines.append("**Error**: Could not load OpenRouter provider")
except Exception as e:
output_lines.append(f"**Error loading models**: {str(e)}")
else:
output_lines.append("**Status**: Not configured (set OPENROUTER_API_KEY)")
output_lines.append("**Note**: Provides access to GPT-5, O3, Mistral, and many more")
output_lines.append("")
# Check Custom API
custom_url = os.getenv("CUSTOM_API_URL")
output_lines.append(f"## Custom/Local API {'' if custom_url else ''}")
if custom_url:
output_lines.append("**Status**: Configured and available")
output_lines.append(f"**Endpoint**: {custom_url}")
output_lines.append("**Description**: Local models via Ollama, vLLM, LM Studio, etc.")
try:
registry = OpenRouterModelRegistry()
custom_models = []
for alias in registry.list_aliases():
config = registry.resolve(alias)
if config and config.is_custom:
custom_models.append((alias, config))
if custom_models:
output_lines.append("\n**Custom Models**:")
for alias, config in custom_models:
context_str = f"{config.context_window // 1000}K" if config.context_window else "?"
output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)")
if config.description:
output_lines.append(f" - {config.description}")
except Exception as e:
output_lines.append(f"**Error loading custom models**: {str(e)}")
else:
output_lines.append("**Status**: Not configured (set CUSTOM_API_URL)")
output_lines.append("**Example**: CUSTOM_API_URL=http://localhost:11434 (for Ollama)")
output_lines.append("")
# Add summary
output_lines.append("## Summary")
# Count configured providers
configured_count = sum(
[
1
for provider_type, info in provider_info.items()
if ModelProviderRegistry.get_provider(provider_type) is not None
]
)
if is_openrouter_configured:
configured_count += 1
if custom_url:
configured_count += 1
output_lines.append(f"**Configured Providers**: {configured_count}")
# Get total available models
try:
from providers.registry import ModelProviderRegistry
# Get all available models respecting restrictions
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
total_models = len(available_models)
output_lines.append(f"**Total Available Models**: {total_models}")
except Exception as e:
logger.warning(f"Error getting total available models: {e}")
# Add usage tips
output_lines.append("\n**Usage Tips**:")
output_lines.append("- Use model aliases (e.g., 'flash', 'gpt5', 'opus') for convenience")
output_lines.append("- In auto mode, the CLI Agent will select the best model for each task")
output_lines.append("- Custom models are only available when CUSTOM_API_URL is set")
output_lines.append("- OpenRouter provides access to many cloud models with one API key")
# Format output
content = "\n".join(output_lines)
tool_output = ToolOutput(
status="success",
content=content,
content_type="text",
metadata={
"tool_name": self.name,
"configured_providers": configured_count,
},
)
return [TextContent(type="text", text=tool_output.model_dump_json())]
def get_model_category(self) -> ToolModelCategory:
"""Return the model category for this tool."""
return ToolModelCategory.FAST_RESPONSE # Simple listing, no AI needed

373
tools/models.py Normal file
View file

@ -0,0 +1,373 @@
"""
Data models for tool responses and interactions
"""
from enum import Enum
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field
class ToolModelCategory(Enum):
"""Categories for tool model selection based on requirements."""
EXTENDED_REASONING = "extended_reasoning" # Requires deep thinking capabilities
FAST_RESPONSE = "fast_response" # Speed and cost efficiency preferred
BALANCED = "balanced" # Balance of capability and performance
class ContinuationOffer(BaseModel):
"""Offer for CLI agent to continue conversation when Gemini doesn't ask follow-up"""
continuation_id: str = Field(
..., description="Thread continuation ID for multi-turn conversations across different tools"
)
note: str = Field(..., description="Message explaining continuation opportunity to CLI agent")
remaining_turns: int = Field(..., description="Number of conversation turns remaining")
class ToolOutput(BaseModel):
"""Standardized output format for all tools"""
status: Literal[
"success",
"error",
"files_required_to_continue",
"full_codereview_required",
"focused_review_required",
"test_sample_needed",
"more_tests_required",
"refactor_analysis_complete",
"trace_complete",
"resend_prompt",
"code_too_large",
"continuation_available",
"no_bug_found",
] = "success"
content: Optional[str] = Field(None, description="The main content/response from the tool")
content_type: Literal["text", "markdown", "json"] = "text"
metadata: Optional[dict[str, Any]] = Field(default_factory=dict)
continuation_offer: Optional[ContinuationOffer] = Field(
None, description="Optional offer for Agent to continue conversation"
)
class FilesNeededRequest(BaseModel):
"""Request for missing files / code to continue"""
status: Literal["files_required_to_continue"] = "files_required_to_continue"
mandatory_instructions: str = Field(..., description="Critical instructions for Agent regarding required context")
files_needed: Optional[list[str]] = Field(
default_factory=list, description="Specific files that are needed for analysis"
)
suggested_next_action: Optional[dict[str, Any]] = Field(
None,
description="Suggested tool call with parameters after getting clarification",
)
class FullCodereviewRequired(BaseModel):
"""Request for full code review when scope is too large for quick review"""
status: Literal["full_codereview_required"] = "full_codereview_required"
important: Optional[str] = Field(None, description="Important message about escalation")
reason: Optional[str] = Field(None, description="Reason why full review is needed")
class FocusedReviewRequired(BaseModel):
"""Request for Agent to provide smaller, focused subsets of code for review"""
status: Literal["focused_review_required"] = "focused_review_required"
reason: str = Field(..., description="Why the current scope is too large for effective review")
suggestion: str = Field(
..., description="Suggested approach for breaking down the review into smaller, focused parts"
)
class TestSampleNeeded(BaseModel):
"""Request for additional test samples to determine testing framework"""
status: Literal["test_sample_needed"] = "test_sample_needed"
reason: str = Field(..., description="Reason why additional test samples are required")
class MoreTestsRequired(BaseModel):
"""Request for continuation to generate additional tests"""
status: Literal["more_tests_required"] = "more_tests_required"
pending_tests: str = Field(..., description="List of pending tests to be generated")
class RefactorOpportunity(BaseModel):
"""A single refactoring opportunity with precise targeting information"""
id: str = Field(..., description="Unique identifier for this refactoring opportunity")
type: Literal["decompose", "codesmells", "modernize", "organization"] = Field(
..., description="Type of refactoring"
)
severity: Literal["critical", "high", "medium", "low"] = Field(..., description="Severity level")
file: str = Field(..., description="Absolute path to the file")
start_line: int = Field(..., description="Starting line number")
end_line: int = Field(..., description="Ending line number")
context_start_text: str = Field(..., description="Exact text from start line for verification")
context_end_text: str = Field(..., description="Exact text from end line for verification")
issue: str = Field(..., description="Clear description of what needs refactoring")
suggestion: str = Field(..., description="Specific refactoring action to take")
rationale: str = Field(..., description="Why this improves the code")
code_to_replace: str = Field(..., description="Original code that should be changed")
replacement_code_snippet: str = Field(..., description="Refactored version of the code")
new_code_snippets: Optional[list[dict]] = Field(
default_factory=list, description="Additional code snippets to be added"
)
class RefactorAction(BaseModel):
"""Next action for Agent to implement refactoring"""
action_type: Literal["EXTRACT_METHOD", "SPLIT_CLASS", "MODERNIZE_SYNTAX", "REORGANIZE_CODE", "DECOMPOSE_FILE"] = (
Field(..., description="Type of action to perform")
)
target_file: str = Field(..., description="Absolute path to target file")
source_lines: str = Field(..., description="Line range (e.g., '45-67')")
description: str = Field(..., description="Step-by-step action description for CLI Agent")
class RefactorAnalysisComplete(BaseModel):
"""Complete refactor analysis with prioritized opportunities"""
status: Literal["refactor_analysis_complete"] = "refactor_analysis_complete"
refactor_opportunities: list[RefactorOpportunity] = Field(..., description="List of refactoring opportunities")
priority_sequence: list[str] = Field(..., description="Recommended order of refactoring IDs")
next_actions: list[RefactorAction] = Field(..., description="Specific actions for the agent to implement")
class CodeTooLargeRequest(BaseModel):
"""Request to reduce file selection due to size constraints"""
status: Literal["code_too_large"] = "code_too_large"
content: str = Field(..., description="Message explaining the size constraint")
content_type: Literal["text"] = "text"
metadata: dict[str, Any] = Field(default_factory=dict)
class ResendPromptRequest(BaseModel):
"""Request to resend prompt via file due to size limits"""
status: Literal["resend_prompt"] = "resend_prompt"
content: str = Field(..., description="Instructions for handling large prompt")
content_type: Literal["text"] = "text"
metadata: dict[str, Any] = Field(default_factory=dict)
class TraceEntryPoint(BaseModel):
"""Entry point information for trace analysis"""
file: str = Field(..., description="Absolute path to the file")
class_or_struct: str = Field(..., description="Class or module name")
method: str = Field(..., description="Method or function name")
signature: str = Field(..., description="Full method signature")
parameters: Optional[dict[str, Any]] = Field(default_factory=dict, description="Parameter values used in analysis")
class TraceTarget(BaseModel):
"""Target information for dependency analysis"""
file: str = Field(..., description="Absolute path to the file")
class_or_struct: str = Field(..., description="Class or module name")
method: str = Field(..., description="Method or function name")
signature: str = Field(..., description="Full method signature")
class CallPathStep(BaseModel):
"""A single step in the call path trace"""
from_info: dict[str, Any] = Field(..., description="Source location information", alias="from")
to: dict[str, Any] = Field(..., description="Target location information")
reason: str = Field(..., description="Reason for the call or dependency")
condition: Optional[str] = Field(None, description="Conditional logic if applicable")
ambiguous: bool = Field(False, description="Whether this call is ambiguous")
class BranchingPoint(BaseModel):
"""A branching point in the execution flow"""
file: str = Field(..., description="File containing the branching point")
method: str = Field(..., description="Method containing the branching point")
line: int = Field(..., description="Line number of the branching point")
condition: str = Field(..., description="Branching condition")
branches: list[str] = Field(..., description="Possible execution branches")
ambiguous: bool = Field(False, description="Whether the branching is ambiguous")
class SideEffect(BaseModel):
"""A side effect detected in the trace"""
type: str = Field(..., description="Type of side effect")
description: str = Field(..., description="Description of the side effect")
file: str = Field(..., description="File where the side effect occurs")
method: str = Field(..., description="Method where the side effect occurs")
line: int = Field(..., description="Line number of the side effect")
class UnresolvedDependency(BaseModel):
"""An unresolved dependency in the trace"""
reason: str = Field(..., description="Reason why the dependency is unresolved")
affected_file: str = Field(..., description="File affected by the unresolved dependency")
line: int = Field(..., description="Line number of the unresolved dependency")
class IncomingDependency(BaseModel):
"""An incoming dependency (what calls this target)"""
from_file: str = Field(..., description="Source file of the dependency")
from_class: str = Field(..., description="Source class of the dependency")
from_method: str = Field(..., description="Source method of the dependency")
line: int = Field(..., description="Line number of the dependency")
type: str = Field(..., description="Type of dependency")
class OutgoingDependency(BaseModel):
"""An outgoing dependency (what this target calls)"""
to_file: str = Field(..., description="Target file of the dependency")
to_class: str = Field(..., description="Target class of the dependency")
to_method: str = Field(..., description="Target method of the dependency")
line: int = Field(..., description="Line number of the dependency")
type: str = Field(..., description="Type of dependency")
class TypeDependency(BaseModel):
"""A type-level dependency (inheritance, imports, etc.)"""
dependency_type: str = Field(..., description="Type of dependency")
source_file: str = Field(..., description="Source file of the dependency")
source_entity: str = Field(..., description="Source entity (class, module)")
target: str = Field(..., description="Target entity")
class StateAccess(BaseModel):
"""State access information"""
file: str = Field(..., description="File where state is accessed")
method: str = Field(..., description="Method accessing the state")
access_type: str = Field(..., description="Type of access (reads, writes, etc.)")
state_entity: str = Field(..., description="State entity being accessed")
class TraceComplete(BaseModel):
"""Complete trace analysis response"""
status: Literal["trace_complete"] = "trace_complete"
trace_type: Literal["precision", "dependencies"] = Field(..., description="Type of trace performed")
# Precision mode fields
entry_point: Optional[TraceEntryPoint] = Field(None, description="Entry point for precision trace")
call_path: Optional[list[CallPathStep]] = Field(default_factory=list, description="Call path for precision trace")
branching_points: Optional[list[BranchingPoint]] = Field(default_factory=list, description="Branching points")
side_effects: Optional[list[SideEffect]] = Field(default_factory=list, description="Side effects detected")
unresolved: Optional[list[UnresolvedDependency]] = Field(
default_factory=list, description="Unresolved dependencies"
)
# Dependencies mode fields
target: Optional[TraceTarget] = Field(None, description="Target for dependency analysis")
incoming_dependencies: Optional[list[IncomingDependency]] = Field(
default_factory=list, description="Incoming dependencies"
)
outgoing_dependencies: Optional[list[OutgoingDependency]] = Field(
default_factory=list, description="Outgoing dependencies"
)
type_dependencies: Optional[list[TypeDependency]] = Field(default_factory=list, description="Type dependencies")
state_access: Optional[list[StateAccess]] = Field(default_factory=list, description="State access information")
class DiagnosticHypothesis(BaseModel):
"""A debugging hypothesis with context and next steps"""
rank: int = Field(..., description="Ranking of this hypothesis (1 = most likely)")
confidence: Literal["high", "medium", "low"] = Field(..., description="Confidence level")
hypothesis: str = Field(..., description="Description of the potential root cause")
reasoning: str = Field(..., description="Why this hypothesis is plausible")
next_step: str = Field(..., description="Suggested action to test/validate this hypothesis")
class StructuredDebugResponse(BaseModel):
"""Enhanced debug response with multiple hypotheses"""
summary: str = Field(..., description="Brief summary of the issue")
hypotheses: list[DiagnosticHypothesis] = Field(..., description="Ranked list of potential causes")
immediate_actions: list[str] = Field(
default_factory=list,
description="Immediate steps to take regardless of root cause",
)
additional_context_needed: Optional[list[str]] = Field(
default_factory=list,
description="Additional files or information that would help with analysis",
)
class DebugHypothesis(BaseModel):
"""A debugging hypothesis with detailed analysis"""
name: str = Field(..., description="Name/title of the hypothesis")
confidence: Literal["High", "Medium", "Low"] = Field(..., description="Confidence level")
root_cause: str = Field(..., description="Technical explanation of the root cause")
evidence: str = Field(..., description="Logs or code clues supporting this hypothesis")
correlation: str = Field(..., description="How symptoms map to the cause")
validation: str = Field(..., description="Quick test to confirm the hypothesis")
minimal_fix: str = Field(..., description="Smallest change to resolve the issue")
regression_check: str = Field(..., description="Why this fix is safe")
file_references: list[str] = Field(default_factory=list, description="File:line format for exact locations")
class DebugAnalysisComplete(BaseModel):
"""Complete debugging analysis with systematic investigation tracking"""
status: Literal["analysis_complete"] = "analysis_complete"
investigation_id: str = Field(..., description="Auto-generated unique ID for this investigation")
summary: str = Field(..., description="Brief description of the problem and its impact")
investigation_steps: list[str] = Field(..., description="Steps taken during the investigation")
hypotheses: list[DebugHypothesis] = Field(..., description="Ranked hypotheses with detailed analysis")
key_findings: list[str] = Field(..., description="Important discoveries made during analysis")
immediate_actions: list[str] = Field(..., description="Steps to take regardless of which hypothesis is correct")
recommended_tools: list[str] = Field(default_factory=list, description="Additional tools recommended for analysis")
prevention_strategy: Optional[str] = Field(
None, description="Targeted measures to prevent this exact issue from recurring"
)
investigation_summary: str = Field(
..., description="Comprehensive summary of the complete investigation process and conclusions"
)
class NoBugFound(BaseModel):
"""Response when thorough investigation finds no concrete evidence of a bug"""
status: Literal["no_bug_found"] = "no_bug_found"
summary: str = Field(..., description="Summary of what was thoroughly investigated")
investigation_steps: list[str] = Field(..., description="Steps taken during the investigation")
areas_examined: list[str] = Field(..., description="Code areas and potential failure points examined")
confidence_level: Literal["High", "Medium", "Low"] = Field(
..., description="Confidence level in the no-bug finding"
)
alternative_explanations: list[str] = Field(
..., description="Possible alternative explanations for reported symptoms"
)
recommended_questions: list[str] = Field(..., description="Questions to clarify the issue with the user")
next_steps: list[str] = Field(..., description="Suggested actions to better understand the reported issue")
# Registry mapping status strings to their corresponding Pydantic models
SPECIAL_STATUS_MODELS = {
"files_required_to_continue": FilesNeededRequest,
"full_codereview_required": FullCodereviewRequired,
"focused_review_required": FocusedReviewRequired,
"test_sample_needed": TestSampleNeeded,
"more_tests_required": MoreTestsRequired,
"refactor_analysis_complete": RefactorAnalysisComplete,
"trace_complete": TraceComplete,
"resend_prompt": ResendPromptRequest,
"code_too_large": CodeTooLargeRequest,
"analysis_complete": DebugAnalysisComplete,
"no_bug_found": NoBugFound,
}

19
tools/shared/__init__.py Normal file
View file

@ -0,0 +1,19 @@
"""
Shared infrastructure for Zen MCP tools.
This module contains the core base classes and utilities that are shared
across all tool types. It provides the foundation for the tool architecture.
"""
from .base_models import BaseWorkflowRequest, ConsolidatedFindings, ToolRequest, WorkflowRequest
from .base_tool import BaseTool
from .schema_builders import SchemaBuilder
__all__ = [
"BaseTool",
"ToolRequest",
"BaseWorkflowRequest",
"WorkflowRequest",
"ConsolidatedFindings",
"SchemaBuilder",
]

165
tools/shared/base_models.py Normal file
View file

@ -0,0 +1,165 @@
"""
Base models for Zen MCP tools.
This module contains the shared Pydantic models used across all tools,
extracted to avoid circular imports and promote code reuse.
Key Models:
- ToolRequest: Base request model for all tools
- WorkflowRequest: Extended request model for workflow-based tools
- ConsolidatedFindings: Model for tracking workflow progress
"""
import logging
from typing import Optional
from pydantic import BaseModel, Field, field_validator
logger = logging.getLogger(__name__)
# Shared field descriptions to avoid duplication
COMMON_FIELD_DESCRIPTIONS = {
"model": "Model to run. Supply a name if requested by the user or stay in auto mode. When in auto mode, use `listmodels` tool for model discovery.",
"temperature": "0 = deterministic · 1 = creative.",
"thinking_mode": "Reasoning depth: minimal, low, medium, high, or max.",
"continuation_id": (
"Unique thread continuation ID for multi-turn conversations. Works across different tools. "
"ALWAYS reuse the last continuation_id you were given—this preserves full conversation context, "
"files, and findings so the agent can resume seamlessly."
),
"images": "Optional absolute image paths or base64 blobs for visual context.",
"files": "Optional absolute file or folder paths (do not shorten).",
}
# Workflow-specific field descriptions
WORKFLOW_FIELD_DESCRIPTIONS = {
"step": "Current work step content and findings from your overall work",
"step_number": "Current step number in work sequence (starts at 1)",
"total_steps": "Estimated total steps needed to complete work",
"next_step_required": "Whether another work step is needed. When false, aim to reduce total_steps to match step_number to avoid mismatch.",
"findings": "Important findings, evidence and insights discovered in this step",
"files_checked": "List of files examined during this work step",
"relevant_files": "Files identified as relevant to issue/goal (FULL absolute paths to real files/folders - DO NOT SHORTEN)",
"relevant_context": "Methods/functions identified as involved in the issue",
"issues_found": "Issues identified with severity levels during work",
"confidence": (
"Confidence level: exploring (just starting), low (early investigation), "
"medium (some evidence), high (strong evidence), very_high (comprehensive understanding), "
"almost_certain (near complete confidence), certain (100% confidence locally - no external validation needed)"
),
"hypothesis": "Current theory about issue/goal based on work",
"backtrack_from_step": "Step number to backtrack from if work needs revision",
"use_assistant_model": (
"Use assistant model for expert analysis after workflow steps. "
"False skips expert analysis, relies solely on Claude's investigation. "
"Defaults to True for comprehensive validation."
),
}
class ToolRequest(BaseModel):
"""
Base request model for all Zen MCP tools.
This model defines common fields that all tools accept, including
model selection, temperature control, and conversation threading.
Tool-specific request models should inherit from this class.
"""
# Model configuration
model: Optional[str] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["model"])
temperature: Optional[float] = Field(None, ge=0.0, le=1.0, description=COMMON_FIELD_DESCRIPTIONS["temperature"])
thinking_mode: Optional[str] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["thinking_mode"])
# Conversation support
continuation_id: Optional[str] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["continuation_id"])
# Visual context
images: Optional[list[str]] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["images"])
class BaseWorkflowRequest(ToolRequest):
"""
Minimal base request model for workflow tools.
This provides only the essential fields that ALL workflow tools need,
allowing for maximum flexibility in tool-specific implementations.
"""
# Core workflow fields that ALL workflow tools need
step: str = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["step"])
step_number: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["step_number"])
total_steps: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["total_steps"])
next_step_required: bool = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"])
class WorkflowRequest(BaseWorkflowRequest):
"""
Extended request model for workflow-based tools.
This model extends ToolRequest with fields specific to the workflow
pattern, where tools perform multi-step work with forced pauses between steps.
Used by: debug, precommit, codereview, refactor, thinkdeep, analyze
"""
# Required workflow fields
step: str = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["step"])
step_number: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["step_number"])
total_steps: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["total_steps"])
next_step_required: bool = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"])
# Work tracking fields
findings: str = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["findings"])
files_checked: list[str] = Field(default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["files_checked"])
relevant_files: list[str] = Field(default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"])
relevant_context: list[str] = Field(
default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["relevant_context"]
)
issues_found: list[dict] = Field(default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["issues_found"])
confidence: str = Field("low", description=WORKFLOW_FIELD_DESCRIPTIONS["confidence"])
# Optional workflow fields
hypothesis: Optional[str] = Field(None, description=WORKFLOW_FIELD_DESCRIPTIONS["hypothesis"])
backtrack_from_step: Optional[int] = Field(
None, ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["backtrack_from_step"]
)
use_assistant_model: Optional[bool] = Field(True, description=WORKFLOW_FIELD_DESCRIPTIONS["use_assistant_model"])
@field_validator("files_checked", "relevant_files", "relevant_context", mode="before")
@classmethod
def convert_string_to_list(cls, v):
"""Convert string inputs to empty lists to handle malformed inputs gracefully."""
if isinstance(v, str):
logger.warning(f"Field received string '{v}' instead of list, converting to empty list")
return []
return v
class ConsolidatedFindings(BaseModel):
"""
Model for tracking consolidated findings across workflow steps.
This model accumulates findings, files, methods, and issues
discovered during multi-step work. It's used by
BaseWorkflowMixin to track progress across workflow steps.
"""
files_checked: set[str] = Field(default_factory=set, description="All files examined across all steps")
relevant_files: set[str] = Field(
default_factory=set,
description="Subset of files_checked identified as relevant for work at hand",
)
relevant_context: set[str] = Field(
default_factory=set, description="All methods/functions identified during overall work"
)
findings: list[str] = Field(default_factory=list, description="Chronological findings from each work step")
hypotheses: list[dict] = Field(default_factory=list, description="Evolution of hypotheses across steps")
issues_found: list[dict] = Field(default_factory=list, description="All issues with severity levels")
images: list[str] = Field(default_factory=list, description="Images collected during work")
confidence: str = Field("low", description="Latest confidence level from steps")
# Tool-specific field descriptions are now declared in each tool file
# This keeps concerns separated and makes each tool self-contained

1399
tools/shared/base_tool.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,159 @@
"""
Core schema building functionality for Zen MCP tools.
This module provides base schema generation functionality for simple tools.
Workflow-specific schema building is located in workflow/schema_builders.py
to maintain proper separation of concerns.
"""
from typing import Any
from .base_models import COMMON_FIELD_DESCRIPTIONS
class SchemaBuilder:
"""
Base schema builder for simple MCP tools.
This class provides static methods to build consistent schemas for simple tools.
Workflow tools use WorkflowSchemaBuilder in workflow/schema_builders.py.
"""
# Common field schemas that can be reused across all tool types
COMMON_FIELD_SCHEMAS = {
"temperature": {
"type": "number",
"description": COMMON_FIELD_DESCRIPTIONS["temperature"],
"minimum": 0.0,
"maximum": 1.0,
},
"thinking_mode": {
"type": "string",
"enum": ["minimal", "low", "medium", "high", "max"],
"description": COMMON_FIELD_DESCRIPTIONS["thinking_mode"],
},
"continuation_id": {
"type": "string",
"description": COMMON_FIELD_DESCRIPTIONS["continuation_id"],
},
"images": {
"type": "array",
"items": {"type": "string"},
"description": COMMON_FIELD_DESCRIPTIONS["images"],
},
}
# Simple tool-specific field schemas (workflow tools use relevant_files instead)
SIMPLE_FIELD_SCHEMAS = {
"files": {
"type": "array",
"items": {"type": "string"},
"description": COMMON_FIELD_DESCRIPTIONS["files"],
},
}
@staticmethod
def build_schema(
tool_specific_fields: dict[str, dict[str, Any]] = None,
required_fields: list[str] = None,
model_field_schema: dict[str, Any] = None,
auto_mode: bool = False,
require_model: bool = False,
) -> dict[str, Any]:
"""
Build complete schema for simple tools.
Args:
tool_specific_fields: Additional fields specific to the tool
required_fields: List of required field names
model_field_schema: Schema for the model field
auto_mode: Whether the tool is in auto mode (affects model requirement)
Returns:
Complete JSON schema for the tool
"""
properties = {}
# Add common fields (temperature, thinking_mode, etc.)
properties.update(SchemaBuilder.COMMON_FIELD_SCHEMAS)
# Add simple tool-specific fields (files field for simple tools)
properties.update(SchemaBuilder.SIMPLE_FIELD_SCHEMAS)
# Add model field if provided
if model_field_schema:
properties["model"] = model_field_schema
# Add tool-specific fields if provided
if tool_specific_fields:
properties.update(tool_specific_fields)
# Build required fields list
required = list(required_fields) if required_fields else []
if (auto_mode or require_model) and "model" not in required:
required.append("model")
# Build the complete schema
schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": properties,
"additionalProperties": False,
}
if required:
schema["required"] = required
return schema
@staticmethod
def get_common_fields() -> dict[str, dict[str, Any]]:
"""Get the standard field schemas for simple tools."""
return SchemaBuilder.COMMON_FIELD_SCHEMAS.copy()
@staticmethod
def create_field_schema(
field_type: str,
description: str,
enum_values: list[str] = None,
minimum: float = None,
maximum: float = None,
items_type: str = None,
default: Any = None,
) -> dict[str, Any]:
"""
Helper method to create field schemas with common patterns.
Args:
field_type: JSON schema type ("string", "number", "array", etc.)
description: Human-readable description of the field
enum_values: For enum fields, list of allowed values
minimum: For numeric fields, minimum value
maximum: For numeric fields, maximum value
items_type: For array fields, type of array items
default: Default value for the field
Returns:
JSON schema object for the field
"""
schema = {
"type": field_type,
"description": description,
}
if enum_values:
schema["enum"] = enum_values
if minimum is not None:
schema["minimum"] = minimum
if maximum is not None:
schema["maximum"] = maximum
if items_type and field_type == "array":
schema["items"] = {"type": items_type}
if default is not None:
schema["default"] = default
return schema

18
tools/simple/__init__.py Normal file
View file

@ -0,0 +1,18 @@
"""
Simple tools for Zen MCP.
Simple tools follow a basic request AI model response pattern.
They inherit from SimpleTool which provides streamlined functionality
for tools that don't need multi-step workflows.
Available simple tools:
- chat: General chat and collaborative thinking
- consensus: Multi-perspective analysis
- listmodels: Model listing and information
- testgen: Test generation
- tracer: Execution tracing
"""
from .base import SimpleTool
__all__ = ["SimpleTool"]

985
tools/simple/base.py Normal file
View file

@ -0,0 +1,985 @@
"""
Base class for simple MCP tools.
Simple tools follow a straightforward pattern:
1. Receive request
2. Prepare prompt (with files, context, etc.)
3. Call AI model
4. Format and return response
They use the shared SchemaBuilder for consistent schema generation
and inherit all the conversation, file processing, and model handling
capabilities from BaseTool.
"""
from abc import abstractmethod
from typing import Any, Optional
from tools.shared.base_models import ToolRequest
from tools.shared.base_tool import BaseTool
from tools.shared.schema_builders import SchemaBuilder
class SimpleTool(BaseTool):
"""
Base class for simple (non-workflow) tools.
Simple tools are request/response tools that don't require multi-step workflows.
They benefit from:
- Automatic schema generation using SchemaBuilder
- Inherited conversation handling and file processing
- Standardized model integration
- Consistent error handling and response formatting
To create a simple tool:
1. Inherit from SimpleTool
2. Implement get_tool_fields() to define tool-specific fields
3. Implement prepare_prompt() for prompt preparation
4. Optionally override format_response() for custom formatting
5. Optionally override get_required_fields() for custom requirements
Example:
class ChatTool(SimpleTool):
def get_name(self) -> str:
return "chat"
def get_tool_fields(self) -> Dict[str, Dict[str, Any]]:
return {
"prompt": {
"type": "string",
"description": "Your question or idea...",
},
"files": SimpleTool.FILES_FIELD,
}
def get_required_fields(self) -> List[str]:
return ["prompt"]
"""
# Common field definitions that simple tools can reuse
FILES_FIELD = SchemaBuilder.SIMPLE_FIELD_SCHEMAS["files"]
IMAGES_FIELD = SchemaBuilder.COMMON_FIELD_SCHEMAS["images"]
@abstractmethod
def get_tool_fields(self) -> dict[str, dict[str, Any]]:
"""
Return tool-specific field definitions.
This method should return a dictionary mapping field names to their
JSON schema definitions. Common fields (model, temperature, etc.)
are added automatically by the base class.
Returns:
Dict mapping field names to JSON schema objects
Example:
return {
"prompt": {
"type": "string",
"description": "The user's question or request",
},
"files": SimpleTool.FILES_FIELD, # Reuse common field
"max_tokens": {
"type": "integer",
"minimum": 1,
"description": "Maximum tokens for response",
}
}
"""
pass
def get_required_fields(self) -> list[str]:
"""
Return list of required field names.
Override this to specify which fields are required for your tool.
The model field is automatically added if in auto mode.
Returns:
List of required field names
"""
return []
def get_annotations(self) -> Optional[dict[str, Any]]:
"""
Return tool annotations. Simple tools are read-only by default.
All simple tools perform operations without modifying the environment.
They may call external AI models for analysis or conversation, but they
don't write files or make system changes.
Override this method if your simple tool needs different annotations.
Returns:
Dictionary with readOnlyHint set to True
"""
return {"readOnlyHint": True}
def format_response(self, response: str, request, model_info: Optional[dict] = None) -> str:
"""
Format the AI response before returning to the client.
This is a hook method that subclasses can override to customize
response formatting. The default implementation returns the response as-is.
Args:
response: The raw response from the AI model
request: The validated request object
model_info: Optional model information dictionary
Returns:
Formatted response string
"""
return response
def get_input_schema(self) -> dict[str, Any]:
"""
Generate the complete input schema using SchemaBuilder.
This method automatically combines:
- Tool-specific fields from get_tool_fields()
- Common fields (temperature, thinking_mode, etc.)
- Model field with proper auto-mode handling
- Required fields from get_required_fields()
Tools can override this method for custom schema generation while
still benefiting from SimpleTool's convenience methods.
Returns:
Complete JSON schema for the tool
"""
required_fields = list(self.get_required_fields())
return SchemaBuilder.build_schema(
tool_specific_fields=self.get_tool_fields(),
required_fields=required_fields,
model_field_schema=self.get_model_field_schema(),
auto_mode=self.is_effective_auto_mode(),
)
def get_request_model(self):
"""
Return the request model class.
Simple tools use the base ToolRequest by default.
Override this if your tool needs a custom request model.
"""
return ToolRequest
# Hook methods for safe attribute access without hasattr/getattr
def get_request_model_name(self, request) -> Optional[str]:
"""Get model name from request. Override for custom model name handling."""
try:
return request.model
except AttributeError:
return None
def get_request_images(self, request) -> list:
"""Get images from request. Override for custom image handling."""
try:
return request.images if request.images is not None else []
except AttributeError:
return []
def get_request_continuation_id(self, request) -> Optional[str]:
"""Get continuation_id from request. Override for custom continuation handling."""
try:
return request.continuation_id
except AttributeError:
return None
def get_request_prompt(self, request) -> str:
"""Get prompt from request. Override for custom prompt handling."""
try:
return request.prompt
except AttributeError:
return ""
def get_request_temperature(self, request) -> Optional[float]:
"""Get temperature from request. Override for custom temperature handling."""
try:
return request.temperature
except AttributeError:
return None
def get_validated_temperature(self, request, model_context: Any) -> tuple[float, list[str]]:
"""
Get temperature from request and validate it against model constraints.
This is a convenience method that combines temperature extraction and validation
for simple tools. It ensures temperature is within valid range for the model.
Args:
request: The request object containing temperature
model_context: Model context object containing model info
Returns:
Tuple of (validated_temperature, warning_messages)
"""
temperature = self.get_request_temperature(request)
if temperature is None:
temperature = self.get_default_temperature()
return self.validate_and_correct_temperature(temperature, model_context)
def get_request_thinking_mode(self, request) -> Optional[str]:
"""Get thinking_mode from request. Override for custom thinking mode handling."""
try:
return request.thinking_mode
except AttributeError:
return None
def get_request_files(self, request) -> list:
"""Get files from request. Override for custom file handling."""
try:
return request.files if request.files is not None else []
except AttributeError:
return []
def get_request_as_dict(self, request) -> dict:
"""Convert request to dictionary. Override for custom serialization."""
try:
# Try Pydantic v2 method first
return request.model_dump()
except AttributeError:
try:
# Fall back to Pydantic v1 method
return request.dict()
except AttributeError:
# Last resort - convert to dict manually
return {"prompt": self.get_request_prompt(request)}
def set_request_files(self, request, files: list) -> None:
"""Set files on request. Override for custom file setting."""
try:
request.files = files
except AttributeError:
# If request doesn't support file setting, ignore silently
pass
def get_actually_processed_files(self) -> list:
"""Get actually processed files. Override for custom file tracking."""
try:
return self._actually_processed_files
except AttributeError:
return []
async def execute(self, arguments: dict[str, Any]) -> list:
"""
Execute the simple tool using the comprehensive flow from old base.py.
This method replicates the proven execution pattern while using SimpleTool hooks.
"""
import json
import logging
from mcp.types import TextContent
from tools.models import ToolOutput
logger = logging.getLogger(f"tools.{self.get_name()}")
try:
# Store arguments for access by helper methods
self._current_arguments = arguments
logger.info(f"🔧 {self.get_name()} tool called with arguments: {list(arguments.keys())}")
# Validate request using the tool's Pydantic model
request_model = self.get_request_model()
request = request_model(**arguments)
logger.debug(f"Request validation successful for {self.get_name()}")
# Validate file paths for security
# This prevents path traversal attacks and ensures proper access control
path_error = self._validate_file_paths(request)
if path_error:
error_output = ToolOutput(
status="error",
content=path_error,
content_type="text",
)
return [TextContent(type="text", text=error_output.model_dump_json())]
# Handle model resolution like old base.py
model_name = self.get_request_model_name(request)
if not model_name:
from config import DEFAULT_MODEL
model_name = DEFAULT_MODEL
# Store the current model name for later use
self._current_model_name = model_name
# Handle model context from arguments (for in-process testing)
if "_model_context" in arguments:
self._model_context = arguments["_model_context"]
logger.debug(f"{self.get_name()}: Using model context from arguments")
else:
# Create model context if not provided
from utils.model_context import ModelContext
self._model_context = ModelContext(model_name)
logger.debug(f"{self.get_name()}: Created model context for {model_name}")
# Get images if present
images = self.get_request_images(request)
continuation_id = self.get_request_continuation_id(request)
# Handle conversation history and prompt preparation
if continuation_id:
# Check if conversation history is already embedded
field_value = self.get_request_prompt(request)
if "=== CONVERSATION HISTORY ===" in field_value:
# Use pre-embedded history
prompt = field_value
logger.debug(f"{self.get_name()}: Using pre-embedded conversation history")
else:
# No embedded history - reconstruct it (for in-process calls)
logger.debug(f"{self.get_name()}: No embedded history found, reconstructing conversation")
# Get thread context
from utils.conversation_memory import add_turn, build_conversation_history, get_thread
thread_context = get_thread(continuation_id)
if thread_context:
# Add user's new input to conversation
user_prompt = self.get_request_prompt(request)
user_files = self.get_request_files(request)
if user_prompt:
add_turn(continuation_id, "user", user_prompt, files=user_files)
# Get updated thread context after adding the turn
thread_context = get_thread(continuation_id)
logger.debug(
f"{self.get_name()}: Retrieved updated thread with {len(thread_context.turns)} turns"
)
# Build conversation history with updated thread context
conversation_history, conversation_tokens = build_conversation_history(
thread_context, self._model_context
)
# Get the base prompt from the tool
base_prompt = await self.prepare_prompt(request)
# Combine with conversation history
if conversation_history:
prompt = f"{conversation_history}\n\n=== NEW USER INPUT ===\n{base_prompt}"
else:
prompt = base_prompt
else:
# Thread not found, prepare normally
logger.warning(f"Thread {continuation_id} not found, preparing prompt normally")
prompt = await self.prepare_prompt(request)
else:
# New conversation, prepare prompt normally
prompt = await self.prepare_prompt(request)
# Add follow-up instructions for new conversations
from server import get_follow_up_instructions
follow_up_instructions = get_follow_up_instructions(0)
prompt = f"{prompt}\n\n{follow_up_instructions}"
logger.debug(
f"Added follow-up instructions for new {self.get_name()} conversation"
) # Validate images if any were provided
if images:
image_validation_error = self._validate_image_limits(
images, model_context=self._model_context, continuation_id=continuation_id
)
if image_validation_error:
return [TextContent(type="text", text=json.dumps(image_validation_error, ensure_ascii=False))]
# Get and validate temperature against model constraints
temperature, temp_warnings = self.get_validated_temperature(request, self._model_context)
# Log any temperature corrections
for warning in temp_warnings:
# Get thinking mode with defaults
logger.warning(warning)
thinking_mode = self.get_request_thinking_mode(request)
if thinking_mode is None:
thinking_mode = self.get_default_thinking_mode()
# Get the provider from model context (clean OOP - no re-fetching)
provider = self._model_context.provider
# Get system prompt for this tool
base_system_prompt = self.get_system_prompt()
language_instruction = self.get_language_instruction()
system_prompt = language_instruction + base_system_prompt
# Generate AI response using the provider
logger.info(f"Sending request to {provider.get_provider_type().value} API for {self.get_name()}")
logger.info(
f"Using model: {self._model_context.model_name} via {provider.get_provider_type().value} provider"
)
# Estimate tokens for logging
from utils.token_utils import estimate_tokens
estimated_tokens = estimate_tokens(prompt)
logger.debug(f"Prompt length: {len(prompt)} characters (~{estimated_tokens:,} tokens)")
# Resolve model capabilities for feature gating
capabilities = self._model_context.capabilities
supports_thinking = capabilities.supports_extended_thinking
# Generate content with provider abstraction
model_response = provider.generate_content(
prompt=prompt,
model_name=self._current_model_name,
system_prompt=system_prompt,
temperature=temperature,
thinking_mode=thinking_mode if supports_thinking else None,
images=images if images else None,
)
logger.info(f"Received response from {provider.get_provider_type().value} API for {self.get_name()}")
# Process the model's response
if model_response.content:
raw_text = model_response.content
# Create model info for conversation tracking
model_info = {
"provider": provider,
"model_name": self._current_model_name,
"model_response": model_response,
}
# Parse response using the same logic as old base.py
tool_output = self._parse_response(raw_text, request, model_info)
logger.info(f"{self.get_name()} tool completed successfully")
else:
# Handle cases where the model couldn't generate a response
metadata = model_response.metadata or {}
finish_reason = metadata.get("finish_reason", "Unknown")
if metadata.get("is_blocked_by_safety"):
# Specific handling for content safety blocks
safety_details = metadata.get("safety_feedback") or "details not provided"
logger.warning(
f"Response blocked by content safety policy for {self.get_name()}. "
f"Reason: {finish_reason}, Details: {safety_details}"
)
tool_output = ToolOutput(
status="error",
content="Your request was blocked by the content safety policy. "
"Please try modifying your prompt.",
content_type="text",
)
else:
# Handle other empty responses - could be legitimate completion or unclear blocking
if finish_reason == "STOP":
# Model completed normally but returned empty content - retry with clarification
logger.info(
f"Model completed with empty response for {self.get_name()}, retrying with clarification"
)
# Retry the same request with modified prompt asking for explicit response
original_prompt = prompt
retry_prompt = f"{original_prompt}\n\nIMPORTANT: Please provide a substantive response. If you cannot respond to the above request, please explain why and suggest alternatives."
try:
retry_response = provider.generate_content(
prompt=retry_prompt,
model_name=self._current_model_name,
system_prompt=system_prompt,
temperature=temperature,
thinking_mode=thinking_mode if supports_thinking else None,
images=images if images else None,
)
if retry_response.content:
# Successful retry - use the retry response
logger.info(f"Retry successful for {self.get_name()}")
raw_text = retry_response.content
# Update model info for the successful retry
model_info = {
"provider": provider,
"model_name": self._current_model_name,
"model_response": retry_response,
}
# Parse the retry response
tool_output = self._parse_response(raw_text, request, model_info)
logger.info(f"{self.get_name()} tool completed successfully after retry")
else:
# Retry also failed - inspect metadata to find out why
retry_metadata = retry_response.metadata or {}
if retry_metadata.get("is_blocked_by_safety"):
# The retry was blocked by safety filters
safety_details = retry_metadata.get("safety_feedback") or "details not provided"
logger.warning(
f"Retry for {self.get_name()} was blocked by content safety policy. "
f"Details: {safety_details}"
)
tool_output = ToolOutput(
status="error",
content="Your request was also blocked by the content safety policy after a retry. "
"Please try rephrasing your prompt significantly.",
content_type="text",
)
else:
# Retry failed for other reasons (e.g., another STOP)
tool_output = ToolOutput(
status="error",
content="The model repeatedly returned empty responses. This may indicate content filtering or a model issue.",
content_type="text",
)
except Exception as retry_error:
logger.warning(f"Retry failed for {self.get_name()}: {retry_error}")
tool_output = ToolOutput(
status="error",
content=f"Model returned empty response and retry failed: {str(retry_error)}",
content_type="text",
)
else:
# Non-STOP finish reasons are likely actual errors
logger.warning(
f"Response blocked or incomplete for {self.get_name()}. Finish reason: {finish_reason}"
)
tool_output = ToolOutput(
status="error",
content=f"Response blocked or incomplete. Finish reason: {finish_reason}",
content_type="text",
)
# Return the tool output as TextContent
return [TextContent(type="text", text=tool_output.model_dump_json())]
except Exception as e:
# Special handling for MCP size check errors
if str(e).startswith("MCP_SIZE_CHECK:"):
# Extract the JSON content after the prefix
json_content = str(e)[len("MCP_SIZE_CHECK:") :]
return [TextContent(type="text", text=json_content)]
logger.error(f"Error in {self.get_name()}: {str(e)}")
error_output = ToolOutput(
status="error",
content=f"Error in {self.get_name()}: {str(e)}",
content_type="text",
)
return [TextContent(type="text", text=error_output.model_dump_json())]
def _parse_response(self, raw_text: str, request, model_info: Optional[dict] = None):
"""
Parse the raw response and format it using the hook method.
This simplified version focuses on the SimpleTool pattern: format the response
using the format_response hook, then handle conversation continuation.
"""
from tools.models import ToolOutput
# Format the response using the hook method
formatted_response = self.format_response(raw_text, request, model_info)
# Handle conversation continuation like old base.py
continuation_id = self.get_request_continuation_id(request)
if continuation_id:
self._record_assistant_turn(continuation_id, raw_text, request, model_info)
# Create continuation offer like old base.py
continuation_data = self._create_continuation_offer(request, model_info)
if continuation_data:
return self._create_continuation_offer_response(formatted_response, continuation_data, request, model_info)
else:
# Build metadata with model and provider info for success response
metadata = {}
if model_info:
model_name = model_info.get("model_name")
if model_name:
metadata["model_used"] = model_name
provider = model_info.get("provider")
if provider:
# Handle both provider objects and string values
if isinstance(provider, str):
metadata["provider_used"] = provider
else:
try:
metadata["provider_used"] = provider.get_provider_type().value
except AttributeError:
# Fallback if provider doesn't have get_provider_type method
metadata["provider_used"] = str(provider)
return ToolOutput(
status="success",
content=formatted_response,
content_type="text",
metadata=metadata if metadata else None,
)
def _create_continuation_offer(self, request, model_info: Optional[dict] = None):
"""Create continuation offer following old base.py pattern"""
continuation_id = self.get_request_continuation_id(request)
try:
from utils.conversation_memory import create_thread, get_thread
if continuation_id:
# Existing conversation
thread_context = get_thread(continuation_id)
if thread_context and thread_context.turns:
turn_count = len(thread_context.turns)
from utils.conversation_memory import MAX_CONVERSATION_TURNS
if turn_count >= MAX_CONVERSATION_TURNS - 1:
return None # No more turns allowed
remaining_turns = MAX_CONVERSATION_TURNS - turn_count - 1
return {
"continuation_id": continuation_id,
"remaining_turns": remaining_turns,
"note": f"Claude can continue this conversation for {remaining_turns} more exchanges.",
}
else:
# New conversation - create thread and offer continuation
# Convert request to dict for initial_context
initial_request_dict = self.get_request_as_dict(request)
new_thread_id = create_thread(tool_name=self.get_name(), initial_request=initial_request_dict)
# Add the initial user turn to the new thread
from utils.conversation_memory import MAX_CONVERSATION_TURNS, add_turn
user_prompt = self.get_request_prompt(request)
user_files = self.get_request_files(request)
user_images = self.get_request_images(request)
# Add user's initial turn
add_turn(
new_thread_id, "user", user_prompt, files=user_files, images=user_images, tool_name=self.get_name()
)
return {
"continuation_id": new_thread_id,
"remaining_turns": MAX_CONVERSATION_TURNS - 1,
"note": f"Claude can continue this conversation for {MAX_CONVERSATION_TURNS - 1} more exchanges.",
}
except Exception:
return None
def _create_continuation_offer_response(
self, content: str, continuation_data: dict, request, model_info: Optional[dict] = None
):
"""Create response with continuation offer following old base.py pattern"""
from tools.models import ContinuationOffer, ToolOutput
try:
if not self.get_request_continuation_id(request):
self._record_assistant_turn(
continuation_data["continuation_id"],
content,
request,
model_info,
)
continuation_offer = ContinuationOffer(
continuation_id=continuation_data["continuation_id"],
note=continuation_data["note"],
remaining_turns=continuation_data["remaining_turns"],
)
# Build metadata with model and provider info
metadata = {"tool_name": self.get_name(), "conversation_ready": True}
if model_info:
model_name = model_info.get("model_name")
if model_name:
metadata["model_used"] = model_name
provider = model_info.get("provider")
if provider:
# Handle both provider objects and string values
if isinstance(provider, str):
metadata["provider_used"] = provider
else:
try:
metadata["provider_used"] = provider.get_provider_type().value
except AttributeError:
# Fallback if provider doesn't have get_provider_type method
metadata["provider_used"] = str(provider)
return ToolOutput(
status="continuation_available",
content=content,
content_type="text",
continuation_offer=continuation_offer,
metadata=metadata,
)
except Exception:
# Fallback to simple success if continuation offer fails
return ToolOutput(status="success", content=content, content_type="text")
def _record_assistant_turn(
self, continuation_id: str, response_text: str, request, model_info: Optional[dict]
) -> None:
"""Persist an assistant response in conversation memory."""
if not continuation_id:
return
from utils.conversation_memory import add_turn
model_provider = None
model_name = None
model_metadata = None
if model_info:
provider = model_info.get("provider")
if provider:
if isinstance(provider, str):
model_provider = provider
else:
try:
model_provider = provider.get_provider_type().value
except AttributeError:
model_provider = str(provider)
model_name = model_info.get("model_name")
model_response = model_info.get("model_response")
if model_response:
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
add_turn(
continuation_id,
"assistant",
response_text,
files=self.get_request_files(request),
images=self.get_request_images(request),
tool_name=self.get_name(),
model_provider=model_provider,
model_name=model_name,
model_metadata=model_metadata,
)
# Convenience methods for common tool patterns
def build_standard_prompt(
self, system_prompt: str, user_content: str, request, file_context_title: str = "CONTEXT FILES"
) -> str:
"""
Build a standard prompt with system prompt, user content, and optional files.
This is a convenience method that handles the common pattern of:
1. Adding file content if present
2. Checking token limits
3. Adding web search instructions
4. Combining everything into a well-formatted prompt
Args:
system_prompt: The system prompt for the tool
user_content: The main user request/content
request: The validated request object
file_context_title: Title for the file context section
Returns:
Complete formatted prompt ready for the AI model
"""
# Add context files if provided
files = self.get_request_files(request)
if files:
file_content, processed_files = self._prepare_file_content_for_prompt(
files,
self.get_request_continuation_id(request),
"Context files",
model_context=getattr(self, "_model_context", None),
)
self._actually_processed_files = processed_files
if file_content:
user_content = f"{user_content}\n\n=== {file_context_title} ===\n{file_content}\n=== END CONTEXT ===="
# Check token limits - only validate original user prompt, not conversation history
content_to_validate = self.get_prompt_content_for_size_validation(user_content)
self._validate_token_limit(content_to_validate, "Content")
# Add standardized web search guidance
websearch_instruction = self.get_websearch_instruction(True, self.get_websearch_guidance())
# Combine system prompt with user content
full_prompt = f"""{system_prompt}{websearch_instruction}
=== USER REQUEST ===
{user_content}
=== END REQUEST ===
Please provide a thoughtful, comprehensive response:"""
return full_prompt
def get_prompt_content_for_size_validation(self, user_content: str) -> str:
"""
Override to use original user prompt for size validation when conversation history is embedded.
When server.py embeds conversation history into the prompt field, it also stores
the original user prompt in _original_user_prompt. We use that for size validation
to avoid incorrectly triggering size limits due to conversation history.
Args:
user_content: The user content (may include conversation history)
Returns:
The original user prompt if available, otherwise the full user content
"""
# Check if we have the current arguments from execute() method
current_args = getattr(self, "_current_arguments", None)
if current_args:
# If server.py embedded conversation history, it stores original prompt separately
original_user_prompt = current_args.get("_original_user_prompt")
if original_user_prompt is not None:
# Use original user prompt for size validation (excludes conversation history)
return original_user_prompt
# Fallback to default behavior (validate full user content)
return user_content
def get_websearch_guidance(self) -> Optional[str]:
"""
Return tool-specific web search guidance.
Override this to provide tool-specific guidance for when web searches
would be helpful. Return None to use the default guidance.
Returns:
Tool-specific web search guidance or None for default
"""
return None
def handle_prompt_file_with_fallback(self, request) -> str:
"""
Handle prompt.txt files with fallback to request field.
This is a convenience method for tools that accept prompts either
as a field or as a prompt.txt file. It handles the extraction
and validation automatically.
Args:
request: The validated request object
Returns:
The effective prompt content
Raises:
ValueError: If prompt is too large for MCP transport
"""
# Check for prompt.txt in files
files = self.get_request_files(request)
if files:
prompt_content, updated_files = self.handle_prompt_file(files)
# Update request files list if needed
if updated_files is not None:
self.set_request_files(request, updated_files)
else:
prompt_content = None
# Use prompt.txt content if available, otherwise use the prompt field
user_content = prompt_content if prompt_content else self.get_request_prompt(request)
# Check user input size at MCP transport boundary (excluding conversation history)
validation_content = self.get_prompt_content_for_size_validation(user_content)
size_check = self.check_prompt_size(validation_content)
if size_check:
from tools.models import ToolOutput
raise ValueError(f"MCP_SIZE_CHECK:{ToolOutput(**size_check).model_dump_json()}")
return user_content
def get_chat_style_websearch_guidance(self) -> str:
"""
Get Chat tool-style web search guidance.
Returns web search guidance that matches the original Chat tool pattern.
This is useful for tools that want to maintain the same search behavior.
Returns:
Web search guidance text
"""
return """When discussing topics, consider if searches for these would help:
- Documentation for any technologies or concepts mentioned
- Current best practices and patterns
- Recent developments or updates
- Community discussions and solutions"""
def supports_custom_request_model(self) -> bool:
"""
Indicate whether this tool supports custom request models.
Simple tools support custom request models by default. Tools that override
get_request_model() to return something other than ToolRequest should
return True here.
Returns:
True if the tool uses a custom request model
"""
return self.get_request_model() != ToolRequest
def _validate_file_paths(self, request) -> Optional[str]:
"""
Validate that all file paths in the request are absolute paths.
This is a security measure to prevent path traversal attacks and ensure
proper access control. All file paths must be absolute (starting with '/').
Args:
request: The validated request object
Returns:
Optional[str]: Error message if validation fails, None if all paths are valid
"""
import os
# Check if request has 'files' attribute (used by most tools)
files = self.get_request_files(request)
if files:
for file_path in files:
if not os.path.isabs(file_path):
return (
f"Error: All file paths must be FULL absolute paths to real files / folders - DO NOT SHORTEN. "
f"Received relative path: {file_path}\n"
f"Please provide the full absolute path starting with '/' (must be FULL absolute paths to real files / folders - DO NOT SHORTEN)"
)
return None
def prepare_chat_style_prompt(self, request, system_prompt: str = None) -> str:
"""
Prepare a prompt using Chat tool-style patterns.
This convenience method replicates the Chat tool's prompt preparation logic:
1. Handle prompt.txt file if present
2. Add file context with specific formatting
3. Add web search guidance
4. Format with system prompt
Args:
request: The validated request object
system_prompt: System prompt to use (uses get_system_prompt() if None)
Returns:
Complete formatted prompt
"""
# Use provided system prompt or get from tool
if system_prompt is None:
system_prompt = self.get_system_prompt()
# Get user content (handles prompt.txt files)
user_content = self.handle_prompt_file_with_fallback(request)
# Build standard prompt with Chat-style web search guidance
websearch_guidance = self.get_chat_style_websearch_guidance()
# Override the websearch guidance temporarily
original_guidance = self.get_websearch_guidance
self.get_websearch_guidance = lambda: websearch_guidance
try:
full_prompt = self.build_standard_prompt(system_prompt, user_content, request, "CONTEXT FILES")
finally:
# Restore original guidance method
self.get_websearch_guidance = original_guidance
return full_prompt

368
tools/version.py Normal file
View file

@ -0,0 +1,368 @@
"""
Version Tool - Display Zen MCP Server version and system information
This tool provides version information about the Zen MCP Server including
version number, last update date, author, and basic system information.
It also checks for updates from the GitHub repository.
"""
import logging
import platform
import re
import sys
from pathlib import Path
from typing import Any, Optional
try:
from urllib.error import HTTPError, URLError
from urllib.request import urlopen
HAS_URLLIB = True
except ImportError:
HAS_URLLIB = False
from mcp.types import TextContent
from config import __author__, __updated__, __version__
from tools.models import ToolModelCategory, ToolOutput
from tools.shared.base_models import ToolRequest
from tools.shared.base_tool import BaseTool
logger = logging.getLogger(__name__)
def parse_version(version_str: str) -> tuple[int, int, int]:
"""
Parse version string to tuple of integers for comparison.
Args:
version_str: Version string like "5.5.5"
Returns:
Tuple of (major, minor, patch) as integers
"""
try:
parts = version_str.strip().split(".")
if len(parts) >= 3:
return (int(parts[0]), int(parts[1]), int(parts[2]))
elif len(parts) == 2:
return (int(parts[0]), int(parts[1]), 0)
elif len(parts) == 1:
return (int(parts[0]), 0, 0)
else:
return (0, 0, 0)
except (ValueError, IndexError):
return (0, 0, 0)
def compare_versions(current: str, remote: str) -> int:
"""
Compare two version strings.
Args:
current: Current version string
remote: Remote version string
Returns:
-1 if current < remote (update available)
0 if current == remote (up to date)
1 if current > remote (ahead of remote)
"""
current_tuple = parse_version(current)
remote_tuple = parse_version(remote)
if current_tuple < remote_tuple:
return -1
elif current_tuple > remote_tuple:
return 1
else:
return 0
def fetch_github_version() -> Optional[tuple[str, str]]:
"""
Fetch the latest version information from GitHub repository.
Returns:
Tuple of (version, last_updated) if successful, None if failed
"""
if not HAS_URLLIB:
logger.warning("urllib not available, cannot check for updates")
return None
github_url = "https://raw.githubusercontent.com/BeehiveInnovations/zen-mcp-server/main/config.py"
try:
# Set a 10-second timeout
with urlopen(github_url, timeout=10) as response:
if response.status != 200:
logger.warning(f"HTTP error while checking GitHub: {response.status}")
return None
content = response.read().decode("utf-8")
# Extract version using regex
version_match = re.search(r'__version__\s*=\s*["\']([^"\']+)["\']', content)
updated_match = re.search(r'__updated__\s*=\s*["\']([^"\']+)["\']', content)
if version_match:
remote_version = version_match.group(1)
remote_updated = updated_match.group(1) if updated_match else "Unknown"
return (remote_version, remote_updated)
else:
logger.warning("Could not parse version from GitHub config.py")
return None
except HTTPError as e:
logger.warning(f"HTTP error while checking GitHub: {e.code}")
return None
except URLError as e:
logger.warning(f"URL error while checking GitHub: {e.reason}")
return None
except Exception as e:
logger.warning(f"Error checking GitHub for updates: {e}")
return None
class VersionTool(BaseTool):
"""
Tool for displaying Zen MCP Server version and system information.
This tool provides:
- Current server version
- Last update date
- Author information
- Python version
- Platform information
"""
def get_name(self) -> str:
return "version"
def get_description(self) -> str:
return "Get server version, configuration details, and list of available tools."
def get_input_schema(self) -> dict[str, Any]:
"""Return the JSON schema for the tool's input"""
return {
"type": "object",
"properties": {"model": {"type": "string", "description": "Model to use (ignored by version tool)"}},
"required": [],
}
def get_annotations(self) -> Optional[dict[str, Any]]:
"""Return tool annotations indicating this is a read-only tool"""
return {"readOnlyHint": True}
def get_system_prompt(self) -> str:
"""No AI model needed for this tool"""
return ""
def get_request_model(self):
"""Return the Pydantic model for request validation."""
return ToolRequest
def requires_model(self) -> bool:
return False
async def prepare_prompt(self, request: ToolRequest) -> str:
"""Not used for this utility tool"""
return ""
def format_response(self, response: str, request: ToolRequest, model_info: dict = None) -> str:
"""Not used for this utility tool"""
return response
async def execute(self, arguments: dict[str, Any]) -> list[TextContent]:
"""
Display Zen MCP Server version and system information.
This overrides the base class execute to provide direct output without AI model calls.
Args:
arguments: Standard tool arguments (none required)
Returns:
Formatted version and system information
"""
output_lines = ["# Zen MCP Server Version\n"]
# Server version information
output_lines.append("## Server Information")
output_lines.append(f"**Current Version**: {__version__}")
output_lines.append(f"**Last Updated**: {__updated__}")
output_lines.append(f"**Author**: {__author__}")
model_selection_metadata = {"mode": "unknown", "default_model": None}
model_selection_display = "Model selection status unavailable"
# Model selection configuration
try:
from config import DEFAULT_MODEL
from tools.shared.base_tool import BaseTool
auto_mode = BaseTool.is_effective_auto_mode(self)
if auto_mode:
output_lines.append(
"**Model Selection**: Auto model selection mode (call `listmodels` to inspect options)"
)
model_selection_metadata = {"mode": "auto", "default_model": DEFAULT_MODEL}
model_selection_display = "Auto model selection (use `listmodels` for options)"
else:
output_lines.append(f"**Model Selection**: Default model set to `{DEFAULT_MODEL}`")
model_selection_metadata = {"mode": "default", "default_model": DEFAULT_MODEL}
model_selection_display = f"Default model: `{DEFAULT_MODEL}`"
except Exception as exc:
logger.debug(f"Could not determine model selection mode: {exc}")
output_lines.append("")
output_lines.append("## Quick Summary — relay everything below")
output_lines.append(f"- Version `{__version__}` (updated {__updated__})")
output_lines.append(f"- {model_selection_display}")
output_lines.append("- Run `listmodels` for the complete model catalog and capabilities")
output_lines.append("")
# Try to get client information
try:
# We need access to the server instance
# This is a bit hacky but works for now
import server as server_module
from utils.client_info import format_client_info, get_client_info_from_context
client_info = get_client_info_from_context(server_module.server)
if client_info:
formatted = format_client_info(client_info)
output_lines.append(f"**Connected Client**: {formatted}")
except Exception as e:
logger.debug(f"Could not get client info: {e}")
# Get the current working directory (MCP server location)
current_path = Path.cwd()
output_lines.append(f"**Installation Path**: `{current_path}`")
output_lines.append("")
output_lines.append("## Agent Reporting Guidance")
output_lines.append(
"Agents MUST report: version, model-selection status, configured providers, and available-model count."
)
output_lines.append("Repeat the quick-summary bullets verbatim in your reply.")
output_lines.append("Reference `listmodels` when users ask about model availability or capabilities.")
output_lines.append("")
# Check for updates from GitHub
output_lines.append("## Update Status")
try:
github_info = fetch_github_version()
if github_info:
remote_version, remote_updated = github_info
comparison = compare_versions(__version__, remote_version)
output_lines.append(f"**Latest Version (GitHub)**: {remote_version}")
output_lines.append(f"**Latest Updated**: {remote_updated}")
if comparison < 0:
# Update available
output_lines.append("")
output_lines.append("🚀 **UPDATE AVAILABLE!**")
output_lines.append(
f"Your version `{__version__}` is older than the latest version `{remote_version}`"
)
output_lines.append("")
output_lines.append("**To update:**")
output_lines.append("```bash")
output_lines.append(f"cd {current_path}")
output_lines.append("git pull")
output_lines.append("```")
output_lines.append("")
output_lines.append("*Note: Restart your session after updating to use the new version.*")
elif comparison == 0:
# Up to date
output_lines.append("")
output_lines.append("✅ **UP TO DATE**")
output_lines.append("You are running the latest version.")
else:
# Ahead of remote (development version)
output_lines.append("")
output_lines.append("🔬 **DEVELOPMENT VERSION**")
output_lines.append(
f"Your version `{__version__}` is ahead of the published version `{remote_version}`"
)
output_lines.append("You may be running a development or custom build.")
else:
output_lines.append("❌ **Could not check for updates**")
output_lines.append("Unable to connect to GitHub or parse version information.")
output_lines.append("Check your internet connection or try again later.")
except Exception as e:
logger.error(f"Error during version check: {e}")
output_lines.append("❌ **Error checking for updates**")
output_lines.append(f"Error: {str(e)}")
output_lines.append("")
# Configuration information
output_lines.append("## Configuration")
# Check for configured providers
try:
from providers.registry import ModelProviderRegistry
from providers.shared import ProviderType
provider_status = []
# Check each provider type
provider_types = [
ProviderType.GOOGLE,
ProviderType.OPENAI,
ProviderType.XAI,
ProviderType.DIAL,
ProviderType.OPENROUTER,
ProviderType.CUSTOM,
]
provider_names = ["Google Gemini", "OpenAI", "X.AI", "DIAL", "OpenRouter", "Custom/Local"]
for provider_type, provider_name in zip(provider_types, provider_names):
provider = ModelProviderRegistry.get_provider(provider_type)
status = "✅ Configured" if provider is not None else "❌ Not configured"
provider_status.append(f"- **{provider_name}**: {status}")
output_lines.append("**Providers**:")
output_lines.extend(provider_status)
# Get total available models
try:
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
output_lines.append(f"\n\n**Available Models**: {len(available_models)}")
except Exception:
output_lines.append("\n\n**Available Models**: Unknown")
except Exception as e:
logger.warning(f"Error checking provider configuration: {e}")
output_lines.append("\n\n**Providers**: Error checking configuration")
output_lines.append("")
# Format output
content = "\n".join(output_lines)
tool_output = ToolOutput(
status="success",
content=content,
content_type="text",
metadata={
"tool_name": self.name,
"server_version": __version__,
"last_updated": __updated__,
"python_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
"platform": f"{platform.system()} {platform.release()}",
"model_selection_mode": model_selection_metadata["mode"],
"default_model": model_selection_metadata["default_model"],
},
)
return [TextContent(type="text", text=tool_output.model_dump_json())]
def get_model_category(self) -> ToolModelCategory:
"""Return the model category for this tool."""
return ToolModelCategory.FAST_RESPONSE # Simple version info, no AI needed

View file

@ -0,0 +1,22 @@
"""
Workflow tools for Zen MCP.
Workflow tools follow a multi-step pattern with forced pauses between steps
to encourage thorough investigation and analysis. They inherit from WorkflowTool
which combines BaseTool with BaseWorkflowMixin.
Available workflow tools:
- debug: Systematic investigation and root cause analysis
- planner: Sequential planning (special case - no AI calls)
- analyze: Code analysis workflow
- codereview: Code review workflow
- precommit: Pre-commit validation workflow
- refactor: Refactoring analysis workflow
- thinkdeep: Deep thinking workflow
"""
from .base import WorkflowTool
from .schema_builders import WorkflowSchemaBuilder
from .workflow_mixin import BaseWorkflowMixin
__all__ = ["WorkflowTool", "WorkflowSchemaBuilder", "BaseWorkflowMixin"]

444
tools/workflow/base.py Normal file
View file

@ -0,0 +1,444 @@
"""
Base class for workflow MCP tools.
Workflow tools follow a multi-step pattern:
1. Claude calls tool with work step data
2. Tool tracks findings and progress
3. Tool forces Claude to pause and investigate between steps
4. Once work is complete, tool calls external AI model for expert analysis
5. Tool returns structured response combining investigation + expert analysis
They combine BaseTool's capabilities with BaseWorkflowMixin's workflow functionality
and use SchemaBuilder for consistent schema generation.
"""
from abc import abstractmethod
from typing import Any, Optional
from tools.shared.base_models import WorkflowRequest
from tools.shared.base_tool import BaseTool
from .schema_builders import WorkflowSchemaBuilder
from .workflow_mixin import BaseWorkflowMixin
class WorkflowTool(BaseTool, BaseWorkflowMixin):
"""
Base class for workflow (multi-step) tools.
Workflow tools perform systematic multi-step work with expert analysis.
They benefit from:
- Automatic workflow orchestration from BaseWorkflowMixin
- Automatic schema generation using SchemaBuilder
- Inherited conversation handling and file processing from BaseTool
- Progress tracking with ConsolidatedFindings
- Expert analysis integration
To create a workflow tool:
1. Inherit from WorkflowTool
2. Tool name is automatically provided by get_name() method
3. Implement get_required_actions() for step guidance
4. Implement should_call_expert_analysis() for completion criteria
5. Implement prepare_expert_analysis_context() for expert prompts
6. Optionally implement get_tool_fields() for additional fields
7. Optionally override workflow behavior methods
Example:
class DebugTool(WorkflowTool):
# get_name() is inherited from BaseTool
def get_tool_fields(self) -> Dict[str, Dict[str, Any]]:
return {
"hypothesis": {
"type": "string",
"description": "Current theory about the issue",
}
}
def get_required_actions(
self, step_number: int, confidence: str, findings: str, total_steps: int
) -> List[str]:
return ["Examine relevant code files", "Trace execution flow", "Check error logs"]
def should_call_expert_analysis(self, consolidated_findings) -> bool:
return len(consolidated_findings.relevant_files) > 0
"""
def __init__(self):
"""Initialize WorkflowTool with proper multiple inheritance."""
BaseTool.__init__(self)
BaseWorkflowMixin.__init__(self)
def get_tool_fields(self) -> dict[str, dict[str, Any]]:
"""
Return tool-specific field definitions beyond the standard workflow fields.
Workflow tools automatically get all standard workflow fields:
- step, step_number, total_steps, next_step_required
- findings, files_checked, relevant_files, relevant_context
- issues_found, confidence, hypothesis, backtrack_from_step
- plus common fields (model, temperature, etc.)
Override this method to add additional tool-specific fields.
Returns:
Dict mapping field names to JSON schema objects
Example:
return {
"severity_filter": {
"type": "string",
"enum": ["low", "medium", "high"],
"description": "Minimum severity level to report",
}
}
"""
return {}
def get_required_fields(self) -> list[str]:
"""
Return additional required fields beyond the standard workflow requirements.
Workflow tools automatically require:
- step, step_number, total_steps, next_step_required, findings
- model (if in auto mode)
Override this to add additional required fields.
Returns:
List of additional required field names
"""
return []
def get_annotations(self) -> Optional[dict[str, Any]]:
"""
Return tool annotations. Workflow tools are read-only by default.
All workflow tools perform analysis and investigation without modifying
the environment. They may call external AI models for expert analysis,
but they don't write files or make system changes.
Override this method if your workflow tool needs different annotations.
Returns:
Dictionary with readOnlyHint set to True
"""
return {"readOnlyHint": True}
def get_input_schema(self) -> dict[str, Any]:
"""
Generate the complete input schema using SchemaBuilder.
This method automatically combines:
- Standard workflow fields (step, findings, etc.)
- Common fields (temperature, thinking_mode, etc.)
- Model field with proper auto-mode handling
- Tool-specific fields from get_tool_fields()
- Required fields from get_required_fields()
Returns:
Complete JSON schema for the workflow tool
"""
return WorkflowSchemaBuilder.build_schema(
tool_specific_fields=self.get_tool_fields(),
required_fields=self.get_required_fields(),
model_field_schema=self.get_model_field_schema(),
auto_mode=self.is_effective_auto_mode(),
tool_name=self.get_name(),
)
def get_workflow_request_model(self):
"""
Return the workflow request model class.
Workflow tools use WorkflowRequest by default, which includes
all the standard workflow fields. Override this if your tool
needs a custom request model.
"""
return WorkflowRequest
# Implement the abstract method from BaseWorkflowMixin
def get_work_steps(self, request) -> list[str]:
"""
Default implementation - workflow tools typically don't need predefined steps.
The workflow is driven by Claude's investigation process rather than
predefined steps. Override this if your tool needs specific step guidance.
"""
return []
# Default implementations for common workflow patterns
def get_standard_required_actions(self, step_number: int, confidence: str, base_actions: list[str]) -> list[str]:
"""
Helper method to generate standard required actions based on confidence and step.
This provides common patterns that most workflow tools can use:
- Early steps: broad exploration
- Low confidence: deeper investigation
- Medium/high confidence: verification and confirmation
Args:
step_number: Current step number
confidence: Current confidence level
base_actions: Tool-specific base actions
Returns:
List of required actions appropriate for the current state
"""
if step_number == 1:
# Initial investigation
return [
"Search for code related to the reported issue or symptoms",
"Examine relevant files and understand the current implementation",
"Understand the project structure and locate relevant modules",
"Identify how the affected functionality is supposed to work",
]
elif confidence in ["exploring", "low"]:
# Need deeper investigation
return base_actions + [
"Trace method calls and data flow through the system",
"Check for edge cases, boundary conditions, and assumptions in the code",
"Look for related configuration, dependencies, or external factors",
]
elif confidence in ["medium", "high"]:
# Close to solution - need confirmation
return base_actions + [
"Examine the exact code sections where you believe the issue occurs",
"Trace the execution path that leads to the failure",
"Verify your hypothesis with concrete code evidence",
"Check for any similar patterns elsewhere in the codebase",
]
else:
# General continued investigation
return base_actions + [
"Continue examining the code paths identified in your hypothesis",
"Gather more evidence using appropriate investigation tools",
"Test edge cases and boundary conditions",
"Look for patterns that confirm or refute your theory",
]
def should_call_expert_analysis_default(self, consolidated_findings) -> bool:
"""
Default implementation for expert analysis decision.
This provides a reasonable default that most workflow tools can use:
- Call expert analysis if we have relevant files or significant findings
- Skip if confidence is "certain" (handled by the workflow mixin)
Override this for tool-specific logic.
Args:
consolidated_findings: The consolidated findings from all work steps
Returns:
True if expert analysis should be called
"""
# Call expert analysis if we have relevant files or substantial findings
return (
len(consolidated_findings.relevant_files) > 0
or len(consolidated_findings.findings) >= 2
or len(consolidated_findings.issues_found) > 0
)
def prepare_standard_expert_context(
self, consolidated_findings, initial_description: str, context_sections: dict[str, str] = None
) -> str:
"""
Helper method to prepare standard expert analysis context.
This provides a common structure that most workflow tools can use,
with the ability to add tool-specific sections.
Args:
consolidated_findings: The consolidated findings from all work steps
initial_description: Description of the initial request/issue
context_sections: Optional additional sections to include
Returns:
Formatted context string for expert analysis
"""
context_parts = [f"=== ISSUE DESCRIPTION ===\n{initial_description}\n=== END DESCRIPTION ==="]
# Add work progression
if consolidated_findings.findings:
findings_text = "\n".join(consolidated_findings.findings)
context_parts.append(f"\n=== INVESTIGATION FINDINGS ===\n{findings_text}\n=== END FINDINGS ===")
# Add relevant methods if available
if consolidated_findings.relevant_context:
methods_text = "\n".join(f"- {method}" for method in consolidated_findings.relevant_context)
context_parts.append(f"\n=== RELEVANT METHODS/FUNCTIONS ===\n{methods_text}\n=== END METHODS ===")
# Add hypothesis evolution if available
if consolidated_findings.hypotheses:
hypotheses_text = "\n".join(
f"Step {h['step']} ({h['confidence']} confidence): {h['hypothesis']}"
for h in consolidated_findings.hypotheses
)
context_parts.append(f"\n=== HYPOTHESIS EVOLUTION ===\n{hypotheses_text}\n=== END HYPOTHESES ===")
# Add issues found if available
if consolidated_findings.issues_found:
issues_text = "\n".join(
f"[{issue.get('severity', 'unknown').upper()}] {issue.get('description', 'No description')}"
for issue in consolidated_findings.issues_found
)
context_parts.append(f"\n=== ISSUES IDENTIFIED ===\n{issues_text}\n=== END ISSUES ===")
# Add tool-specific sections
if context_sections:
for section_title, section_content in context_sections.items():
context_parts.append(
f"\n=== {section_title.upper()} ===\n{section_content}\n=== END {section_title.upper()} ==="
)
return "\n".join(context_parts)
def handle_completion_without_expert_analysis(
self, request, consolidated_findings, initial_description: str = None
) -> dict[str, Any]:
"""
Generic handler for completion when expert analysis is not needed.
This provides a standard response format for when the tool determines
that external expert analysis is not required. All workflow tools
can use this generic implementation or override for custom behavior.
Args:
request: The workflow request object
consolidated_findings: The consolidated findings from all work steps
initial_description: Optional initial description (defaults to request.step)
Returns:
Dictionary with completion response data
"""
# Prepare work summary using inheritance hook
work_summary = self.prepare_work_summary()
return {
"status": self.get_completion_status(),
self.get_completion_data_key(): {
"initial_request": initial_description or request.step,
"steps_taken": len(consolidated_findings.findings),
"files_examined": list(consolidated_findings.files_checked),
"relevant_files": list(consolidated_findings.relevant_files),
"relevant_context": list(consolidated_findings.relevant_context),
"work_summary": work_summary,
"final_analysis": self.get_final_analysis_from_request(request),
"confidence_level": self.get_confidence_level(request),
},
"next_steps": self.get_completion_message(),
"skip_expert_analysis": True,
"expert_analysis": {
"status": self.get_skip_expert_analysis_status(),
"reason": self.get_skip_reason(),
},
}
# Inheritance hooks for customization
def prepare_work_summary(self) -> str:
"""
Prepare a summary of the work performed. Override for custom summaries.
Default implementation provides a basic summary.
"""
try:
return self._prepare_work_summary()
except AttributeError:
try:
return f"Completed {len(self.work_history)} work steps"
except AttributeError:
return "Completed 0 work steps"
def get_completion_status(self) -> str:
"""Get the status to use when completing without expert analysis."""
return "high_confidence_completion"
def get_completion_data_key(self) -> str:
"""Get the key name for completion data in the response."""
return f"complete_{self.get_name()}"
def get_final_analysis_from_request(self, request) -> Optional[str]:
"""Extract final analysis from request. Override for tool-specific extraction."""
try:
return request.hypothesis
except AttributeError:
return None
def get_confidence_level(self, request) -> str:
"""Get confidence level from request. Override for tool-specific logic."""
try:
return request.confidence or "high"
except AttributeError:
return "high"
def get_completion_message(self) -> str:
"""Get completion message. Override for tool-specific messaging."""
return (
f"{self.get_name().capitalize()} complete with high confidence. You have identified the exact "
"analysis and solution. MANDATORY: Present the user with the results "
"and proceed with implementing the solution without requiring further "
"consultation. Focus on the precise, actionable steps needed."
)
def get_skip_reason(self) -> str:
"""Get reason for skipping expert analysis. Override for tool-specific reasons."""
return f"{self.get_name()} completed with sufficient confidence"
def get_skip_expert_analysis_status(self) -> str:
"""Get status for skipped expert analysis. Override for tool-specific status."""
return "skipped_by_tool_design"
def is_continuation_workflow(self, request) -> bool:
"""
Check if this is a continuation workflow that should skip multi-step investigation.
When continuation_id is provided, the workflow typically continues from a previous
conversation and should go directly to expert analysis rather than starting a new
multi-step investigation.
Args:
request: The workflow request object
Returns:
True if this is a continuation that should skip multi-step workflow
"""
continuation_id = self.get_request_continuation_id(request)
return bool(continuation_id)
# Abstract methods that must be implemented by specific workflow tools
# (These are inherited from BaseWorkflowMixin and must be implemented)
@abstractmethod
def get_required_actions(
self, step_number: int, confidence: str, findings: str, total_steps: int, request=None
) -> list[str]:
"""Define required actions for each work phase.
Args:
step_number: Current step number
confidence: Current confidence level
findings: Current findings text
total_steps: Total estimated steps
request: Optional request object for continuation-aware decisions
Returns:
List of required actions for the current step
"""
pass
@abstractmethod
def should_call_expert_analysis(self, consolidated_findings) -> bool:
"""Decide when to call external model based on tool-specific criteria"""
pass
@abstractmethod
def prepare_expert_analysis_context(self, consolidated_findings) -> str:
"""Prepare context for external model call"""
pass
# Default execute method - delegates to workflow
async def execute(self, arguments: dict[str, Any]) -> list:
"""Execute the workflow tool - delegates to BaseWorkflowMixin."""
return await self.execute_workflow(arguments)

View file

@ -0,0 +1,174 @@
"""
Schema builders for workflow MCP tools.
This module provides workflow-specific schema generation functionality,
keeping workflow concerns separated from simple tool concerns.
"""
from typing import Any
from ..shared.base_models import WORKFLOW_FIELD_DESCRIPTIONS
from ..shared.schema_builders import SchemaBuilder
class WorkflowSchemaBuilder:
"""
Schema builder for workflow MCP tools.
This class extends the base SchemaBuilder with workflow-specific fields
and schema generation logic, maintaining separation of concerns.
"""
# Workflow-specific field schemas
WORKFLOW_FIELD_SCHEMAS = {
"step": {
"type": "string",
"description": WORKFLOW_FIELD_DESCRIPTIONS["step"],
},
"step_number": {
"type": "integer",
"minimum": 1,
"description": WORKFLOW_FIELD_DESCRIPTIONS["step_number"],
},
"total_steps": {
"type": "integer",
"minimum": 1,
"description": WORKFLOW_FIELD_DESCRIPTIONS["total_steps"],
},
"next_step_required": {
"type": "boolean",
"description": WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"],
},
"findings": {
"type": "string",
"description": WORKFLOW_FIELD_DESCRIPTIONS["findings"],
},
"files_checked": {
"type": "array",
"items": {"type": "string"},
"description": WORKFLOW_FIELD_DESCRIPTIONS["files_checked"],
},
"relevant_files": {
"type": "array",
"items": {"type": "string"},
"description": WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"],
},
"relevant_context": {
"type": "array",
"items": {"type": "string"},
"description": WORKFLOW_FIELD_DESCRIPTIONS["relevant_context"],
},
"issues_found": {
"type": "array",
"items": {"type": "object"},
"description": WORKFLOW_FIELD_DESCRIPTIONS["issues_found"],
},
"confidence": {
"type": "string",
"enum": ["exploring", "low", "medium", "high", "very_high", "almost_certain", "certain"],
"description": WORKFLOW_FIELD_DESCRIPTIONS["confidence"],
},
"hypothesis": {
"type": "string",
"description": WORKFLOW_FIELD_DESCRIPTIONS["hypothesis"],
},
"backtrack_from_step": {
"type": "integer",
"minimum": 1,
"description": WORKFLOW_FIELD_DESCRIPTIONS["backtrack_from_step"],
},
"use_assistant_model": {
"type": "boolean",
"default": True,
"description": WORKFLOW_FIELD_DESCRIPTIONS["use_assistant_model"],
},
}
@staticmethod
def build_schema(
tool_specific_fields: dict[str, dict[str, Any]] = None,
required_fields: list[str] = None,
model_field_schema: dict[str, Any] = None,
auto_mode: bool = False,
tool_name: str = None,
excluded_workflow_fields: list[str] = None,
excluded_common_fields: list[str] = None,
require_model: bool = False,
) -> dict[str, Any]:
"""
Build complete schema for workflow tools.
Args:
tool_specific_fields: Additional fields specific to the tool
required_fields: List of required field names (beyond workflow defaults)
model_field_schema: Schema for the model field
auto_mode: Whether the tool is in auto mode (affects model requirement)
tool_name: Name of the tool (for schema title)
excluded_workflow_fields: Workflow fields to exclude from schema (e.g., for planning tools)
excluded_common_fields: Common fields to exclude from schema
Returns:
Complete JSON schema for the workflow tool
"""
properties = {}
# Add workflow fields first, excluding any specified fields
workflow_fields = WorkflowSchemaBuilder.WORKFLOW_FIELD_SCHEMAS.copy()
if excluded_workflow_fields:
for field in excluded_workflow_fields:
workflow_fields.pop(field, None)
properties.update(workflow_fields)
# Add common fields (temperature, thinking_mode, etc.) from base builder, excluding any specified fields
common_fields = SchemaBuilder.COMMON_FIELD_SCHEMAS.copy()
if excluded_common_fields:
for field in excluded_common_fields:
common_fields.pop(field, None)
properties.update(common_fields)
# Add model field if provided
if model_field_schema:
properties["model"] = model_field_schema
# Add tool-specific fields if provided
if tool_specific_fields:
properties.update(tool_specific_fields)
# Build required fields list - workflow tools have standard required fields
standard_required = ["step", "step_number", "total_steps", "next_step_required", "findings"]
# Filter out excluded fields from required fields
if excluded_workflow_fields:
standard_required = [field for field in standard_required if field not in excluded_workflow_fields]
required = standard_required + (required_fields or [])
if (auto_mode or require_model) and "model" not in required:
required.append("model")
# Build the complete schema
schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False,
}
if tool_name:
schema["title"] = f"{tool_name.capitalize()}Request"
return schema
@staticmethod
def get_workflow_fields() -> dict[str, dict[str, Any]]:
"""Get the standard field schemas for workflow tools."""
combined = {}
combined.update(WorkflowSchemaBuilder.WORKFLOW_FIELD_SCHEMAS)
combined.update(SchemaBuilder.COMMON_FIELD_SCHEMAS)
return combined
@staticmethod
def get_workflow_only_fields() -> dict[str, dict[str, Any]]:
"""Get only the workflow-specific field schemas."""
return WorkflowSchemaBuilder.WORKFLOW_FIELD_SCHEMAS.copy()

File diff suppressed because it is too large Load diff

21
utils/__init__.py Normal file
View file

@ -0,0 +1,21 @@
"""
Utility functions for Zen MCP Server
"""
from .file_types import CODE_EXTENSIONS, FILE_CATEGORIES, PROGRAMMING_EXTENSIONS, TEXT_EXTENSIONS
from .file_utils import expand_paths, read_file_content, read_files
from .security_config import EXCLUDED_DIRS
from .token_utils import check_token_limit, estimate_tokens
__all__ = [
"read_files",
"read_file_content",
"expand_paths",
"CODE_EXTENSIONS",
"PROGRAMMING_EXTENSIONS",
"TEXT_EXTENSIONS",
"FILE_CATEGORIES",
"EXCLUDED_DIRS",
"estimate_tokens",
"check_token_limit",
]

293
utils/client_info.py Normal file
View file

@ -0,0 +1,293 @@
"""
Client Information Utility for MCP Server
This module provides utilities to extract and format client information
from the MCP protocol's clientInfo sent during initialization.
It also provides friendly name mapping and caching for consistent client
identification across the application.
"""
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
# Global cache for client information
_client_info_cache: Optional[dict[str, Any]] = None
# Mapping of known client names to friendly names
# This is case-insensitive and checks if the key is contained in the client name
CLIENT_NAME_MAPPINGS = {
# Claude variants
"claude-ai": "Claude",
"claude": "Claude",
"claude-desktop": "Claude",
"claude-code": "Claude",
"anthropic": "Claude",
# Gemini variants
"gemini-cli-mcp-client": "Gemini",
"gemini-cli": "Gemini",
"gemini": "Gemini",
"google": "Gemini",
# Other known clients
"cursor": "Cursor",
"vscode": "VS Code",
"codeium": "Codeium",
"copilot": "GitHub Copilot",
# Generic MCP clients
"mcp-client": "MCP Client",
"test-client": "Test Client",
}
# Default friendly name when no match is found
DEFAULT_FRIENDLY_NAME = "Claude"
def get_friendly_name(client_name: str) -> str:
"""
Map a client name to a friendly name.
Args:
client_name: The raw client name from clientInfo
Returns:
A friendly name for display (e.g., "Claude", "Gemini")
"""
if not client_name:
return DEFAULT_FRIENDLY_NAME
# Convert to lowercase for case-insensitive matching
client_name_lower = client_name.lower()
# Check each mapping - using 'in' to handle partial matches
for key, friendly_name in CLIENT_NAME_MAPPINGS.items():
if key.lower() in client_name_lower:
return friendly_name
# If no match found, return the default
return DEFAULT_FRIENDLY_NAME
def get_cached_client_info() -> Optional[dict[str, Any]]:
"""
Get cached client information if available.
Returns:
Cached client info dictionary or None
"""
global _client_info_cache
return _client_info_cache
def get_client_info_from_context(server: Any) -> Optional[dict[str, Any]]:
"""
Extract client information from the MCP server's request context.
The MCP protocol sends clientInfo during initialization containing:
- name: The client application name (e.g., "Claude Code", "Claude Desktop")
- version: The client version string
This function also adds a friendly_name field and caches the result.
Args:
server: The MCP server instance
Returns:
Dictionary with client info or None if not available:
{
"name": "claude-ai",
"version": "1.0.0",
"friendly_name": "Claude"
}
"""
global _client_info_cache
# Return cached info if available
if _client_info_cache is not None:
return _client_info_cache
try:
# Try to access the request context and session
if not server:
return None
# Check if server has request_context property
request_context = None
try:
request_context = server.request_context
except AttributeError:
logger.debug("Server does not have request_context property")
return None
if not request_context:
logger.debug("Request context is None")
return None
# Try to access session from request context
session = None
try:
session = request_context.session
except AttributeError:
logger.debug("Request context does not have session property")
return None
if not session:
logger.debug("Session is None")
return None
# Try to access client params from session
client_params = None
try:
# The clientInfo is stored in _client_params.clientInfo
client_params = session._client_params
except AttributeError:
logger.debug("Session does not have _client_params property")
return None
if not client_params:
logger.debug("Client params is None")
return None
# Try to extract clientInfo
client_info = None
try:
client_info = client_params.clientInfo
except AttributeError:
logger.debug("Client params does not have clientInfo property")
return None
if not client_info:
logger.debug("Client info is None")
return None
# Extract name and version
result = {}
try:
result["name"] = client_info.name
except AttributeError:
logger.debug("Client info does not have name property")
try:
result["version"] = client_info.version
except AttributeError:
logger.debug("Client info does not have version property")
if not result:
return None
# Add friendly name
raw_name = result.get("name", "")
result["friendly_name"] = get_friendly_name(raw_name)
# Cache the result
_client_info_cache = result
logger.debug(f"Cached client info: {result}")
return result
except Exception as e:
logger.debug(f"Error extracting client info: {e}")
return None
def format_client_info(client_info: Optional[dict[str, Any]], use_friendly_name: bool = True) -> str:
"""
Format client information for display.
Args:
client_info: Dictionary with client info or None
use_friendly_name: If True, use the friendly name instead of raw name
Returns:
Formatted string like "Claude v1.0.0" or "Claude"
"""
if not client_info:
return DEFAULT_FRIENDLY_NAME
if use_friendly_name:
name = client_info.get("friendly_name", client_info.get("name", DEFAULT_FRIENDLY_NAME))
else:
name = client_info.get("name", "Unknown")
version = client_info.get("version", "")
if version and not use_friendly_name:
return f"{name} v{version}"
else:
# For friendly names, we just return the name without version
return name
def get_client_friendly_name() -> str:
"""
Get the cached client's friendly name.
This is a convenience function that returns just the friendly name
from the cached client info, defaulting to "Claude" if not available.
Returns:
The friendly name (e.g., "Claude", "Gemini")
"""
cached_info = get_cached_client_info()
if cached_info:
return cached_info.get("friendly_name", DEFAULT_FRIENDLY_NAME)
return DEFAULT_FRIENDLY_NAME
def log_client_info(server: Any, logger_instance: Optional[logging.Logger] = None) -> None:
"""
Log client information extracted from the server.
Args:
server: The MCP server instance
logger_instance: Optional logger to use (defaults to module logger)
"""
log = logger_instance or logger
client_info = get_client_info_from_context(server)
if client_info:
# Log with both raw and friendly names for debugging
raw_name = client_info.get("name", "Unknown")
friendly_name = client_info.get("friendly_name", DEFAULT_FRIENDLY_NAME)
version = client_info.get("version", "")
if raw_name != friendly_name:
log.info(f"MCP Client Connected: {friendly_name} (raw: {raw_name} v{version})")
else:
log.info(f"MCP Client Connected: {friendly_name} v{version}")
# Log to activity logger as well
try:
activity_logger = logging.getLogger("mcp_activity")
activity_logger.info(f"CLIENT_IDENTIFIED: {friendly_name} (name={raw_name}, version={version})")
except Exception:
pass
else:
log.debug("Could not extract client info from MCP protocol")
# Example usage in tools:
#
# from utils.client_info import get_client_friendly_name, get_cached_client_info
#
# # In a tool's execute method:
# def execute(self, arguments: dict[str, Any]) -> list[TextContent]:
# # Get the friendly name of the connected client
# client_name = get_client_friendly_name() # Returns "Claude" or "Gemini" etc.
#
# # Or get full cached info if needed
# client_info = get_cached_client_info()
# if client_info:
# raw_name = client_info['name'] # e.g., "claude-ai"
# version = client_info['version'] # e.g., "1.0.0"
# friendly = client_info['friendly_name'] # e.g., "Claude"
#
# # Customize response based on client
# if client_name == "Claude":
# response = f"Hello from Zen MCP Server to {client_name}!"
# elif client_name == "Gemini":
# response = f"Greetings {client_name}, welcome to Zen MCP Server!"
# else:
# response = f"Welcome {client_name}!"

1095
utils/conversation_memory.py Normal file

File diff suppressed because it is too large Load diff

271
utils/file_types.py Normal file
View file

@ -0,0 +1,271 @@
"""
File type definitions and constants for file processing
This module centralizes all file type and extension definitions used
throughout the MCP server for consistent file handling.
"""
# Programming language file extensions - core code files
PROGRAMMING_LANGUAGES = {
".py", # Python
".js", # JavaScript
".ts", # TypeScript
".jsx", # React JavaScript
".tsx", # React TypeScript
".java", # Java
".cpp", # C++
".c", # C
".h", # C/C++ Header
".hpp", # C++ Header
".cs", # C#
".go", # Go
".rs", # Rust
".rb", # Ruby
".php", # PHP
".swift", # Swift
".kt", # Kotlin
".scala", # Scala
".r", # R
".m", # Objective-C
".mm", # Objective-C++
}
# Script and shell file extensions
SCRIPTS = {
".sql", # SQL
".sh", # Shell
".bash", # Bash
".zsh", # Zsh
".fish", # Fish shell
".ps1", # PowerShell
".bat", # Batch
".cmd", # Command
}
# Configuration and data file extensions
CONFIGS = {
".yml", # YAML
".yaml", # YAML
".json", # JSON
".xml", # XML
".toml", # TOML
".ini", # INI
".cfg", # Config
".conf", # Config
".properties", # Properties
".env", # Environment
}
# Documentation and markup file extensions
DOCS = {
".txt", # Text
".md", # Markdown
".rst", # reStructuredText
".tex", # LaTeX
}
# Web development file extensions
WEB = {
".html", # HTML
".css", # CSS
".scss", # Sass
".sass", # Sass
".less", # Less
}
# Additional text file extensions for logs and data
TEXT_DATA = {
".log", # Log files
".csv", # CSV
".tsv", # TSV
".gitignore", # Git ignore
".dockerfile", # Dockerfile
".makefile", # Make
".cmake", # CMake
".gradle", # Gradle
".sbt", # SBT
".pom", # Maven POM
".lock", # Lock files
".changeset", # Precommit changeset
}
# Image file extensions - limited to what AI models actually support
# Based on OpenAI and Gemini supported formats: PNG, JPEG, GIF, WebP
IMAGES = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
# Binary executable and library extensions
BINARIES = {
".exe", # Windows executable
".dll", # Windows library
".so", # Linux shared object
".dylib", # macOS dynamic library
".bin", # Binary
".class", # Java class
}
# Archive and package file extensions
ARCHIVES = {
".jar",
".war",
".ear", # Java archives
".zip",
".tar",
".gz", # General archives
".7z",
".rar", # Compression
".deb",
".rpm", # Linux packages
".dmg",
".pkg", # macOS packages
}
# Derived sets for different use cases
CODE_EXTENSIONS = PROGRAMMING_LANGUAGES | SCRIPTS | CONFIGS | DOCS | WEB
PROGRAMMING_EXTENSIONS = PROGRAMMING_LANGUAGES # For line numbering
TEXT_EXTENSIONS = CODE_EXTENSIONS | TEXT_DATA
IMAGE_EXTENSIONS = IMAGES
BINARY_EXTENSIONS = BINARIES | ARCHIVES
# All extensions by category for easy access
FILE_CATEGORIES = {
"programming": PROGRAMMING_LANGUAGES,
"scripts": SCRIPTS,
"configs": CONFIGS,
"docs": DOCS,
"web": WEB,
"text_data": TEXT_DATA,
"images": IMAGES,
"binaries": BINARIES,
"archives": ARCHIVES,
}
def get_file_category(file_path: str) -> str:
"""
Determine the category of a file based on its extension.
Args:
file_path: Path to the file
Returns:
Category name or "unknown" if not recognized
"""
from pathlib import Path
extension = Path(file_path).suffix.lower()
for category, extensions in FILE_CATEGORIES.items():
if extension in extensions:
return category
return "unknown"
def is_code_file(file_path: str) -> bool:
"""Check if a file is a code file (programming language)."""
from pathlib import Path
return Path(file_path).suffix.lower() in PROGRAMMING_LANGUAGES
def is_text_file(file_path: str) -> bool:
"""Check if a file is a text file."""
from pathlib import Path
return Path(file_path).suffix.lower() in TEXT_EXTENSIONS
def is_binary_file(file_path: str) -> bool:
"""Check if a file is a binary file."""
from pathlib import Path
return Path(file_path).suffix.lower() in BINARY_EXTENSIONS
# File-type specific token-to-byte ratios for accurate token estimation
# Based on empirical analysis of file compression characteristics and tokenization patterns
TOKEN_ESTIMATION_RATIOS = {
# Programming languages
".py": 3.5, # Python - moderate verbosity
".js": 3.2, # JavaScript - compact syntax
".ts": 3.3, # TypeScript - type annotations add tokens
".jsx": 3.1, # React JSX - JSX tags are tokenized efficiently
".tsx": 3.0, # React TSX - combination of TypeScript + JSX
".java": 3.6, # Java - verbose syntax, long identifiers
".cpp": 3.7, # C++ - preprocessor directives, templates
".c": 3.8, # C - function definitions, struct declarations
".go": 3.9, # Go - explicit error handling, package names
".rs": 3.5, # Rust - similar to Python in verbosity
".php": 3.3, # PHP - mixed HTML/code, variable prefixes
".rb": 3.6, # Ruby - descriptive method names
".swift": 3.4, # Swift - modern syntax, type inference
".kt": 3.5, # Kotlin - similar to modern languages
".scala": 3.2, # Scala - functional programming, concise
# Scripts and configuration
".sh": 4.1, # Shell scripts - commands and paths
".bat": 4.0, # Batch files - similar to shell
".ps1": 3.8, # PowerShell - more structured than bash
".sql": 3.8, # SQL - keywords and table/column names
# Data and configuration formats
".json": 2.5, # JSON - lots of punctuation and quotes
".yaml": 3.0, # YAML - structured but readable
".yml": 3.0, # YAML (alternative extension)
".xml": 2.8, # XML - tags and attributes
".toml": 3.2, # TOML - similar to config files
# Documentation and text
".md": 4.2, # Markdown - natural language with formatting
".txt": 4.0, # Plain text - mostly natural language
".rst": 4.1, # reStructuredText - documentation format
# Web technologies
".html": 2.9, # HTML - tags and attributes
".css": 3.4, # CSS - properties and selectors
# Logs and data
".log": 4.5, # Log files - timestamps, messages, stack traces
".csv": 3.1, # CSV - data with delimiters
# Infrastructure files
".dockerfile": 3.7, # Dockerfile - commands and paths
".tf": 3.5, # Terraform - infrastructure as code
}
def get_token_estimation_ratio(file_path: str) -> float:
"""
Get the token estimation ratio for a file based on its extension.
Args:
file_path: Path to the file
Returns:
Token-to-byte ratio for the file type (default: 3.5 for unknown types)
"""
from pathlib import Path
extension = Path(file_path).suffix.lower()
return TOKEN_ESTIMATION_RATIOS.get(extension, 3.5) # Conservative default
# MIME type mappings for image files - limited to what AI models actually support
# Based on OpenAI and Gemini supported formats: PNG, JPEG, GIF, WebP
IMAGE_MIME_TYPES = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".webp": "image/webp",
}
def get_image_mime_type(extension: str) -> str:
"""
Get the MIME type for an image file extension.
Args:
extension: File extension (with or without leading dot)
Returns:
MIME type string (default: image/jpeg for unknown extensions)
"""
if not extension.startswith("."):
extension = "." + extension
extension = extension.lower()
return IMAGE_MIME_TYPES.get(extension, "image/jpeg")

864
utils/file_utils.py Normal file
View file

@ -0,0 +1,864 @@
"""
File reading utilities with directory support and token management
This module provides secure file access functionality for the MCP server.
It implements critical security measures to prevent unauthorized file access
and manages token limits to ensure efficient API usage.
Key Features:
- Path validation and sandboxing to prevent directory traversal attacks
- Support for both individual files and recursive directory reading
- Token counting and management to stay within API limits
- Automatic file type detection and filtering
- Comprehensive error handling with informative messages
Security Model:
- All file access is restricted to PROJECT_ROOT and its subdirectories
- Absolute paths are required to prevent ambiguity
- Symbolic links are resolved to ensure they stay within bounds
CONVERSATION MEMORY INTEGRATION:
This module works with the conversation memory system to support efficient
multi-turn file handling:
1. DEDUPLICATION SUPPORT:
- File reading functions are called by conversation-aware tools
- Supports newest-first file prioritization by providing accurate token estimation
- Enables efficient file content caching and token budget management
2. TOKEN BUDGET OPTIMIZATION:
- Provides accurate token estimation for file content before reading
- Supports the dual prioritization strategy by enabling precise budget calculations
- Enables tools to make informed decisions about which files to include
3. CROSS-TOOL FILE PERSISTENCE:
- File reading results are used across different tools in conversation chains
- Consistent file access patterns support conversation continuation scenarios
- Error handling preserves conversation flow when files become unavailable
"""
import json
import logging
import os
from pathlib import Path
from typing import Optional
from .file_types import BINARY_EXTENSIONS, CODE_EXTENSIONS, IMAGE_EXTENSIONS, TEXT_EXTENSIONS
from .security_config import EXCLUDED_DIRS, is_dangerous_path
from .token_utils import DEFAULT_CONTEXT_WINDOW, estimate_tokens
def _is_builtin_custom_models_config(path_str: str) -> bool:
"""
Check if path points to the server's built-in custom_models.json config file.
This only matches the server's internal config, not user-specified CUSTOM_MODELS_CONFIG_PATH.
We identify the built-in config by checking if it resolves to the server's conf directory.
Args:
path_str: Path to check
Returns:
True if this is the server's built-in custom_models.json config file
"""
try:
path = Path(path_str)
# Get the server root by going up from this file: utils/file_utils.py -> server_root
server_root = Path(__file__).parent.parent
builtin_config = server_root / "conf" / "custom_models.json"
# Check if the path resolves to the same file as our built-in config
# This handles both relative and absolute paths to the same file
return path.resolve() == builtin_config.resolve()
except Exception:
# If path resolution fails, it's not our built-in config
return False
logger = logging.getLogger(__name__)
def is_mcp_directory(path: Path) -> bool:
"""
Check if a directory is the MCP server's own directory.
This prevents the MCP from including its own code when scanning projects
where the MCP has been cloned as a subdirectory.
Args:
path: Directory path to check
Returns:
True if this is the MCP server directory or a subdirectory
"""
if not path.is_dir():
return False
# Get the directory where the MCP server is running from
# __file__ is utils/file_utils.py, so parent.parent is the MCP root
mcp_server_dir = Path(__file__).parent.parent.resolve()
# Check if the given path is the MCP server directory or a subdirectory
try:
path.resolve().relative_to(mcp_server_dir)
logger.info(f"Detected MCP server directory at {path}, will exclude from scanning")
return True
except ValueError:
# Not a subdirectory of MCP server
return False
def get_user_home_directory() -> Optional[Path]:
"""
Get the user's home directory.
Returns:
User's home directory path
"""
return Path.home()
def is_home_directory_root(path: Path) -> bool:
"""
Check if the given path is the user's home directory root.
This prevents scanning the entire home directory which could include
sensitive data and non-project files.
Args:
path: Directory path to check
Returns:
True if this is the home directory root
"""
user_home = get_user_home_directory()
if not user_home:
return False
try:
resolved_path = path.resolve()
resolved_home = user_home.resolve()
# Check if this is exactly the home directory
if resolved_path == resolved_home:
logger.warning(
f"Attempted to scan user home directory root: {path}. Please specify a subdirectory instead."
)
return True
# Also check common home directory patterns
path_str = str(resolved_path).lower()
home_patterns = [
"/users/", # macOS
"/home/", # Linux
"c:\\users\\", # Windows
"c:/users/", # Windows with forward slashes
]
for pattern in home_patterns:
if pattern in path_str:
# Extract the user directory path
# e.g., /Users/fahad or /home/username
parts = path_str.split(pattern)
if len(parts) > 1:
# Get the part after the pattern
after_pattern = parts[1]
# Check if we're at the user's root (no subdirectories)
if "/" not in after_pattern and "\\" not in after_pattern:
logger.warning(
f"Attempted to scan user home directory root: {path}. "
f"Please specify a subdirectory instead."
)
return True
except Exception as e:
logger.debug(f"Error checking if path is home directory: {e}")
return False
def detect_file_type(file_path: str) -> str:
"""
Detect file type for appropriate processing strategy.
This function is intended for specific file type handling (e.g., image processing,
binary file analysis, or enhanced file filtering).
Args:
file_path: Path to the file to analyze
Returns:
str: "text", "binary", or "image"
"""
path = Path(file_path)
# Check extension first (fast)
extension = path.suffix.lower()
if extension in TEXT_EXTENSIONS:
return "text"
elif extension in IMAGE_EXTENSIONS:
return "image"
elif extension in BINARY_EXTENSIONS:
return "binary"
# Fallback: check magic bytes for text vs binary
# This is helpful for files without extensions or unknown extensions
try:
with open(path, "rb") as f:
chunk = f.read(1024)
# Simple heuristic: if we can decode as UTF-8, likely text
chunk.decode("utf-8")
return "text"
except UnicodeDecodeError:
return "binary"
except (FileNotFoundError, PermissionError) as e:
logger.warning(f"Could not access file {file_path} for type detection: {e}")
return "unknown"
def should_add_line_numbers(file_path: str, include_line_numbers: Optional[bool] = None) -> bool:
"""
Determine if line numbers should be added to a file.
Args:
file_path: Path to the file
include_line_numbers: Explicit preference, or None for auto-detection
Returns:
bool: True if line numbers should be added
"""
if include_line_numbers is not None:
return include_line_numbers
# Default: DO NOT add line numbers
# Tools that want line numbers must explicitly request them
return False
def _normalize_line_endings(content: str) -> str:
"""
Normalize line endings for consistent line numbering.
Args:
content: File content with potentially mixed line endings
Returns:
str: Content with normalized LF line endings
"""
# Normalize all line endings to LF for consistent counting
return content.replace("\r\n", "\n").replace("\r", "\n")
def _add_line_numbers(content: str) -> str:
"""
Add line numbers to text content for precise referencing.
Args:
content: Text content to number
Returns:
str: Content with line numbers in format " 45│ actual code line"
Supports files up to 99,999 lines with dynamic width allocation
"""
# Normalize line endings first
normalized_content = _normalize_line_endings(content)
lines = normalized_content.split("\n")
# Dynamic width allocation based on total line count
# This supports files of any size by computing required width
total_lines = len(lines)
width = len(str(total_lines))
width = max(width, 4) # Minimum padding for readability
# Format with dynamic width and clear separator
numbered_lines = [f"{i + 1:{width}d}{line}" for i, line in enumerate(lines)]
return "\n".join(numbered_lines)
def resolve_and_validate_path(path_str: str) -> Path:
"""
Resolves and validates a path against security policies.
This function ensures safe file access by:
1. Requiring absolute paths (no ambiguity)
2. Resolving symlinks to prevent deception
3. Blocking access to dangerous system directories
Args:
path_str: Path string (must be absolute)
Returns:
Resolved Path object that is safe to access
Raises:
ValueError: If path is not absolute or otherwise invalid
PermissionError: If path is in a dangerous location
"""
# Step 1: Create a Path object
user_path = Path(path_str)
# Step 2: Security Policy - Require absolute paths
# Relative paths could be interpreted differently depending on working directory
if not user_path.is_absolute():
raise ValueError(f"Relative paths are not supported. Please provide an absolute path.\nReceived: {path_str}")
# Step 3: Resolve the absolute path (follows symlinks, removes .. and .)
# This is critical for security as it reveals the true destination of symlinks
resolved_path = user_path.resolve()
# Step 4: Check against dangerous paths
if is_dangerous_path(resolved_path):
logger.warning(f"Access denied - dangerous path: {resolved_path}")
raise PermissionError(f"Access to system directory denied: {path_str}")
# Step 5: Check if it's the home directory root
if is_home_directory_root(resolved_path):
raise PermissionError(
f"Cannot scan entire home directory: {path_str}\n" f"Please specify a subdirectory within your home folder."
)
return resolved_path
def expand_paths(paths: list[str], extensions: Optional[set[str]] = None) -> list[str]:
"""
Expand paths to individual files, handling both files and directories.
This function recursively walks directories to find all matching files.
It automatically filters out hidden files and common non-code directories
like __pycache__ to avoid including generated or system files.
Args:
paths: List of file or directory paths (must be absolute)
extensions: Optional set of file extensions to include (defaults to CODE_EXTENSIONS)
Returns:
List of individual file paths, sorted for consistent ordering
"""
if extensions is None:
extensions = CODE_EXTENSIONS
expanded_files = []
seen = set()
for path in paths:
try:
# Validate each path for security before processing
path_obj = resolve_and_validate_path(path)
except (ValueError, PermissionError):
# Skip invalid paths silently to allow partial success
continue
if not path_obj.exists():
continue
# Safety checks for directory scanning
if path_obj.is_dir():
# Check 1: Prevent scanning user's home directory root
if is_home_directory_root(path_obj):
logger.warning(f"Skipping home directory root: {path}. Please specify a project subdirectory instead.")
continue
# Check 2: Skip if this is the MCP's own directory
if is_mcp_directory(path_obj):
logger.info(
f"Skipping MCP server directory: {path}. The MCP server code is excluded from project scans."
)
continue
if path_obj.is_file():
# Add file directly
if str(path_obj) not in seen:
expanded_files.append(str(path_obj))
seen.add(str(path_obj))
elif path_obj.is_dir():
# Walk directory recursively to find all files
for root, dirs, files in os.walk(path_obj):
# Filter directories in-place to skip hidden and excluded directories
# This prevents descending into .git, .venv, __pycache__, node_modules, etc.
original_dirs = dirs[:]
dirs[:] = []
for d in original_dirs:
# Skip hidden directories
if d.startswith("."):
continue
# Skip excluded directories
if d in EXCLUDED_DIRS:
continue
# Skip MCP directories found during traversal
dir_path = Path(root) / d
if is_mcp_directory(dir_path):
logger.debug(f"Skipping MCP directory during traversal: {dir_path}")
continue
dirs.append(d)
for file in files:
# Skip hidden files (e.g., .DS_Store, .gitignore)
if file.startswith("."):
continue
file_path = Path(root) / file
# Filter by extension if specified
if not extensions or file_path.suffix.lower() in extensions:
full_path = str(file_path)
# Use set to prevent duplicates
if full_path not in seen:
expanded_files.append(full_path)
seen.add(full_path)
# Sort for consistent ordering across different runs
# This makes output predictable and easier to debug
expanded_files.sort()
return expanded_files
def read_file_content(
file_path: str, max_size: int = 1_000_000, *, include_line_numbers: Optional[bool] = None
) -> tuple[str, int]:
"""
Read a single file and format it for inclusion in AI prompts.
This function handles various error conditions gracefully and always
returns formatted content, even for errors. This ensures the AI model
gets context about what files were attempted but couldn't be read.
Args:
file_path: Path to file (must be absolute)
max_size: Maximum file size to read (default 1MB to prevent memory issues)
include_line_numbers: Whether to add line numbers. If None, auto-detects based on file type
Returns:
Tuple of (formatted_content, estimated_tokens)
Content is wrapped with clear delimiters for AI parsing
"""
logger.debug(f"[FILES] read_file_content called for: {file_path}")
try:
# Validate path security before any file operations
path = resolve_and_validate_path(file_path)
logger.debug(f"[FILES] Path validated and resolved: {path}")
except (ValueError, PermissionError) as e:
# Return error in a format that provides context to the AI
logger.debug(f"[FILES] Path validation failed for {file_path}: {type(e).__name__}: {e}")
error_msg = str(e)
content = f"\n--- ERROR ACCESSING FILE: {file_path} ---\nError: {error_msg}\n--- END FILE ---\n"
tokens = estimate_tokens(content)
logger.debug(f"[FILES] Returning error content for {file_path}: {tokens} tokens")
return content, tokens
try:
# Validate file existence and type
if not path.exists():
logger.debug(f"[FILES] File does not exist: {file_path}")
content = f"\n--- FILE NOT FOUND: {file_path} ---\nError: File does not exist\n--- END FILE ---\n"
return content, estimate_tokens(content)
if not path.is_file():
logger.debug(f"[FILES] Path is not a file: {file_path}")
content = f"\n--- NOT A FILE: {file_path} ---\nError: Path is not a file\n--- END FILE ---\n"
return content, estimate_tokens(content)
# Check file size to prevent memory exhaustion
file_size = path.stat().st_size
logger.debug(f"[FILES] File size for {file_path}: {file_size:,} bytes")
if file_size > max_size:
logger.debug(f"[FILES] File too large: {file_path} ({file_size:,} > {max_size:,} bytes)")
content = f"\n--- FILE TOO LARGE: {file_path} ---\nFile size: {file_size:,} bytes (max: {max_size:,})\n--- END FILE ---\n"
return content, estimate_tokens(content)
# Determine if we should add line numbers
add_line_numbers = should_add_line_numbers(file_path, include_line_numbers)
logger.debug(f"[FILES] Line numbers for {file_path}: {'enabled' if add_line_numbers else 'disabled'}")
# Read the file with UTF-8 encoding, replacing invalid characters
# This ensures we can handle files with mixed encodings
logger.debug(f"[FILES] Reading file content for {file_path}")
with open(path, encoding="utf-8", errors="replace") as f:
file_content = f.read()
logger.debug(f"[FILES] Successfully read {len(file_content)} characters from {file_path}")
# Add line numbers if requested or auto-detected
if add_line_numbers:
file_content = _add_line_numbers(file_content)
logger.debug(f"[FILES] Added line numbers to {file_path}")
else:
# Still normalize line endings for consistency
file_content = _normalize_line_endings(file_content)
# Format with clear delimiters that help the AI understand file boundaries
# Using consistent markers makes it easier for the model to parse
# NOTE: These markers ("--- BEGIN FILE: ... ---") are distinct from git diff markers
# ("--- BEGIN DIFF: ... ---") to allow AI to distinguish between complete file content
# vs. partial diff content when files appear in both sections
formatted = f"\n--- BEGIN FILE: {file_path} ---\n{file_content}\n--- END FILE: {file_path} ---\n"
tokens = estimate_tokens(formatted)
logger.debug(f"[FILES] Formatted content for {file_path}: {len(formatted)} chars, {tokens} tokens")
return formatted, tokens
except Exception as e:
logger.debug(f"[FILES] Exception reading file {file_path}: {type(e).__name__}: {e}")
content = f"\n--- ERROR READING FILE: {file_path} ---\nError: {str(e)}\n--- END FILE ---\n"
tokens = estimate_tokens(content)
logger.debug(f"[FILES] Returning error content for {file_path}: {tokens} tokens")
return content, tokens
def read_files(
file_paths: list[str],
code: Optional[str] = None,
max_tokens: Optional[int] = None,
reserve_tokens: int = 50_000,
*,
include_line_numbers: bool = False,
) -> str:
"""
Read multiple files and optional direct code with smart token management.
This function implements intelligent token budgeting to maximize the amount
of relevant content that can be included in an AI prompt while staying
within token limits. It prioritizes direct code and reads files until
the token budget is exhausted.
Args:
file_paths: List of file or directory paths (absolute paths required)
code: Optional direct code to include (prioritized over files)
max_tokens: Maximum tokens to use (defaults to DEFAULT_CONTEXT_WINDOW)
reserve_tokens: Tokens to reserve for prompt and response (default 50K)
include_line_numbers: Whether to add line numbers to file content
Returns:
str: All file contents formatted for AI consumption
"""
if max_tokens is None:
max_tokens = DEFAULT_CONTEXT_WINDOW
logger.debug(f"[FILES] read_files called with {len(file_paths)} paths")
logger.debug(
f"[FILES] Token budget: max={max_tokens:,}, reserve={reserve_tokens:,}, available={max_tokens - reserve_tokens:,}"
)
content_parts = []
total_tokens = 0
available_tokens = max_tokens - reserve_tokens
files_skipped = []
# Priority 1: Handle direct code if provided
# Direct code is prioritized because it's explicitly provided by the user
if code:
formatted_code = f"\n--- BEGIN DIRECT CODE ---\n{code}\n--- END DIRECT CODE ---\n"
code_tokens = estimate_tokens(formatted_code)
if code_tokens <= available_tokens:
content_parts.append(formatted_code)
total_tokens += code_tokens
available_tokens -= code_tokens
# Priority 2: Process file paths
if file_paths:
# Expand directories to get all individual files
logger.debug(f"[FILES] Expanding {len(file_paths)} file paths")
all_files = expand_paths(file_paths)
logger.debug(f"[FILES] After expansion: {len(all_files)} individual files")
if not all_files and file_paths:
# No files found but paths were provided
logger.debug("[FILES] No files found from provided paths")
content_parts.append(f"\n--- NO FILES FOUND ---\nProvided paths: {', '.join(file_paths)}\n--- END ---\n")
else:
# Read files sequentially until token limit is reached
logger.debug(f"[FILES] Reading {len(all_files)} files with token budget {available_tokens:,}")
for i, file_path in enumerate(all_files):
if total_tokens >= available_tokens:
logger.debug(f"[FILES] Token budget exhausted, skipping remaining {len(all_files) - i} files")
files_skipped.extend(all_files[i:])
break
file_content, file_tokens = read_file_content(file_path, include_line_numbers=include_line_numbers)
logger.debug(f"[FILES] File {file_path}: {file_tokens:,} tokens")
# Check if adding this file would exceed limit
if total_tokens + file_tokens <= available_tokens:
content_parts.append(file_content)
total_tokens += file_tokens
logger.debug(f"[FILES] Added file {file_path}, total tokens: {total_tokens:,}")
else:
# File too large for remaining budget
logger.debug(
f"[FILES] File {file_path} too large for remaining budget ({file_tokens:,} tokens, {available_tokens - total_tokens:,} remaining)"
)
files_skipped.append(file_path)
# Add informative note about skipped files to help users understand
# what was omitted and why
if files_skipped:
logger.debug(f"[FILES] {len(files_skipped)} files skipped due to token limits")
skip_note = "\n\n--- SKIPPED FILES (TOKEN LIMIT) ---\n"
skip_note += f"Total skipped: {len(files_skipped)}\n"
# Show first 10 skipped files as examples
for _i, file_path in enumerate(files_skipped[:10]):
skip_note += f" - {file_path}\n"
if len(files_skipped) > 10:
skip_note += f" ... and {len(files_skipped) - 10} more\n"
skip_note += "--- END SKIPPED FILES ---\n"
content_parts.append(skip_note)
result = "\n\n".join(content_parts) if content_parts else ""
logger.debug(f"[FILES] read_files complete: {len(result)} chars, {total_tokens:,} tokens used")
return result
def estimate_file_tokens(file_path: str) -> int:
"""
Estimate tokens for a file using file-type aware ratios.
Args:
file_path: Path to the file
Returns:
Estimated token count for the file
"""
try:
if not os.path.exists(file_path) or not os.path.isfile(file_path):
return 0
file_size = os.path.getsize(file_path)
# Get the appropriate ratio for this file type
from .file_types import get_token_estimation_ratio
ratio = get_token_estimation_ratio(file_path)
return int(file_size / ratio)
except Exception:
return 0
def check_files_size_limit(files: list[str], max_tokens: int, threshold_percent: float = 1.0) -> tuple[bool, int, int]:
"""
Check if a list of files would exceed token limits.
Args:
files: List of file paths to check
max_tokens: Maximum allowed tokens
threshold_percent: Percentage of max_tokens to use as threshold (0.0-1.0)
Returns:
Tuple of (within_limit, total_estimated_tokens, file_count)
"""
if not files:
return True, 0, 0
total_estimated_tokens = 0
file_count = 0
threshold = int(max_tokens * threshold_percent)
for file_path in files:
try:
estimated_tokens = estimate_file_tokens(file_path)
total_estimated_tokens += estimated_tokens
if estimated_tokens > 0: # Only count accessible files
file_count += 1
except Exception:
# Skip files that can't be accessed for size check
continue
within_limit = total_estimated_tokens <= threshold
return within_limit, total_estimated_tokens, file_count
def read_json_file(file_path: str) -> Optional[dict]:
"""
Read and parse a JSON file with proper error handling.
Args:
file_path: Path to the JSON file
Returns:
Parsed JSON data as dict, or None if file doesn't exist or invalid
"""
try:
if not os.path.exists(file_path):
return None
with open(file_path, encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, OSError):
return None
def write_json_file(file_path: str, data: dict, indent: int = 2) -> bool:
"""
Write data to a JSON file with proper formatting.
Args:
file_path: Path to write the JSON file
data: Dictionary data to serialize
indent: JSON indentation level
Returns:
True if successful, False otherwise
"""
try:
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=indent, ensure_ascii=False)
return True
except (OSError, TypeError):
return False
def get_file_size(file_path: str) -> int:
"""
Get file size in bytes with proper error handling.
Args:
file_path: Path to the file
Returns:
File size in bytes, or 0 if file doesn't exist or error
"""
try:
if os.path.exists(file_path) and os.path.isfile(file_path):
return os.path.getsize(file_path)
return 0
except OSError:
return 0
def ensure_directory_exists(file_path: str) -> bool:
"""
Ensure the parent directory of a file path exists.
Args:
file_path: Path to file (directory will be created for parent)
Returns:
True if directory exists or was created, False on error
"""
try:
directory = os.path.dirname(file_path)
if directory:
os.makedirs(directory, exist_ok=True)
return True
except OSError:
return False
def is_text_file(file_path: str) -> bool:
"""
Check if a file is likely a text file based on extension and content.
Args:
file_path: Path to the file
Returns:
True if file appears to be text, False otherwise
"""
from .file_types import is_text_file as check_text_type
return check_text_type(file_path)
def read_file_safely(file_path: str, max_size: int = 10 * 1024 * 1024) -> Optional[str]:
"""
Read a file with size limits and encoding handling.
Args:
file_path: Path to the file
max_size: Maximum file size in bytes (default 10MB)
Returns:
File content as string, or None if file too large or unreadable
"""
try:
if not os.path.exists(file_path) or not os.path.isfile(file_path):
return None
file_size = os.path.getsize(file_path)
if file_size > max_size:
return None
with open(file_path, encoding="utf-8", errors="ignore") as f:
return f.read()
except OSError:
return None
def check_total_file_size(files: list[str], model_name: str) -> Optional[dict]:
"""
Check if total file sizes would exceed token threshold before embedding.
IMPORTANT: This performs STRICT REJECTION at MCP boundary.
No partial inclusion - either all files fit or request is rejected.
This forces Claude to make better file selection decisions.
This function MUST be called with the effective model name (after resolution).
It should never receive 'auto' or None - model resolution happens earlier.
Args:
files: List of file paths to check
model_name: The resolved model name for context-aware thresholds (required)
Returns:
Dict with `code_too_large` response if too large, None if acceptable
"""
if not files:
return None
# Validate we have a proper model name (not auto or None)
if not model_name or model_name.lower() == "auto":
raise ValueError(
f"check_total_file_size called with unresolved model: '{model_name}'. "
"Model must be resolved before file size checking."
)
logger.info(f"File size check: Using model '{model_name}' for token limit calculation")
from utils.model_context import ModelContext
model_context = ModelContext(model_name)
token_allocation = model_context.calculate_token_allocation()
# Dynamic threshold based on model capacity
context_window = token_allocation.total_tokens
if context_window >= 1_000_000: # Gemini-class models
threshold_percent = 0.8 # Can be more generous
elif context_window >= 500_000: # Mid-range models
threshold_percent = 0.7 # Moderate
else: # OpenAI-class models (200K)
threshold_percent = 0.6 # Conservative
max_file_tokens = int(token_allocation.file_tokens * threshold_percent)
# Use centralized file size checking (threshold already applied to max_file_tokens)
within_limit, total_estimated_tokens, file_count = check_files_size_limit(files, max_file_tokens)
if not within_limit:
return {
"status": "code_too_large",
"content": (
f"The selected files are too large for analysis "
f"(estimated {total_estimated_tokens:,} tokens, limit {max_file_tokens:,}). "
f"Please select fewer, more specific files that are most relevant "
f"to your question, then invoke the tool again."
),
"content_type": "text",
"metadata": {
"total_estimated_tokens": total_estimated_tokens,
"limit": max_file_tokens,
"file_count": file_count,
"threshold_percent": threshold_percent,
"model_context_window": context_window,
"model_name": model_name,
"instructions": "Reduce file selection and try again - all files must fit within budget. If this persists, please use a model with a larger context window where available.",
},
}
return None # Proceed with ALL files

94
utils/image_utils.py Normal file
View file

@ -0,0 +1,94 @@
"""Utility helpers for validating image inputs."""
import base64
import binascii
import os
from collections.abc import Iterable
from utils.file_types import IMAGES, get_image_mime_type
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
__all__ = ["DEFAULT_MAX_IMAGE_SIZE_MB", "validate_image"]
def _valid_mime_types() -> Iterable[str]:
"""Return the MIME types permitted by the IMAGES whitelist."""
return (get_image_mime_type(ext) for ext in IMAGES)
def validate_image(image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
"""Validate a user-supplied image path or data URL.
Args:
image_path: Either a filesystem path or a data URL.
max_size_mb: Optional size limit (defaults to ``DEFAULT_MAX_IMAGE_SIZE_MB``).
Returns:
A tuple ``(image_bytes, mime_type)`` ready for upstream providers.
Raises:
ValueError: When the image is missing, malformed, or exceeds limits.
"""
if max_size_mb is None:
max_size_mb = DEFAULT_MAX_IMAGE_SIZE_MB
if image_path.startswith("data:"):
return _validate_data_url(image_path, max_size_mb)
return _validate_file_path(image_path, max_size_mb)
def _validate_data_url(image_data_url: str, max_size_mb: float) -> tuple[bytes, str]:
"""Validate a data URL and return image bytes plus MIME type."""
try:
header, data = image_data_url.split(",", 1)
mime_type = header.split(";")[0].split(":")[1]
except (ValueError, IndexError) as exc:
raise ValueError(f"Invalid data URL format: {exc}")
valid_mime_types = list(_valid_mime_types())
if mime_type not in valid_mime_types:
raise ValueError(
"Unsupported image type: {mime}. Supported types: {supported}".format(
mime=mime_type, supported=", ".join(valid_mime_types)
)
)
try:
image_bytes = base64.b64decode(data)
except binascii.Error as exc:
raise ValueError(f"Invalid base64 data: {exc}")
_validate_size(image_bytes, max_size_mb)
return image_bytes, mime_type
def _validate_file_path(file_path: str, max_size_mb: float) -> tuple[bytes, str]:
"""Validate an image loaded from the filesystem."""
try:
with open(file_path, "rb") as handle:
image_bytes = handle.read()
except FileNotFoundError:
raise ValueError(f"Image file not found: {file_path}")
except OSError as exc:
raise ValueError(f"Failed to read image file: {exc}")
ext = os.path.splitext(file_path)[1].lower()
if ext not in IMAGES:
raise ValueError(
"Unsupported image format: {ext}. Supported formats: {supported}".format(
ext=ext, supported=", ".join(sorted(IMAGES))
)
)
mime_type = get_image_mime_type(ext)
_validate_size(image_bytes, max_size_mb)
return image_bytes, mime_type
def _validate_size(image_bytes: bytes, max_size_mb: float) -> None:
"""Ensure the image does not exceed the configured size limit."""
size_mb = len(image_bytes) / (1024 * 1024)
if size_mb > max_size_mb:
raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)")

180
utils/model_context.py Normal file
View file

@ -0,0 +1,180 @@
"""
Model context management for dynamic token allocation.
This module provides a clean abstraction for model-specific token management,
ensuring that token limits are properly calculated based on the current model
being used, not global constants.
CONVERSATION MEMORY INTEGRATION:
This module works closely with the conversation memory system to provide
optimal token allocation for multi-turn conversations:
1. DUAL PRIORITIZATION STRATEGY SUPPORT:
- Provides separate token budgets for conversation history vs. files
- Enables the conversation memory system to apply newest-first prioritization
- Ensures optimal balance between context preservation and new content
2. MODEL-SPECIFIC ALLOCATION:
- Dynamic allocation based on model capabilities (context window size)
- Conservative allocation for smaller models (O3: 200K context)
- Generous allocation for larger models (Gemini: 1M+ context)
- Adapts token distribution ratios based on model capacity
3. CROSS-TOOL CONSISTENCY:
- Provides consistent token budgets across different tools
- Enables seamless conversation continuation between tools
- Supports conversation reconstruction with proper budget management
"""
import logging
from dataclasses import dataclass
from typing import Any, Optional
from config import DEFAULT_MODEL
from providers import ModelCapabilities, ModelProviderRegistry
logger = logging.getLogger(__name__)
@dataclass
class TokenAllocation:
"""Token allocation strategy for a model."""
total_tokens: int
content_tokens: int
response_tokens: int
file_tokens: int
history_tokens: int
@property
def available_for_prompt(self) -> int:
"""Tokens available for the actual prompt after allocations."""
return self.content_tokens - self.file_tokens - self.history_tokens
class ModelContext:
"""
Encapsulates model-specific information and token calculations.
This class provides a single source of truth for all model-related
token calculations, ensuring consistency across the system.
"""
def __init__(self, model_name: str, model_option: Optional[str] = None):
self.model_name = model_name
self.model_option = model_option # Store optional model option (e.g., "for", "against", etc.)
self._provider = None
self._capabilities = None
self._token_allocation = None
@property
def provider(self):
"""Get the model provider lazily."""
if self._provider is None:
self._provider = ModelProviderRegistry.get_provider_for_model(self.model_name)
if not self._provider:
available_models = ModelProviderRegistry.get_available_model_names()
if available_models:
available_text = ", ".join(available_models)
else:
available_text = (
"No models detected. Configure provider credentials or set DEFAULT_MODEL to a valid option."
)
raise ValueError(
f"Model '{self.model_name}' is not available with current API keys. Available models: {available_text}."
)
return self._provider
@property
def capabilities(self) -> ModelCapabilities:
"""Get model capabilities lazily."""
if self._capabilities is None:
self._capabilities = self.provider.get_capabilities(self.model_name)
return self._capabilities
def calculate_token_allocation(self, reserved_for_response: Optional[int] = None) -> TokenAllocation:
"""
Calculate token allocation based on model capacity and conversation requirements.
This method implements the core token budget calculation that supports the
dual prioritization strategy used in conversation memory and file processing:
TOKEN ALLOCATION STRATEGY:
1. CONTENT vs RESPONSE SPLIT:
- Smaller models (< 300K): 60% content, 40% response (conservative)
- Larger models ( 300K): 80% content, 20% response (generous)
2. CONTENT SUB-ALLOCATION:
- File tokens: 30-40% of content budget for newest file versions
- History tokens: 40-50% of content budget for conversation context
- Remaining: Available for tool-specific prompt content
3. CONVERSATION MEMORY INTEGRATION:
- History allocation enables conversation reconstruction in reconstruct_thread_context()
- File allocation supports newest-first file prioritization in tools
- Remaining budget passed to tools via _remaining_tokens parameter
Args:
reserved_for_response: Override response token reservation
Returns:
TokenAllocation with calculated budgets for dual prioritization strategy
"""
total_tokens = self.capabilities.context_window
# Dynamic allocation based on model capacity
if total_tokens < 300_000:
# Smaller context models (O3): Conservative allocation
content_ratio = 0.6 # 60% for content
response_ratio = 0.4 # 40% for response
file_ratio = 0.3 # 30% of content for files
history_ratio = 0.5 # 50% of content for history
else:
# Larger context models (Gemini): More generous allocation
content_ratio = 0.8 # 80% for content
response_ratio = 0.2 # 20% for response
file_ratio = 0.4 # 40% of content for files
history_ratio = 0.4 # 40% of content for history
# Calculate allocations
content_tokens = int(total_tokens * content_ratio)
response_tokens = reserved_for_response or int(total_tokens * response_ratio)
# Sub-allocations within content budget
file_tokens = int(content_tokens * file_ratio)
history_tokens = int(content_tokens * history_ratio)
allocation = TokenAllocation(
total_tokens=total_tokens,
content_tokens=content_tokens,
response_tokens=response_tokens,
file_tokens=file_tokens,
history_tokens=history_tokens,
)
logger.debug(f"Token allocation for {self.model_name}:")
logger.debug(f" Total: {allocation.total_tokens:,}")
logger.debug(f" Content: {allocation.content_tokens:,} ({content_ratio:.0%})")
logger.debug(f" Response: {allocation.response_tokens:,} ({response_ratio:.0%})")
logger.debug(f" Files: {allocation.file_tokens:,} ({file_ratio:.0%} of content)")
logger.debug(f" History: {allocation.history_tokens:,} ({history_ratio:.0%} of content)")
return allocation
def estimate_tokens(self, text: str) -> int:
"""
Estimate token count for text using model-specific tokenizer.
For now, uses simple estimation. Can be enhanced with model-specific
tokenizers (tiktoken for OpenAI, etc.) in the future.
"""
# TODO: Integrate model-specific tokenizers
# For now, use conservative estimation
return len(text) // 3 # Conservative estimate
@classmethod
def from_arguments(cls, arguments: dict[str, Any]) -> "ModelContext":
"""Create ModelContext from tool arguments."""
model_name = arguments.get("model") or DEFAULT_MODEL
return cls(model_name)

226
utils/model_restrictions.py Normal file
View file

@ -0,0 +1,226 @@
"""
Model Restriction Service
This module provides centralized management of model usage restrictions
based on environment variables. It allows organizations to limit which
models can be used from each provider for cost control, compliance, or
standardization purposes.
Environment Variables:
- OPENAI_ALLOWED_MODELS: Comma-separated list of allowed OpenAI models
- GOOGLE_ALLOWED_MODELS: Comma-separated list of allowed Gemini models
- XAI_ALLOWED_MODELS: Comma-separated list of allowed X.AI GROK models
- OPENROUTER_ALLOWED_MODELS: Comma-separated list of allowed OpenRouter models
- DIAL_ALLOWED_MODELS: Comma-separated list of allowed DIAL models
Example:
OPENAI_ALLOWED_MODELS=o3-mini,o4-mini
GOOGLE_ALLOWED_MODELS=flash
XAI_ALLOWED_MODELS=grok-3,grok-3-fast
OPENROUTER_ALLOWED_MODELS=opus,sonnet,mistral
"""
import logging
import os
from typing import Optional
from providers.shared import ProviderType
logger = logging.getLogger(__name__)
class ModelRestrictionService:
"""Central authority for environment-driven model allowlists.
Role
Interpret ``*_ALLOWED_MODELS`` environment variables, keep their
entries normalised (lowercase), and answer whether a provider/model
pairing is permitted.
Responsibilities
* Parse, cache, and expose per-provider restriction sets
* Validate configuration by cross-checking each entry against the
providers alias-aware model list
* Offer helper methods such as ``is_allowed`` and ``filter_models`` to
enforce policy everywhere model names appear (tool selection, CLI
commands, etc.).
"""
# Environment variable names
ENV_VARS = {
ProviderType.OPENAI: "OPENAI_ALLOWED_MODELS",
ProviderType.GOOGLE: "GOOGLE_ALLOWED_MODELS",
ProviderType.XAI: "XAI_ALLOWED_MODELS",
ProviderType.OPENROUTER: "OPENROUTER_ALLOWED_MODELS",
ProviderType.DIAL: "DIAL_ALLOWED_MODELS",
}
def __init__(self):
"""Initialize the restriction service by loading from environment."""
self.restrictions: dict[ProviderType, set[str]] = {}
self._load_from_env()
def _load_from_env(self) -> None:
"""Load restrictions from environment variables."""
for provider_type, env_var in self.ENV_VARS.items():
env_value = os.getenv(env_var)
if env_value is None or env_value == "":
# Not set or empty - no restrictions (allow all models)
logger.debug(f"{env_var} not set or empty - all {provider_type.value} models allowed")
continue
# Parse comma-separated list
models = set()
for model in env_value.split(","):
cleaned = model.strip().lower()
if cleaned:
models.add(cleaned)
if models:
self.restrictions[provider_type] = models
logger.info(f"{provider_type.value} allowed models: {sorted(models)}")
else:
# All entries were empty after cleaning - treat as no restrictions
logger.debug(f"{env_var} contains only whitespace - all {provider_type.value} models allowed")
def validate_against_known_models(self, provider_instances: dict[ProviderType, any]) -> None:
"""
Validate restrictions against known models from providers.
This should be called after providers are initialized to warn about
typos or invalid model names in the restriction lists.
Args:
provider_instances: Dictionary of provider type to provider instance
"""
for provider_type, allowed_models in self.restrictions.items():
provider = provider_instances.get(provider_type)
if not provider:
continue
# Get all supported models using the clean polymorphic interface
try:
# Gather canonical models and aliases with consistent formatting
all_models = provider.list_models(
respect_restrictions=False,
include_aliases=True,
lowercase=True,
unique=True,
)
supported_models = set(all_models)
except Exception as e:
logger.debug(f"Could not get model list from {provider_type.value} provider: {e}")
supported_models = set()
# Check each allowed model
for allowed_model in allowed_models:
if allowed_model not in supported_models:
logger.warning(
f"Model '{allowed_model}' in {self.ENV_VARS[provider_type]} "
f"is not a recognized {provider_type.value} model. "
f"Please check for typos. Known models: {sorted(supported_models)}"
)
def is_allowed(self, provider_type: ProviderType, model_name: str, original_name: Optional[str] = None) -> bool:
"""
Check if a model is allowed for a specific provider.
Args:
provider_type: The provider type (OPENAI, GOOGLE, etc.)
model_name: The canonical model name (after alias resolution)
original_name: The original model name before alias resolution (optional)
Returns:
True if allowed (or no restrictions), False if restricted
"""
if provider_type not in self.restrictions:
# No restrictions for this provider
return True
allowed_set = self.restrictions[provider_type]
if len(allowed_set) == 0:
# Empty set - allowed
return True
# Check both the resolved name and original name (if different)
names_to_check = {model_name.lower()}
if original_name and original_name.lower() != model_name.lower():
names_to_check.add(original_name.lower())
# If any of the names is in the allowed set, it's allowed
return any(name in allowed_set for name in names_to_check)
def get_allowed_models(self, provider_type: ProviderType) -> Optional[set[str]]:
"""
Get the set of allowed models for a provider.
Args:
provider_type: The provider type
Returns:
Set of allowed model names, or None if no restrictions
"""
return self.restrictions.get(provider_type)
def has_restrictions(self, provider_type: ProviderType) -> bool:
"""
Check if a provider has any restrictions.
Args:
provider_type: The provider type
Returns:
True if restrictions exist, False otherwise
"""
return provider_type in self.restrictions
def filter_models(self, provider_type: ProviderType, models: list[str]) -> list[str]:
"""
Filter a list of models based on restrictions.
Args:
provider_type: The provider type
models: List of model names to filter
Returns:
Filtered list containing only allowed models
"""
if not self.has_restrictions(provider_type):
return models
return [m for m in models if self.is_allowed(provider_type, m)]
def get_restriction_summary(self) -> dict[str, any]:
"""
Get a summary of all restrictions for logging/debugging.
Returns:
Dictionary with provider names and their restrictions
"""
summary = {}
for provider_type, allowed_set in self.restrictions.items():
if allowed_set:
summary[provider_type.value] = sorted(allowed_set)
else:
summary[provider_type.value] = "none (provider disabled)"
return summary
# Global instance (singleton pattern)
_restriction_service: Optional[ModelRestrictionService] = None
def get_restriction_service() -> ModelRestrictionService:
"""
Get the global restriction service instance.
Returns:
The singleton ModelRestrictionService instance
"""
global _restriction_service
if _restriction_service is None:
_restriction_service = ModelRestrictionService()
return _restriction_service

104
utils/security_config.py Normal file
View file

@ -0,0 +1,104 @@
"""
Security configuration and path validation constants
This module contains security-related constants and configurations
for file access control.
"""
from pathlib import Path
# Dangerous paths that should never be scanned
# These would give overly broad access and pose security risks
DANGEROUS_PATHS = {
"/",
"/etc",
"/usr",
"/bin",
"/var",
"/root",
"/home",
"C:\\",
"C:\\Windows",
"C:\\Program Files",
"C:\\Users",
}
# Directories to exclude from recursive file search
# These typically contain generated code, dependencies, or build artifacts
EXCLUDED_DIRS = {
# Python
"__pycache__",
".venv",
"venv",
"env",
".env",
"*.egg-info",
".eggs",
"wheels",
".Python",
".mypy_cache",
".pytest_cache",
".tox",
"htmlcov",
".coverage",
"coverage",
# Node.js / JavaScript
"node_modules",
".next",
".nuxt",
"bower_components",
".sass-cache",
# Version Control
".git",
".svn",
".hg",
# Build Output
"build",
"dist",
"target",
"out",
# IDEs
".idea",
".vscode",
".sublime",
".atom",
".brackets",
# Temporary / Cache
".cache",
".temp",
".tmp",
"*.swp",
"*.swo",
"*~",
# OS-specific
".DS_Store",
"Thumbs.db",
# Java / JVM
".gradle",
".m2",
# Documentation build
"_build",
"site",
# Mobile development
".expo",
".flutter",
# Package managers
"vendor",
}
def is_dangerous_path(path: Path) -> bool:
"""
Check if a path is in the dangerous paths list.
Args:
path: Path to check
Returns:
True if the path is dangerous and should not be accessed
"""
try:
resolved = path.resolve()
return str(resolved) in DANGEROUS_PATHS or resolved.parent == resolved
except Exception:
return True # If we can't resolve, consider it dangerous

113
utils/storage_backend.py Normal file
View file

@ -0,0 +1,113 @@
"""
In-memory storage backend for conversation threads
This module provides a thread-safe, in-memory alternative to Redis for storing
conversation contexts. It's designed for ephemeral MCP server sessions where
conversations only need to persist during a single Claude session.
PROCESS-SPECIFIC STORAGE: This storage is confined to a single Python process.
Data stored in one process is NOT accessible from other processes or subprocesses.
This is why simulator tests that run server.py as separate subprocesses cannot
share conversation state between tool calls.
Key Features:
- Thread-safe operations using locks
- TTL support with automatic expiration
- Background cleanup thread for memory management
- Singleton pattern for consistent state within a single process
- Drop-in replacement for Redis storage (for single-process scenarios)
"""
import logging
import os
import threading
import time
from typing import Optional
logger = logging.getLogger(__name__)
class InMemoryStorage:
"""Thread-safe in-memory storage for conversation threads"""
def __init__(self):
self._store: dict[str, tuple[str, float]] = {}
self._lock = threading.Lock()
# Match Redis behavior: cleanup interval based on conversation timeout
# Run cleanup at 1/10th of timeout interval (e.g., 18 mins for 3 hour timeout)
timeout_hours = int(os.getenv("CONVERSATION_TIMEOUT_HOURS", "3"))
self._cleanup_interval = (timeout_hours * 3600) // 10
self._cleanup_interval = max(300, self._cleanup_interval) # Minimum 5 minutes
self._shutdown = False
# Start background cleanup thread
self._cleanup_thread = threading.Thread(target=self._cleanup_worker, daemon=True)
self._cleanup_thread.start()
logger.info(
f"In-memory storage initialized with {timeout_hours}h timeout, cleanup every {self._cleanup_interval//60}m"
)
def set_with_ttl(self, key: str, ttl_seconds: int, value: str) -> None:
"""Store value with expiration time"""
with self._lock:
expires_at = time.time() + ttl_seconds
self._store[key] = (value, expires_at)
logger.debug(f"Stored key {key} with TTL {ttl_seconds}s")
def get(self, key: str) -> Optional[str]:
"""Retrieve value if not expired"""
with self._lock:
if key in self._store:
value, expires_at = self._store[key]
if time.time() < expires_at:
logger.debug(f"Retrieved key {key}")
return value
else:
# Clean up expired entry
del self._store[key]
logger.debug(f"Key {key} expired and removed")
return None
def setex(self, key: str, ttl_seconds: int, value: str) -> None:
"""Redis-compatible setex method"""
self.set_with_ttl(key, ttl_seconds, value)
def _cleanup_worker(self):
"""Background thread that periodically cleans up expired entries"""
while not self._shutdown:
time.sleep(self._cleanup_interval)
self._cleanup_expired()
def _cleanup_expired(self):
"""Remove all expired entries"""
with self._lock:
current_time = time.time()
expired_keys = [k for k, (_, exp) in self._store.items() if exp < current_time]
for key in expired_keys:
del self._store[key]
if expired_keys:
logger.debug(f"Cleaned up {len(expired_keys)} expired conversation threads")
def shutdown(self):
"""Graceful shutdown of background thread"""
self._shutdown = True
if self._cleanup_thread.is_alive():
self._cleanup_thread.join(timeout=1)
# Global singleton instance
_storage_instance = None
_storage_lock = threading.Lock()
def get_storage_backend() -> InMemoryStorage:
"""Get the global storage instance (singleton pattern)"""
global _storage_instance
if _storage_instance is None:
with _storage_lock:
if _storage_instance is None:
_storage_instance = InMemoryStorage()
logger.info("Initialized in-memory conversation storage")
return _storage_instance

54
utils/token_utils.py Normal file
View file

@ -0,0 +1,54 @@
"""
Token counting utilities for managing API context limits
This module provides functions for estimating token counts to ensure
requests stay within the Gemini API's context window limits.
Note: The estimation uses a simple character-to-token ratio which is
approximate. For production systems requiring precise token counts,
consider using the actual tokenizer for the specific model.
"""
# Default fallback for token limit (conservative estimate)
DEFAULT_CONTEXT_WINDOW = 200_000 # Conservative fallback for unknown models
def estimate_tokens(text: str) -> int:
"""
Estimate token count using a character-based approximation.
This uses a rough heuristic where 1 token 4 characters, which is
a reasonable approximation for English text. The actual token count
may vary based on:
- Language (non-English text may have different ratios)
- Code vs prose (code often has more tokens per character)
- Special characters and formatting
Args:
text: The text to estimate tokens for
Returns:
int: Estimated number of tokens
"""
return len(text) // 4
def check_token_limit(text: str, context_window: int = DEFAULT_CONTEXT_WINDOW) -> tuple[bool, int]:
"""
Check if text exceeds the specified token limit.
This function is used to validate that prepared prompts will fit
within the model's context window, preventing API errors and ensuring
reliable operation.
Args:
text: The text to check
context_window: The model's context window size (defaults to conservative fallback)
Returns:
Tuple[bool, int]: (is_within_limit, estimated_tokens)
- is_within_limit: True if the text fits within context_window
- estimated_tokens: The estimated token count
"""
estimated = estimate_tokens(text)
return estimated <= context_window, estimated