Initial commit: Zen-Marketing MCP Server v0.1.0
- Core architecture from zen-mcp-server - OpenRouter and Gemini provider configuration - Content variant generator tool (first marketing tool) - Chat tool for marketing strategy - Version and model listing tools - Configuration system with .env support - Logging infrastructure - Ready for Claude Desktop integration
This commit is contained in:
commit
371806488d
56 changed files with 16196 additions and 0 deletions
40
.env.example
Normal file
40
.env.example
Normal file
|
|
@ -0,0 +1,40 @@
|
||||||
|
# Zen-Marketing MCP Server Configuration
|
||||||
|
|
||||||
|
# API Keys
|
||||||
|
# Required: At least one API key must be configured
|
||||||
|
OPENROUTER_API_KEY=your-openrouter-api-key-here
|
||||||
|
GEMINI_API_KEY=your-gemini-api-key-here
|
||||||
|
|
||||||
|
# Model Configuration
|
||||||
|
# DEFAULT_MODEL: Primary model for analytical and strategic work
|
||||||
|
# Options: google/gemini-2.5-pro-latest, minimax/minimax-m2, or any OpenRouter model
|
||||||
|
DEFAULT_MODEL=google/gemini-2.5-pro-latest
|
||||||
|
|
||||||
|
# FAST_MODEL: Model for quick generation (subject lines, variations)
|
||||||
|
FAST_MODEL=google/gemini-2.5-flash-preview-09-2025
|
||||||
|
|
||||||
|
# CREATIVE_MODEL: Model for creative content generation
|
||||||
|
CREATIVE_MODEL=minimax/minimax-m2
|
||||||
|
|
||||||
|
# Web Search
|
||||||
|
# Enable web search for fact-checking and current information
|
||||||
|
ENABLE_WEB_SEARCH=true
|
||||||
|
|
||||||
|
# Tool Configuration
|
||||||
|
# Comma-separated list of tools to disable
|
||||||
|
# Available tools: contentvariant, platformadapt, subjectlines, styleguide,
|
||||||
|
# seooptimize, guestedit, linkstrategy, factcheck,
|
||||||
|
# voiceanalysis, campaignmap, chat, thinkdeep, planner
|
||||||
|
DISABLED_TOOLS=
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
# Options: DEBUG, INFO, WARNING, ERROR
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
|
||||||
|
# Language/Locale
|
||||||
|
# Leave empty for English, or specify locale (e.g., fr-FR, es-ES, de-DE)
|
||||||
|
LOCALE=
|
||||||
|
|
||||||
|
# Thinking Mode for Deep Analysis
|
||||||
|
# Options: minimal, low, medium, high, max
|
||||||
|
DEFAULT_THINKING_MODE_THINKDEEP=high
|
||||||
190
.gitignore
vendored
Normal file
190
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,190 @@
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
.pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
pdm.lock
|
||||||
|
|
||||||
|
# PEP 582
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.env~
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# VS Code
|
||||||
|
.vscode/
|
||||||
|
|
||||||
|
# macOS
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# API Keys and secrets
|
||||||
|
*.key
|
||||||
|
*.pem
|
||||||
|
.env.local
|
||||||
|
.env.*.local
|
||||||
|
|
||||||
|
# Test outputs
|
||||||
|
test_output/
|
||||||
|
*.test.log
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
coverage.xml
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Test simulation artifacts (dynamically created during testing)
|
||||||
|
test_simulation_files/.claude/
|
||||||
|
|
||||||
|
# Temporary test directories
|
||||||
|
test-setup/
|
||||||
|
|
||||||
|
# Scratch feature documentation files
|
||||||
|
FEATURE_*.md
|
||||||
|
# Temporary files
|
||||||
|
/tmp/
|
||||||
|
|
||||||
|
# Local user instructions
|
||||||
|
CLAUDE.local.md
|
||||||
|
|
||||||
|
# Claude Code personal settings
|
||||||
|
.claude/settings.local.json
|
||||||
|
|
||||||
|
# Standalone mode files
|
||||||
|
.zen_venv/
|
||||||
|
.docker_cleaned
|
||||||
|
logs/
|
||||||
|
*.backup
|
||||||
|
/.desktop_configured
|
||||||
|
|
||||||
|
/worktrees/
|
||||||
|
test_simulation_files/
|
||||||
|
.mcp.json
|
||||||
566
CLAUDE.md
Normal file
566
CLAUDE.md
Normal file
|
|
@ -0,0 +1,566 @@
|
||||||
|
# Claude Development Guide for Zen-Marketing MCP Server
|
||||||
|
|
||||||
|
This file contains essential commands and workflows for developing the Zen-Marketing MCP Server - a specialized marketing-focused fork of Zen MCP Server.
|
||||||
|
|
||||||
|
## Project Context
|
||||||
|
|
||||||
|
**What is Zen-Marketing?**
|
||||||
|
A Claude Desktop MCP server providing AI-powered marketing tools focused on:
|
||||||
|
- Content variation generation for A/B testing
|
||||||
|
- Cross-platform content adaptation
|
||||||
|
- Writing style enforcement
|
||||||
|
- SEO optimization for WordPress
|
||||||
|
- Guest content editing with voice preservation
|
||||||
|
- Technical fact verification
|
||||||
|
- Internal linking strategy
|
||||||
|
- Multi-channel campaign planning
|
||||||
|
|
||||||
|
**Target User:** Solo marketing professionals managing technical B2B content, particularly in industries like HVAC, SaaS, and technical education.
|
||||||
|
|
||||||
|
**Key Difference from Zen Code:** This is for marketing/content work, not software development. Tools generate content variations, enforce writing styles, and optimize for platforms like LinkedIn, newsletters, and WordPress - not code review or debugging.
|
||||||
|
|
||||||
|
## Quick Reference Commands
|
||||||
|
|
||||||
|
### Initial Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Navigate to project directory
|
||||||
|
cd ~/mcp/zen-marketing
|
||||||
|
|
||||||
|
# Copy core files from zen-mcp-server (if starting fresh)
|
||||||
|
# We'll do this in the new session
|
||||||
|
|
||||||
|
# Create virtual environment
|
||||||
|
python3 -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Install dependencies (once requirements.txt is created)
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Create .env file
|
||||||
|
cp .env.example .env
|
||||||
|
# Edit .env with your API keys
|
||||||
|
```
|
||||||
|
|
||||||
|
### Development Workflow
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Activate environment
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Run code quality checks (once implemented)
|
||||||
|
./code_quality_checks.sh
|
||||||
|
|
||||||
|
# Run server locally for testing
|
||||||
|
python server.py
|
||||||
|
|
||||||
|
# View logs
|
||||||
|
tail -f logs/mcp_server.log
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
python -m pytest tests/ -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Claude Desktop Configuration
|
||||||
|
|
||||||
|
Add to `~/.claude.json`:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"zen-marketing": {
|
||||||
|
"command": "/home/ben/mcp/zen-marketing/.venv/bin/python",
|
||||||
|
"args": ["/home/ben/mcp/zen-marketing/server.py"],
|
||||||
|
"env": {
|
||||||
|
"OPENROUTER_API_KEY": "your-openrouter-key",
|
||||||
|
"GEMINI_API_KEY": "your-gemini-key",
|
||||||
|
"DEFAULT_MODEL": "gemini-2.5-pro",
|
||||||
|
"FAST_MODEL": "gemini-flash",
|
||||||
|
"CREATIVE_MODEL": "minimax-m2",
|
||||||
|
"ENABLE_WEB_SEARCH": "true",
|
||||||
|
"DISABLED_TOOLS": "",
|
||||||
|
"LOG_LEVEL": "INFO"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**After modifying config:** Restart Claude Desktop for changes to take effect.
|
||||||
|
|
||||||
|
## Tool Development Guidelines
|
||||||
|
|
||||||
|
### Tool Categories
|
||||||
|
|
||||||
|
**Simple Tools** (single-shot, fast response):
|
||||||
|
- Inherit from `SimpleTool` base class
|
||||||
|
- Focus on speed and iteration
|
||||||
|
- Examples: `contentvariant`, `platformadapt`, `subjectlines`, `factcheck`
|
||||||
|
- Use fast models (gemini-flash) when possible
|
||||||
|
|
||||||
|
**Workflow Tools** (multi-step processes):
|
||||||
|
- Inherit from `WorkflowTool` base class
|
||||||
|
- Systematic step-by-step workflows
|
||||||
|
- Track progress, confidence, findings
|
||||||
|
- Examples: `styleguide`, `seooptimize`, `guestedit`, `linkstrategy`
|
||||||
|
|
||||||
|
### Temperature Guidelines for Marketing Tools
|
||||||
|
|
||||||
|
- **High (0.7-0.8)**: Content variation, creative adaptation
|
||||||
|
- **Medium (0.5-0.6)**: Balanced tasks, campaign planning
|
||||||
|
- **Low (0.3-0.4)**: Analytical work, SEO optimization
|
||||||
|
- **Very Low (0.2)**: Fact-checking, technical verification
|
||||||
|
|
||||||
|
### Model Selection Strategy
|
||||||
|
|
||||||
|
**Gemini 2.5 Pro** (`gemini-2.5-pro`):
|
||||||
|
- Analytical and strategic work
|
||||||
|
- SEO optimization
|
||||||
|
- Guest editing
|
||||||
|
- Internal linking analysis
|
||||||
|
- Voice analysis
|
||||||
|
- Campaign planning
|
||||||
|
- Fact-checking
|
||||||
|
|
||||||
|
**Gemini Flash** (`gemini-flash`):
|
||||||
|
- Fast bulk generation
|
||||||
|
- Subject line creation
|
||||||
|
- Quick variations
|
||||||
|
- Cost-effective iterations
|
||||||
|
|
||||||
|
**Minimax M2** (`minimax-m2`):
|
||||||
|
- Creative content generation
|
||||||
|
- Platform adaptation
|
||||||
|
- Content repurposing
|
||||||
|
- Marketing copy variations
|
||||||
|
|
||||||
|
### System Prompt Best Practices
|
||||||
|
|
||||||
|
Marketing tool prompts should:
|
||||||
|
1. **Specify output format clearly** (JSON, markdown, numbered list)
|
||||||
|
2. **Include platform constraints** (character limits, formatting rules)
|
||||||
|
3. **Emphasize preservation** (voice, expertise, technical accuracy)
|
||||||
|
4. **Request rationale** (why certain variations work, what to test)
|
||||||
|
5. **Avoid code terminology** (use "content" not "implementation")
|
||||||
|
|
||||||
|
Example prompt structure:
|
||||||
|
```python
|
||||||
|
CONTENTVARIANT_PROMPT = """
|
||||||
|
You are a marketing content strategist specializing in A/B testing and variation generation.
|
||||||
|
|
||||||
|
TASK: Generate multiple variations of marketing content for testing different approaches.
|
||||||
|
|
||||||
|
OUTPUT FORMAT:
|
||||||
|
Return variations as numbered list, each with:
|
||||||
|
1. The variation text
|
||||||
|
2. The testing angle (what makes it different)
|
||||||
|
3. Predicted audience response
|
||||||
|
|
||||||
|
CONSTRAINTS:
|
||||||
|
- Maintain core message across variations
|
||||||
|
- Respect platform character limits if specified
|
||||||
|
- Preserve brand voice characteristics
|
||||||
|
- Generate genuinely different approaches, not just word swaps
|
||||||
|
|
||||||
|
VARIATION TYPES:
|
||||||
|
- Hook variations: Different opening angles
|
||||||
|
- Length variations: Short, medium, long
|
||||||
|
- Tone variations: Professional, conversational, urgent
|
||||||
|
- Structure variations: Question, statement, story
|
||||||
|
- CTA variations: Different calls-to-action
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
## Implementation Phases
|
||||||
|
|
||||||
|
### Phase 1: Foundation ✓ (You Are Here)
|
||||||
|
- [x] Create project directory
|
||||||
|
- [x] Write implementation plan (PLAN.md)
|
||||||
|
- [x] Create development guide (CLAUDE.md)
|
||||||
|
- [ ] Copy core architecture from zen-mcp-server
|
||||||
|
- [ ] Configure minimax provider
|
||||||
|
- [ ] Remove code-specific tools
|
||||||
|
- [ ] Test basic chat functionality
|
||||||
|
|
||||||
|
### Phase 2: Simple Tools (Priority: High)
|
||||||
|
Implementation order based on real-world usage:
|
||||||
|
1. **`contentvariant`** - Most frequently used (subject lines, social posts)
|
||||||
|
2. **`subjectlines`** - Specific workflow mentioned in project memories
|
||||||
|
3. **`platformadapt`** - Multi-channel content distribution
|
||||||
|
4. **`factcheck`** - Technical accuracy verification
|
||||||
|
|
||||||
|
### Phase 3: Workflow Tools (Priority: Medium)
|
||||||
|
5. **`styleguide`** - Writing rule enforcement (no em-dashes, etc.)
|
||||||
|
6. **`seooptimize`** - WordPress SEO optimization
|
||||||
|
7. **`guestedit`** - Guest content editing workflow
|
||||||
|
8. **`linkstrategy`** - Internal linking and cross-platform integration
|
||||||
|
|
||||||
|
### Phase 4: Advanced Features (Priority: Lower)
|
||||||
|
9. **`voiceanalysis`** - Voice extraction and consistency checking
|
||||||
|
10. **`campaignmap`** - Multi-touch campaign planning
|
||||||
|
|
||||||
|
## Tool Implementation Checklist
|
||||||
|
|
||||||
|
For each new tool:
|
||||||
|
|
||||||
|
**Code Files:**
|
||||||
|
- [ ] Create tool file in `tools/` (e.g., `tools/contentvariant.py`)
|
||||||
|
- [ ] Create system prompt in `systemprompts/` (e.g., `systemprompts/contentvariant_prompt.py`)
|
||||||
|
- [ ] Create test file in `tests/` (e.g., `tests/test_contentvariant.py`)
|
||||||
|
- [ ] Register tool in `server.py`
|
||||||
|
|
||||||
|
**Tool Class Requirements:**
|
||||||
|
- [ ] Inherit from `SimpleTool` or `WorkflowTool`
|
||||||
|
- [ ] Implement `get_name()` - tool name
|
||||||
|
- [ ] Implement `get_description()` - what it does
|
||||||
|
- [ ] Implement `get_system_prompt()` - behavior instructions
|
||||||
|
- [ ] Implement `get_default_temperature()` - creativity level
|
||||||
|
- [ ] Implement `get_model_category()` - FAST_RESPONSE or DEEP_THINKING
|
||||||
|
- [ ] Implement `get_request_model()` - Pydantic request schema
|
||||||
|
- [ ] Implement `get_input_schema()` - MCP tool schema
|
||||||
|
- [ ] Implement request/response formatting hooks
|
||||||
|
|
||||||
|
**Testing:**
|
||||||
|
- [ ] Unit tests for request validation
|
||||||
|
- [ ] Unit tests for response formatting
|
||||||
|
- [ ] Integration test with real model (optional)
|
||||||
|
- [ ] Add to quality checks script
|
||||||
|
|
||||||
|
**Documentation:**
|
||||||
|
- [ ] Add tool to README.md
|
||||||
|
- [ ] Create examples in docs/tools/
|
||||||
|
- [ ] Update PLAN.md progress
|
||||||
|
|
||||||
|
## Common Development Tasks
|
||||||
|
|
||||||
|
### Adding a New Simple Tool
|
||||||
|
|
||||||
|
```python
|
||||||
|
# tools/mynewtool.py
|
||||||
|
from typing import Optional
|
||||||
|
from pydantic import Field
|
||||||
|
from tools.shared.base_models import ToolRequest
|
||||||
|
from .simple.base import SimpleTool
|
||||||
|
from systemprompts import MYNEWTOOL_PROMPT
|
||||||
|
from config import TEMPERATURE_BALANCED
|
||||||
|
|
||||||
|
class MyNewToolRequest(ToolRequest):
|
||||||
|
"""Request model for MyNewTool"""
|
||||||
|
prompt: str = Field(..., description="What you want to accomplish")
|
||||||
|
files: Optional[list[str]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
class MyNewTool(SimpleTool):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "mynewtool"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Brief description of what this tool does"
|
||||||
|
|
||||||
|
def get_system_prompt(self) -> str:
|
||||||
|
return MYNEWTOOL_PROMPT
|
||||||
|
|
||||||
|
def get_default_temperature(self) -> float:
|
||||||
|
return TEMPERATURE_BALANCED
|
||||||
|
|
||||||
|
def get_model_category(self) -> "ToolModelCategory":
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
return ToolModelCategory.FAST_RESPONSE
|
||||||
|
|
||||||
|
def get_request_model(self):
|
||||||
|
return MyNewToolRequest
|
||||||
|
```
|
||||||
|
|
||||||
|
### Adding a New Workflow Tool
|
||||||
|
|
||||||
|
```python
|
||||||
|
# tools/mynewworkflow.py
|
||||||
|
from typing import Optional
|
||||||
|
from pydantic import Field
|
||||||
|
from tools.shared.base_models import WorkflowRequest
|
||||||
|
from .workflow.base import WorkflowTool
|
||||||
|
from systemprompts import MYNEWWORKFLOW_PROMPT
|
||||||
|
|
||||||
|
class MyNewWorkflowRequest(WorkflowRequest):
|
||||||
|
"""Request model for workflow tool"""
|
||||||
|
step: str = Field(description="Current step content")
|
||||||
|
step_number: int = Field(ge=1)
|
||||||
|
total_steps: int = Field(ge=1)
|
||||||
|
next_step_required: bool
|
||||||
|
findings: str = Field(description="What was discovered")
|
||||||
|
# Add workflow-specific fields
|
||||||
|
|
||||||
|
class MyNewWorkflow(WorkflowTool):
|
||||||
|
# Implementation similar to Simple Tool
|
||||||
|
# but with workflow-specific logic
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing a Tool Manually
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start server with debug logging
|
||||||
|
LOG_LEVEL=DEBUG python server.py
|
||||||
|
|
||||||
|
# In another terminal, watch logs
|
||||||
|
tail -f logs/mcp_server.log | grep -E "(TOOL_CALL|ERROR|MyNewTool)"
|
||||||
|
|
||||||
|
# In Claude Desktop, test:
|
||||||
|
# "Use zen-marketing to generate 10 subject lines about HVAC maintenance"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
zen-marketing/
|
||||||
|
├── server.py # Main MCP server entry point
|
||||||
|
├── config.py # Configuration constants
|
||||||
|
├── PLAN.md # Implementation plan (this doc)
|
||||||
|
├── CLAUDE.md # Development guide
|
||||||
|
├── README.md # User-facing documentation
|
||||||
|
├── requirements.txt # Python dependencies
|
||||||
|
├── .env.example # Environment variable template
|
||||||
|
├── .env # Local config (gitignored)
|
||||||
|
├── run-server.sh # Setup and run script
|
||||||
|
├── code_quality_checks.sh # Linting and testing
|
||||||
|
│
|
||||||
|
├── tools/ # Tool implementations
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── contentvariant.py # Bulk variation generator
|
||||||
|
│ ├── platformadapt.py # Cross-platform adapter
|
||||||
|
│ ├── subjectlines.py # Email subject line generator
|
||||||
|
│ ├── styleguide.py # Writing style enforcer
|
||||||
|
│ ├── seooptimize.py # SEO optimizer
|
||||||
|
│ ├── guestedit.py # Guest content editor
|
||||||
|
│ ├── linkstrategy.py # Internal linking strategist
|
||||||
|
│ ├── factcheck.py # Technical fact checker
|
||||||
|
│ ├── voiceanalysis.py # Voice extractor/validator
|
||||||
|
│ ├── campaignmap.py # Campaign planner
|
||||||
|
│ ├── chat.py # General chat (from zen)
|
||||||
|
│ ├── thinkdeep.py # Deep thinking (from zen)
|
||||||
|
│ ├── planner.py # Planning (from zen)
|
||||||
|
│ ├── models.py # Shared models
|
||||||
|
│ ├── simple/ # Simple tool base classes
|
||||||
|
│ │ └── base.py
|
||||||
|
│ ├── workflow/ # Workflow tool base classes
|
||||||
|
│ │ └── base.py
|
||||||
|
│ └── shared/ # Shared utilities
|
||||||
|
│ └── base_models.py
|
||||||
|
│
|
||||||
|
├── providers/ # AI provider implementations
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── base.py # Base provider interface
|
||||||
|
│ ├── gemini.py # Google Gemini
|
||||||
|
│ ├── minimax.py # Minimax (NEW)
|
||||||
|
│ ├── openrouter.py # OpenRouter fallback
|
||||||
|
│ ├── registry.py # Provider registry
|
||||||
|
│ └── shared/
|
||||||
|
│
|
||||||
|
├── systemprompts/ # System prompts for tools
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── contentvariant_prompt.py
|
||||||
|
│ ├── platformadapt_prompt.py
|
||||||
|
│ ├── subjectlines_prompt.py
|
||||||
|
│ ├── styleguide_prompt.py
|
||||||
|
│ ├── seooptimize_prompt.py
|
||||||
|
│ ├── guestedit_prompt.py
|
||||||
|
│ ├── linkstrategy_prompt.py
|
||||||
|
│ ├── factcheck_prompt.py
|
||||||
|
│ ├── voiceanalysis_prompt.py
|
||||||
|
│ ├── campaignmap_prompt.py
|
||||||
|
│ ├── chat_prompt.py # From zen
|
||||||
|
│ ├── thinkdeep_prompt.py # From zen
|
||||||
|
│ └── planner_prompt.py # From zen
|
||||||
|
│
|
||||||
|
├── utils/ # Utility functions
|
||||||
|
│ ├── conversation_memory.py # Conversation continuity
|
||||||
|
│ ├── file_utils.py # File handling
|
||||||
|
│ └── web_search.py # Web search integration
|
||||||
|
│
|
||||||
|
├── tests/ # Test suite
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── test_contentvariant.py
|
||||||
|
│ ├── test_platformadapt.py
|
||||||
|
│ ├── test_subjectlines.py
|
||||||
|
│ └── ...
|
||||||
|
│
|
||||||
|
├── logs/ # Log files (gitignored)
|
||||||
|
│ ├── mcp_server.log
|
||||||
|
│ └── mcp_activity.log
|
||||||
|
│
|
||||||
|
└── docs/ # Documentation
|
||||||
|
├── getting-started.md
|
||||||
|
├── tools/
|
||||||
|
│ ├── contentvariant.md
|
||||||
|
│ ├── platformadapt.md
|
||||||
|
│ └── ...
|
||||||
|
└── examples/
|
||||||
|
└── marketing-workflows.md
|
||||||
|
```
|
||||||
|
|
||||||
|
## Key Concepts from Zen Architecture
|
||||||
|
|
||||||
|
### Conversation Continuity
|
||||||
|
Every tool supports `continuation_id` to maintain context across interactions:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# First call
|
||||||
|
result1 = await tool.execute({
|
||||||
|
"prompt": "Analyze this brand voice",
|
||||||
|
"files": ["brand_samples/post1.txt", "brand_samples/post2.txt"]
|
||||||
|
})
|
||||||
|
# Returns: continuation_id: "abc123"
|
||||||
|
|
||||||
|
# Follow-up call (remembers previous context)
|
||||||
|
result2 = await tool.execute({
|
||||||
|
"prompt": "Now check if this new draft matches the voice",
|
||||||
|
"files": ["new_draft.txt"],
|
||||||
|
"continuation_id": "abc123" # Preserves context
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### File Handling
|
||||||
|
Tools automatically:
|
||||||
|
- Expand directories to individual files
|
||||||
|
- Deduplicate file lists
|
||||||
|
- Handle absolute paths
|
||||||
|
- Process images (screenshots, brand assets)
|
||||||
|
|
||||||
|
### Web Search Integration
|
||||||
|
Tools can request Claude to perform web searches:
|
||||||
|
```python
|
||||||
|
# In system prompt:
|
||||||
|
"If you need current information about [topic], request a web search from Claude."
|
||||||
|
|
||||||
|
# Claude will then use WebSearch tool and provide results
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multi-Model Orchestration
|
||||||
|
Tools specify model category, server selects best available:
|
||||||
|
- `FAST_RESPONSE` → gemini-flash or equivalent
|
||||||
|
- `DEEP_THINKING` → gemini-2.5-pro or equivalent
|
||||||
|
- User can override with `model` parameter
|
||||||
|
|
||||||
|
## Debugging Common Issues
|
||||||
|
|
||||||
|
### Tool Not Appearing in Claude Desktop
|
||||||
|
1. Check `server.py` registers the tool
|
||||||
|
2. Verify tool is not in `DISABLED_TOOLS` env var
|
||||||
|
3. Restart Claude Desktop after config changes
|
||||||
|
4. Check logs: `tail -f logs/mcp_server.log`
|
||||||
|
|
||||||
|
### Model Selection Issues
|
||||||
|
1. Verify API keys in `.env`
|
||||||
|
2. Check provider registration in `providers/registry.py`
|
||||||
|
3. Test with explicit model name: `"model": "gemini-2.5-pro"`
|
||||||
|
4. Check logs for provider errors
|
||||||
|
|
||||||
|
### Response Formatting Issues
|
||||||
|
1. Validate system prompt specifies output format
|
||||||
|
2. Check response doesn't exceed token limits
|
||||||
|
3. Test with simpler input first
|
||||||
|
4. Review logs for truncation warnings
|
||||||
|
|
||||||
|
### Conversation Continuity Not Working
|
||||||
|
1. Verify `continuation_id` is being passed correctly
|
||||||
|
2. Check conversation hasn't expired (default 6 hours)
|
||||||
|
3. Validate conversation memory storage
|
||||||
|
4. Review logs: `grep "continuation_id" logs/mcp_server.log`
|
||||||
|
|
||||||
|
## Code Quality Standards
|
||||||
|
|
||||||
|
Before committing:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all quality checks
|
||||||
|
./code_quality_checks.sh
|
||||||
|
|
||||||
|
# Manual checks:
|
||||||
|
ruff check . --fix # Linting
|
||||||
|
black . # Formatting
|
||||||
|
isort . # Import sorting
|
||||||
|
pytest tests/ -v # Run tests
|
||||||
|
```
|
||||||
|
|
||||||
|
## Marketing-Specific Considerations
|
||||||
|
|
||||||
|
### Character Limits by Platform
|
||||||
|
Tools should be aware of:
|
||||||
|
- **Twitter/Bluesky**: 280 characters
|
||||||
|
- **LinkedIn**: 3000 chars (1300 optimal)
|
||||||
|
- **Instagram**: 2200 characters
|
||||||
|
- **Facebook**: No hard limit (500 chars optimal)
|
||||||
|
- **Email subject**: 60 characters optimal
|
||||||
|
- **Email preview**: 90-100 characters
|
||||||
|
- **Meta description**: 156 characters
|
||||||
|
- **Page title**: 60 characters
|
||||||
|
|
||||||
|
### Writing Style Rules from Project Memories
|
||||||
|
- No em-dashes (use periods or semicolons)
|
||||||
|
- No "This isn't X, it's Y" constructions
|
||||||
|
- Direct affirmative statements over negations
|
||||||
|
- Semantic variety in paragraph openings
|
||||||
|
- Concrete metrics over abstract claims
|
||||||
|
- Technical accuracy preserved
|
||||||
|
- Author voice maintained
|
||||||
|
|
||||||
|
### Testing Angles for Variations
|
||||||
|
- Technical curiosity
|
||||||
|
- Contrarian/provocative
|
||||||
|
- Knowledge gap emphasis
|
||||||
|
- Urgency/timeliness
|
||||||
|
- Insider knowledge positioning
|
||||||
|
- Problem-solution framing
|
||||||
|
- Before-after transformation
|
||||||
|
- Social proof/credibility
|
||||||
|
- FOMO (fear of missing out)
|
||||||
|
- Educational value
|
||||||
|
|
||||||
|
## Next Session Goals
|
||||||
|
|
||||||
|
When you start the new session in `~/mcp/zen-marketing/`:
|
||||||
|
|
||||||
|
1. **Copy Core Files from Zen**
|
||||||
|
- Copy base architecture preserving git history
|
||||||
|
- Remove code-specific tools
|
||||||
|
- Update imports and references
|
||||||
|
|
||||||
|
2. **Configure Minimax Provider**
|
||||||
|
- Add minimax support to providers/
|
||||||
|
- Register in provider registry
|
||||||
|
- Test basic model calls
|
||||||
|
|
||||||
|
3. **Implement First Simple Tool**
|
||||||
|
- Start with `contentvariant` (highest priority)
|
||||||
|
- Create tool, system prompt, and tests
|
||||||
|
- Test end-to-end with Claude Desktop
|
||||||
|
|
||||||
|
4. **Validate Architecture**
|
||||||
|
- Ensure conversation continuity works
|
||||||
|
- Verify file handling
|
||||||
|
- Test web search integration
|
||||||
|
|
||||||
|
## Questions to Consider
|
||||||
|
|
||||||
|
Before implementing each tool:
|
||||||
|
1. What real-world workflow does this solve? (Reference project memories)
|
||||||
|
2. What's the minimum viable version?
|
||||||
|
3. What can go wrong? (Character limits, API errors, invalid input)
|
||||||
|
4. How will users test variations? (Output format)
|
||||||
|
5. Does it need web search? (Current info, fact-checking)
|
||||||
|
6. What's the right temperature? (Creative vs analytical)
|
||||||
|
7. Simple or workflow tool? (Single-shot vs multi-step)
|
||||||
|
|
||||||
|
## Resources
|
||||||
|
|
||||||
|
- **Zen MCP Server Repo**: Source for architecture and patterns
|
||||||
|
- **MCP Protocol Docs**: https://modelcontextprotocol.com
|
||||||
|
- **Claude Desktop Config**: `~/.claude.json`
|
||||||
|
- **Project Memories**: See PLAN.md for user workflow examples
|
||||||
|
- **Platform Best Practices**: Research current 2025 guidelines
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Ready to build?** Start the new session with:
|
||||||
|
```bash
|
||||||
|
cd ~/mcp/zen-marketing
|
||||||
|
# Then ask Claude to begin Phase 1 implementation
|
||||||
|
```
|
||||||
197
LICENSE
Normal file
197
LICENSE
Normal file
|
|
@ -0,0 +1,197 @@
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship covered by this License,
|
||||||
|
whether in source or binary form, which is made available under the
|
||||||
|
License, as indicated by a copyright notice that is included in or
|
||||||
|
attached to the work. (The copyright notice requirement does not
|
||||||
|
apply to derivative works of the License holder.)
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based upon (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and derivative works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control
|
||||||
|
systems, and issue tracking systems that are managed by, or on behalf
|
||||||
|
of, the Licensor for the purpose of discussing and improving the Work,
|
||||||
|
but excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to use, reproduce, modify, distribute, and otherwise
|
||||||
|
transfer the Work as part of a Derivative Work.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright notice to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Support. You may choose to offer, and to
|
||||||
|
charge a fee for, warranty, support, indemnity or other liability
|
||||||
|
obligations and/or rights consistent with this License. However,
|
||||||
|
in accepting such obligations, You may act only on Your own behalf
|
||||||
|
and on Your sole responsibility, not on behalf of any other
|
||||||
|
Contributor, and only if You agree to indemnify, defend, and hold
|
||||||
|
each Contributor harmless for any liability incurred by, or claims
|
||||||
|
asserted against, such Contributor by reason of your accepting any
|
||||||
|
such warranty or support.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in comments for the
|
||||||
|
particular file format. An identification line is also useful.
|
||||||
|
|
||||||
|
Copyright 2025 Beehive Innovations
|
||||||
|
https://github.com/BeehiveInnovations
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
513
PLAN.md
Normal file
513
PLAN.md
Normal file
|
|
@ -0,0 +1,513 @@
|
||||||
|
# Zen-Marketing MCP Server - Implementation Plan
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
A specialized MCP server for Claude Desktop focused on marketing workflows, derived from the Zen MCP Server codebase. Designed for solo marketing professionals working in technical B2B content, particularly HVAC/technical education and software marketing.
|
||||||
|
|
||||||
|
**Target User Profile:** Marketing professional managing:
|
||||||
|
- Technical newsletters with guest contributors
|
||||||
|
- Multi-platform social media campaigns
|
||||||
|
- Educational content with SEO optimization
|
||||||
|
- A/B testing and content variation generation
|
||||||
|
- WordPress content management with internal linking
|
||||||
|
- Brand voice consistency across channels
|
||||||
|
|
||||||
|
## Core Principles
|
||||||
|
|
||||||
|
1. **Variation at Scale**: Generate 5-20 variations of content for A/B testing
|
||||||
|
2. **Platform Intelligence**: Understand character limits, tone, and best practices per platform
|
||||||
|
3. **Technical Accuracy**: Verify facts via web search before publishing
|
||||||
|
4. **Voice Preservation**: Maintain authentic voices (guest authors, brand persona)
|
||||||
|
5. **Cross-Platform Integration**: Connect content across blog, social, email, video
|
||||||
|
6. **Style Enforcement**: Apply specific writing guidelines automatically
|
||||||
|
|
||||||
|
## Planned Tools
|
||||||
|
|
||||||
|
### 1. `contentvariant` (Simple Tool)
|
||||||
|
**Purpose:** Rapid generation of multiple content variations for testing
|
||||||
|
|
||||||
|
**Real-world use case from project memory:**
|
||||||
|
- "Generate 15 email subject line variations testing different angles: technical curiosity, contrarian statements, FOMO, educational value"
|
||||||
|
- "Create 10 LinkedIn post variations of this announcement, mixing lengths and hooks"
|
||||||
|
|
||||||
|
**Key features:**
|
||||||
|
- Generate 5-20 variations in one call
|
||||||
|
- Specify variation dimensions (tone, length, hook, CTA placement)
|
||||||
|
- Fast model (gemini-flash) for speed
|
||||||
|
- Output formatted for easy copy-paste testing
|
||||||
|
|
||||||
|
**Parameters:**
|
||||||
|
- `content` (str): Base content to vary
|
||||||
|
- `variation_count` (int): Number of variations (default 10)
|
||||||
|
- `variation_types` (list): Types to vary - "tone", "length", "hook", "structure", "cta"
|
||||||
|
- `platform` (str): Target platform for context
|
||||||
|
- `constraints` (str): Character limits, style requirements
|
||||||
|
|
||||||
|
**Model:** gemini-flash (fast, cost-effective for bulk generation)
|
||||||
|
**Temperature:** 0.8 (creative variation while maintaining message)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. `platformadapt` (Simple Tool)
|
||||||
|
**Purpose:** Adapt single content piece across multiple social platforms
|
||||||
|
|
||||||
|
**Real-world use case from project memory:**
|
||||||
|
- "Take this blog post intro and create versions for LinkedIn (1300 chars), Twitter (280 chars), Instagram caption (2200 chars), Facebook (500 chars), and Bluesky (300 chars)"
|
||||||
|
|
||||||
|
**Key features:**
|
||||||
|
- Single source → multiple platform versions
|
||||||
|
- Respects character limits automatically
|
||||||
|
- Platform-specific best practices (hashtags, tone, formatting)
|
||||||
|
- Preserves core message across adaptations
|
||||||
|
|
||||||
|
**Parameters:**
|
||||||
|
- `source_content` (str): Original content
|
||||||
|
- `source_platform` (str): Where content originated
|
||||||
|
- `target_platforms` (list): ["linkedin", "twitter", "instagram", "facebook", "bluesky"]
|
||||||
|
- `preserve_urls` (bool): Keep links intact vs. adapt
|
||||||
|
- `brand_voice` (str): Voice guidelines to maintain
|
||||||
|
|
||||||
|
**Model:** minimax/minimax-m2 (creative adaptations)
|
||||||
|
**Temperature:** 0.7
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. `subjectlines` (Simple Tool)
|
||||||
|
**Purpose:** Specialized tool for email subject line generation with angle testing
|
||||||
|
|
||||||
|
**Real-world use case from project memory:**
|
||||||
|
- "Generate subject lines testing: technical curiosity hook, contrarian/provocative angle, knowledge gap emphasis, urgency/timeliness, insider knowledge positioning"
|
||||||
|
|
||||||
|
**Key features:**
|
||||||
|
- Generates 15-25 subject lines by default
|
||||||
|
- Groups by psychological angle/hook
|
||||||
|
- Includes A/B testing rationale for each
|
||||||
|
- Character count validation (under 60 chars optimal)
|
||||||
|
- Emoji suggestions (optional)
|
||||||
|
|
||||||
|
**Parameters:**
|
||||||
|
- `email_topic` (str): Main topic/content
|
||||||
|
- `target_audience` (str): Who receives this
|
||||||
|
- `angles_to_test` (list): Psychological hooks to explore
|
||||||
|
- `include_emoji` (bool): Add emoji options
|
||||||
|
- `preview_text` (str): Optional - generate matching preview text
|
||||||
|
|
||||||
|
**Model:** gemini-flash (fast generation, creative but focused)
|
||||||
|
**Temperature:** 0.8
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4. `styleguide` (Workflow Tool)
|
||||||
|
**Purpose:** Enforce specific writing guidelines and polish content
|
||||||
|
|
||||||
|
**Real-world use case from project memory:**
|
||||||
|
- "Check this content for: em-dashes (remove), 'This isn't X, it's Y' constructions (rewrite), semantic variety in paragraph structure, direct affirmative statements instead of negations"
|
||||||
|
|
||||||
|
**Key features:**
|
||||||
|
- Multi-step workflow: detection → flagging → rewriting → validation
|
||||||
|
- Tracks style violations with severity
|
||||||
|
- Provides before/after comparisons
|
||||||
|
- Custom rule definitions
|
||||||
|
|
||||||
|
**Workflow steps:**
|
||||||
|
1. Analyze content for style violations
|
||||||
|
2. Flag issues with line numbers and severity
|
||||||
|
3. Generate corrected version
|
||||||
|
4. Validate improvements
|
||||||
|
|
||||||
|
**Parameters:**
|
||||||
|
- `content` (str): Content to check
|
||||||
|
- `rules` (list): Style rules to enforce
|
||||||
|
- `rewrite_violations` (bool): Auto-fix vs. flag only
|
||||||
|
- `preserve_technical_terms` (bool): Don't change HVAC/technical terminology
|
||||||
|
|
||||||
|
**Common rules (from project memory):**
|
||||||
|
- No em-dashes (use periods or semicolons)
|
||||||
|
- No "This isn't X, it's Y" constructions
|
||||||
|
- Direct affirmative statements over negations
|
||||||
|
- Semantic variety in paragraph openings
|
||||||
|
- No clichéd structures
|
||||||
|
- Concrete metrics over abstract claims
|
||||||
|
|
||||||
|
**Model:** gemini-2.5-pro (analytical precision)
|
||||||
|
**Temperature:** 0.3 (consistency enforcement)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 5. `seooptimize` (Workflow Tool)
|
||||||
|
**Purpose:** Comprehensive SEO optimization for WordPress content
|
||||||
|
|
||||||
|
**Real-world use case from project memory:**
|
||||||
|
- "Create SEO title under 60 chars, excerpt under 156 chars, suggest 5-7 WordPress tags, internal linking opportunities to foundational content"
|
||||||
|
|
||||||
|
**Workflow steps:**
|
||||||
|
1. Analyze content for primary keywords and topics
|
||||||
|
2. Generate SEO title (under 60 chars)
|
||||||
|
3. Create meta description/excerpt (under 156 chars)
|
||||||
|
4. Suggest WordPress tags (5-10)
|
||||||
|
5. Identify internal linking opportunities
|
||||||
|
6. Validate character limits and keyword density
|
||||||
|
|
||||||
|
**Parameters:**
|
||||||
|
- `content` (str): Full article content
|
||||||
|
- `existing_tags` (list): Current WordPress tag library
|
||||||
|
- `target_keywords` (list): Keywords to optimize for
|
||||||
|
- `internal_links_context` (str): Description of site structure for linking
|
||||||
|
- `platform` (str): "wordpress" (default), "ghost", "medium"
|
||||||
|
|
||||||
|
**Model:** gemini-2.5-pro (analytical, search-aware)
|
||||||
|
**Temperature:** 0.4
|
||||||
|
**Web search:** Enabled for keyword research
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 6. `guestedit` (Workflow Tool)
|
||||||
|
**Purpose:** Edit guest content while preserving author expertise and voice
|
||||||
|
|
||||||
|
**Real-world use case from project memory:**
|
||||||
|
- "Edit this guest article on PCB components: preserve technical authority, add key takeaways section, suggest internal links to foundational content, enhance educational flow, maintain expert's voice"
|
||||||
|
|
||||||
|
**Workflow steps:**
|
||||||
|
1. Analyze guest author's voice and technical expertise level
|
||||||
|
2. Identify areas needing clarification vs. areas to preserve
|
||||||
|
3. Generate educational enhancements (key takeaways, callouts)
|
||||||
|
4. Suggest internal linking without fabricating URLs
|
||||||
|
5. Validate technical accuracy via web search
|
||||||
|
6. Present edits with rationale
|
||||||
|
|
||||||
|
**Parameters:**
|
||||||
|
- `guest_content` (str): Original article
|
||||||
|
- `author_name` (str): Guest author for voice reference
|
||||||
|
- `expertise_level` (str): "expert", "practitioner", "educator"
|
||||||
|
- `enhancements_needed` (list): ["key_takeaways", "internal_links", "clarity", "seo"]
|
||||||
|
- `site_context` (str): Description of broader content ecosystem
|
||||||
|
- `verify_technical` (bool): Use web search for fact-checking
|
||||||
|
|
||||||
|
**Model:** gemini-2.5-pro (analytical + creative balance)
|
||||||
|
**Temperature:** 0.5
|
||||||
|
**Web search:** Enabled for technical verification
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 7. `linkstrategy` (Workflow Tool)
|
||||||
|
**Purpose:** Internal linking and cross-platform content integration strategy
|
||||||
|
|
||||||
|
**Real-world use case from project memory:**
|
||||||
|
- "Find opportunities to link this advanced PCB article to foundational HVAC knowledge. Connect blog with related podcast episode, Instagram post, and YouTube video from our content database."
|
||||||
|
|
||||||
|
**Workflow steps:**
|
||||||
|
1. Analyze content topics and technical depth
|
||||||
|
2. Identify prerequisite concepts needing internal links
|
||||||
|
3. Search for related content across platforms (blog, podcast, video, social)
|
||||||
|
4. Generate linking strategy with anchor text suggestions
|
||||||
|
5. Validate URLs against actual content (no fabrication)
|
||||||
|
6. Create cross-platform promotion plan
|
||||||
|
|
||||||
|
**Parameters:**
|
||||||
|
- `primary_content` (str): Main piece to link from/to
|
||||||
|
- `content_type` (str): "blog", "newsletter", "social_post"
|
||||||
|
- `available_platforms` (list): ["blog", "podcast", "youtube", "instagram"]
|
||||||
|
- `technical_depth` (str): "foundational", "intermediate", "advanced"
|
||||||
|
- `verify_links` (bool): Validate URLs exist before suggesting
|
||||||
|
|
||||||
|
**Model:** gemini-2.5-pro (analytical, relationship mapping)
|
||||||
|
**Temperature:** 0.4
|
||||||
|
**Web search:** Enabled for content discovery
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 8. `factcheck` (Simple Tool)
|
||||||
|
**Purpose:** Quick technical fact verification via web search
|
||||||
|
|
||||||
|
**Real-world use case from project memory:**
|
||||||
|
- "Verify these technical claims about voltage regulation chains in HVAC PCBs"
|
||||||
|
- "Check if this White-Rodgers universal control model number and compatibility claims are accurate"
|
||||||
|
|
||||||
|
**Key features:**
|
||||||
|
- Web search-powered verification
|
||||||
|
- Cites sources for each claim
|
||||||
|
- Flags unsupported or questionable statements
|
||||||
|
- Quick turnaround for pre-publish checks
|
||||||
|
|
||||||
|
**Parameters:**
|
||||||
|
- `content` (str): Content with claims to verify
|
||||||
|
- `technical_domain` (str): "hvac", "software", "general"
|
||||||
|
- `claim_type` (str): "product_specs", "technical_process", "statistics", "general"
|
||||||
|
- `confidence_threshold` (str): "high_confidence_only", "balanced", "comprehensive"
|
||||||
|
|
||||||
|
**Model:** gemini-2.5-pro (search-augmented, analytical)
|
||||||
|
**Temperature:** 0.2 (precision for facts)
|
||||||
|
**Web search:** Required (core functionality)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 9. `voiceanalysis` (Workflow Tool)
|
||||||
|
**Purpose:** Extract and codify brand/author voice for consistency
|
||||||
|
|
||||||
|
**Real-world use case from project memory:**
|
||||||
|
- "Analyze these 10 pieces of Gary McCreadie's content to extract voice patterns, then check if this new draft matches his authentic teaching voice"
|
||||||
|
|
||||||
|
**Workflow steps:**
|
||||||
|
1. Analyze sample content for voice characteristics
|
||||||
|
2. Extract patterns: sentence structure, vocabulary, tone markers
|
||||||
|
3. Create voice profile with examples
|
||||||
|
4. Compare new content against profile
|
||||||
|
5. Flag inconsistencies with suggested rewrites
|
||||||
|
6. Track confidence in voice matching
|
||||||
|
|
||||||
|
**Parameters:**
|
||||||
|
- `sample_content` (list): 5-15 pieces showing authentic voice
|
||||||
|
- `author_name` (str): Voice owner for reference
|
||||||
|
- `new_content` (str): Content to validate against voice
|
||||||
|
- `voice_aspects` (list): ["tone", "vocabulary", "structure", "educational_style"]
|
||||||
|
- `generate_profile` (bool): Create reusable voice profile
|
||||||
|
|
||||||
|
**Model:** gemini-2.5-pro (pattern analysis)
|
||||||
|
**Temperature:** 0.3
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 10. `campaignmap` (Workflow Tool)
|
||||||
|
**Purpose:** Map content campaign across multiple touchpoints
|
||||||
|
|
||||||
|
**Real-world use case from project memory:**
|
||||||
|
- "Plan content campaign for measureQuick National Championship: social teaser posts, email sequence, on-site promotion, post-event follow-up, blog recap with internal links"
|
||||||
|
|
||||||
|
**Workflow steps:**
|
||||||
|
1. Define campaign goals and timeline
|
||||||
|
2. Map content across platforms and stages (awareness → consideration → action → retention)
|
||||||
|
3. Create content calendar with dependencies
|
||||||
|
4. Generate messaging for each touchpoint
|
||||||
|
5. Plan cross-promotion and internal linking
|
||||||
|
6. Set success metrics per channel
|
||||||
|
|
||||||
|
**Parameters:**
|
||||||
|
- `campaign_goal` (str): Primary objective
|
||||||
|
- `timeline_days` (int): Campaign duration
|
||||||
|
- `platforms` (list): Channels to use
|
||||||
|
- `content_pillars` (list): Key messages to reinforce
|
||||||
|
- `target_audience` (str): Audience description
|
||||||
|
- `existing_assets` (list): Content to repurpose
|
||||||
|
|
||||||
|
**Model:** gemini-2.5-pro (strategic planning)
|
||||||
|
**Temperature:** 0.6
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Tool Architecture
|
||||||
|
|
||||||
|
### Simple Tools (3)
|
||||||
|
Fast, single-shot interactions for rapid iteration:
|
||||||
|
- `contentvariant` - Bulk variation generation
|
||||||
|
- `platformadapt` - Cross-platform adaptation
|
||||||
|
- `subjectlines` - Email subject line testing
|
||||||
|
- `factcheck` - Quick verification
|
||||||
|
|
||||||
|
### Workflow Tools (6)
|
||||||
|
Multi-step systematic processes:
|
||||||
|
- `styleguide` - Style enforcement workflow
|
||||||
|
- `seooptimize` - Comprehensive SEO process
|
||||||
|
- `guestedit` - Guest content editing workflow
|
||||||
|
- `linkstrategy` - Internal linking analysis
|
||||||
|
- `voiceanalysis` - Voice extraction and validation
|
||||||
|
- `campaignmap` - Multi-touch campaign planning
|
||||||
|
|
||||||
|
### Keep from Original Zen
|
||||||
|
- `chat` - General brainstorming
|
||||||
|
- `thinkdeep` - Deep strategic thinking
|
||||||
|
- `planner` - Project planning
|
||||||
|
|
||||||
|
## Model Assignment Strategy
|
||||||
|
|
||||||
|
**Primary Models:**
|
||||||
|
- **minimax/minimax-m2**: Creative content generation (contentvariant, platformadapt)
|
||||||
|
- **google/gemini-2.5-pro**: Analytical and strategic work (seooptimize, guestedit, linkstrategy, voiceanalysis, campaignmap, styleguide, factcheck)
|
||||||
|
- **google/gemini-flash**: Fast bulk generation (subjectlines)
|
||||||
|
|
||||||
|
**Fallback Chain:**
|
||||||
|
1. Configured provider (minimax or gemini)
|
||||||
|
2. OpenRouter (if API key present)
|
||||||
|
3. Error (no fallback to Anthropic to avoid confusion)
|
||||||
|
|
||||||
|
## Configuration Defaults
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"zen-marketing": {
|
||||||
|
"command": "/home/ben/mcp/zen-marketing/.venv/bin/python",
|
||||||
|
"args": ["/home/ben/mcp/zen-marketing/server.py"],
|
||||||
|
"env": {
|
||||||
|
"OPENROUTER_API_KEY": "your-key",
|
||||||
|
"DEFAULT_MODEL": "gemini-2.5-pro",
|
||||||
|
"FAST_MODEL": "gemini-flash",
|
||||||
|
"CREATIVE_MODEL": "minimax-m2",
|
||||||
|
"DISABLED_TOOLS": "analyze,refactor,testgen,secaudit,docgen,tracer,precommit,codereview,debug,consensus,challenge",
|
||||||
|
"ENABLE_WEB_SEARCH": "true",
|
||||||
|
"LOG_LEVEL": "INFO"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tool Temperature Defaults
|
||||||
|
|
||||||
|
| Tool | Temperature | Rationale |
|
||||||
|
|------|-------------|-----------|
|
||||||
|
| contentvariant | 0.8 | High creativity for diverse variations |
|
||||||
|
| platformadapt | 0.7 | Creative while maintaining message |
|
||||||
|
| subjectlines | 0.8 | Explore diverse psychological angles |
|
||||||
|
| styleguide | 0.3 | Precision for consistency |
|
||||||
|
| seooptimize | 0.4 | Balanced analytical + creative |
|
||||||
|
| guestedit | 0.5 | Balance preservation + enhancement |
|
||||||
|
| linkstrategy | 0.4 | Analytical relationship mapping |
|
||||||
|
| factcheck | 0.2 | Precision for accuracy |
|
||||||
|
| voiceanalysis | 0.3 | Pattern recognition precision |
|
||||||
|
| campaignmap | 0.6 | Strategic creativity |
|
||||||
|
|
||||||
|
## Differences from Zen Code
|
||||||
|
|
||||||
|
### What's Different
|
||||||
|
1. **Content-first**: Tools designed for writing, not coding
|
||||||
|
2. **Variation generation**: Built-in support for A/B testing at scale
|
||||||
|
3. **Platform awareness**: Character limits, formatting, best practices per channel
|
||||||
|
4. **Voice preservation**: Tools explicitly maintain author authenticity
|
||||||
|
5. **SEO native**: WordPress-specific optimization built in
|
||||||
|
6. **Verification emphasis**: Web search for fact-checking integrated
|
||||||
|
7. **No code tools**: Remove debug, codereview, precommit, refactor, testgen, secaudit, tracer
|
||||||
|
|
||||||
|
### What's Kept
|
||||||
|
1. **Conversation continuity**: `continuation_id` across tools
|
||||||
|
2. **Multi-model orchestration**: Different models for different tasks
|
||||||
|
3. **Workflow architecture**: Multi-step systematic processes
|
||||||
|
4. **File handling**: Attach content files, brand guidelines
|
||||||
|
5. **Image support**: Analyze screenshots, brand assets
|
||||||
|
6. **Thinking modes**: Control depth vs. cost
|
||||||
|
7. **Web search**: Current information access
|
||||||
|
|
||||||
|
### What's Enhanced
|
||||||
|
1. **Bulk operations**: Generate 5-20 variations in one call
|
||||||
|
2. **Style enforcement**: Custom writing rules
|
||||||
|
3. **Cross-platform thinking**: Native multi-channel workflows
|
||||||
|
4. **Technical verification**: Domain-specific fact-checking
|
||||||
|
|
||||||
|
## Implementation Phases
|
||||||
|
|
||||||
|
### Phase 1: Foundation (Week 1)
|
||||||
|
- Fork Zen MCP Server codebase
|
||||||
|
- Remove code-specific tools
|
||||||
|
- Configure minimax and gemini providers
|
||||||
|
- Update system prompts for marketing context
|
||||||
|
- Basic testing with chat tool
|
||||||
|
|
||||||
|
### Phase 2: Simple Tools (Week 2)
|
||||||
|
- Implement `contentvariant`
|
||||||
|
- Implement `platformadapt`
|
||||||
|
- Implement `subjectlines`
|
||||||
|
- Implement `factcheck`
|
||||||
|
- Test variation generation workflows
|
||||||
|
|
||||||
|
### Phase 3: Workflow Tools (Week 3-4)
|
||||||
|
- Implement `styleguide`
|
||||||
|
- Implement `seooptimize`
|
||||||
|
- Implement `guestedit`
|
||||||
|
- Implement `linkstrategy`
|
||||||
|
- Test multi-step workflows
|
||||||
|
|
||||||
|
### Phase 4: Advanced Features (Week 5)
|
||||||
|
- Implement `voiceanalysis`
|
||||||
|
- Implement `campaignmap`
|
||||||
|
- Add voice profile storage
|
||||||
|
- Test complex campaign workflows
|
||||||
|
|
||||||
|
### Phase 5: Polish & Documentation (Week 6)
|
||||||
|
- User documentation
|
||||||
|
- Example workflows
|
||||||
|
- Configuration guide
|
||||||
|
- Testing with real project memories
|
||||||
|
|
||||||
|
## Success Metrics
|
||||||
|
|
||||||
|
1. **Variation Quality**: Can generate 10+ usable subject lines in one call
|
||||||
|
2. **Platform Accuracy**: Respects character limits and formatting rules
|
||||||
|
3. **Voice Consistency**: Successfully preserves guest author expertise
|
||||||
|
4. **SEO Effectiveness**: Generated titles/descriptions pass WordPress SEO checks
|
||||||
|
5. **Cross-platform Integration**: Correctly maps content across blog/social/email
|
||||||
|
6. **Fact Accuracy**: Catches technical errors before publication
|
||||||
|
7. **Speed**: Simple tools respond in <10 seconds, workflow tools in <30 seconds
|
||||||
|
|
||||||
|
## File Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
zen-marketing/
|
||||||
|
├── server.py # Main MCP server (adapted from zen)
|
||||||
|
├── config.py # Marketing-specific configuration
|
||||||
|
├── PLAN.md # This file
|
||||||
|
├── CLAUDE.md # Development guide
|
||||||
|
├── README.md # User documentation
|
||||||
|
├── requirements.txt # Python dependencies
|
||||||
|
├── .env.example # Environment template
|
||||||
|
├── run-server.sh # Setup and run script
|
||||||
|
├── tools/
|
||||||
|
│ ├── contentvariant.py # Bulk variation generation
|
||||||
|
│ ├── platformadapt.py # Cross-platform adaptation
|
||||||
|
│ ├── subjectlines.py # Email subject lines
|
||||||
|
│ ├── styleguide.py # Writing style enforcement
|
||||||
|
│ ├── seooptimize.py # WordPress SEO workflow
|
||||||
|
│ ├── guestedit.py # Guest content editing
|
||||||
|
│ ├── linkstrategy.py # Internal linking strategy
|
||||||
|
│ ├── factcheck.py # Technical verification
|
||||||
|
│ ├── voiceanalysis.py # Voice extraction/validation
|
||||||
|
│ ├── campaignmap.py # Campaign planning
|
||||||
|
│ └── shared/ # Shared utilities from zen
|
||||||
|
├── providers/
|
||||||
|
│ ├── gemini.py # Google Gemini provider
|
||||||
|
│ ├── minimax.py # Minimax provider (NEW)
|
||||||
|
│ ├── openrouter.py # OpenRouter fallback
|
||||||
|
│ └── registry.py # Provider registration
|
||||||
|
├── systemprompts/
|
||||||
|
│ ├── contentvariant_prompt.py
|
||||||
|
│ ├── platformadapt_prompt.py
|
||||||
|
│ ├── subjectlines_prompt.py
|
||||||
|
│ ├── styleguide_prompt.py
|
||||||
|
│ ├── seooptimize_prompt.py
|
||||||
|
│ ├── guestedit_prompt.py
|
||||||
|
│ ├── linkstrategy_prompt.py
|
||||||
|
│ ├── factcheck_prompt.py
|
||||||
|
│ ├── voiceanalysis_prompt.py
|
||||||
|
│ └── campaignmap_prompt.py
|
||||||
|
└── tests/
|
||||||
|
├── test_contentvariant.py
|
||||||
|
├── test_platformadapt.py
|
||||||
|
└── ...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Development Priorities
|
||||||
|
|
||||||
|
### High Priority (Essential)
|
||||||
|
1. `contentvariant` - Most used feature from project memories
|
||||||
|
2. `subjectlines` - Specific workflow mentioned multiple times
|
||||||
|
3. `platformadapt` - Core multi-channel need
|
||||||
|
4. `styleguide` - Explicit writing rules enforcement
|
||||||
|
5. `factcheck` - Technical accuracy verification
|
||||||
|
|
||||||
|
### Medium Priority (Important)
|
||||||
|
6. `seooptimize` - WordPress-specific SEO needs
|
||||||
|
7. `guestedit` - Guest content workflow
|
||||||
|
8. `linkstrategy` - Internal linking mentioned frequently
|
||||||
|
|
||||||
|
### Lower Priority (Nice to Have)
|
||||||
|
9. `voiceanalysis` - Advanced voice preservation
|
||||||
|
10. `campaignmap` - Strategic planning (can use planner tool initially)
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
1. Create CLAUDE.md with development guide for zen-marketing
|
||||||
|
2. Start new session in ~/mcp/zen-marketing/ directory
|
||||||
|
3. Begin Phase 1 implementation:
|
||||||
|
- Copy core architecture from zen-mcp-server
|
||||||
|
- Configure minimax provider
|
||||||
|
- Remove code-specific tools
|
||||||
|
- Test with basic chat functionality
|
||||||
|
4. Implement Phase 2 simple tools starting with `contentvariant`
|
||||||
163
README.md
Normal file
163
README.md
Normal file
|
|
@ -0,0 +1,163 @@
|
||||||
|
# Zen-Marketing MCP Server
|
||||||
|
|
||||||
|
> **Status:** 🚧 In Development - Phase 1
|
||||||
|
|
||||||
|
AI-powered marketing tools for Claude Desktop, specialized for technical B2B content creation, multi-platform campaigns, and content variation testing.
|
||||||
|
|
||||||
|
## What is This?
|
||||||
|
|
||||||
|
A fork of the [Zen MCP Server](https://github.com/BeehiveInnovations/zen-mcp-server) optimized for marketing workflows instead of software development. Provides Claude Desktop with specialized tools for:
|
||||||
|
|
||||||
|
- **Content Variation** - Generate 5-20 variations for A/B testing
|
||||||
|
- **Platform Adaptation** - Adapt content across LinkedIn, Twitter, newsletters, blogs
|
||||||
|
- **Style Enforcement** - Apply writing guidelines automatically
|
||||||
|
- **SEO Optimization** - WordPress-specific SEO workflows
|
||||||
|
- **Guest Editing** - Edit external content while preserving author voice
|
||||||
|
- **Fact Verification** - Technical accuracy checking via web search
|
||||||
|
- **Internal Linking** - Cross-platform content integration strategy
|
||||||
|
- **Campaign Planning** - Multi-touch campaign mapping
|
||||||
|
|
||||||
|
## For Whom?
|
||||||
|
|
||||||
|
Solo marketing professionals managing:
|
||||||
|
- Technical newsletters and educational content
|
||||||
|
- Multi-platform social media campaigns
|
||||||
|
- WordPress blogs with SEO requirements
|
||||||
|
- Guest contributor content
|
||||||
|
- A/B testing and content experimentation
|
||||||
|
- B2B SaaS or technical product marketing
|
||||||
|
|
||||||
|
## Key Differences from Zen Code
|
||||||
|
|
||||||
|
| Zen Code | Zen Marketing |
|
||||||
|
|----------|---------------|
|
||||||
|
| Code review, debugging, testing | Content variation, SEO, style guide |
|
||||||
|
| Software architecture analysis | Campaign planning, voice analysis |
|
||||||
|
| Development workflows | Marketing workflows |
|
||||||
|
| Technical accuracy for code | Technical accuracy for content |
|
||||||
|
| GitHub/git integration | WordPress/platform integration |
|
||||||
|
|
||||||
|
## Tools (Planned)
|
||||||
|
|
||||||
|
### Simple Tools (Fast, Single-Shot)
|
||||||
|
- `contentvariant` - Generate 5-20 variations for testing
|
||||||
|
- `platformadapt` - Cross-platform content adaptation
|
||||||
|
- `subjectlines` - Email subject line generation
|
||||||
|
- `factcheck` - Technical fact verification
|
||||||
|
|
||||||
|
### Workflow Tools (Multi-Step)
|
||||||
|
- `styleguide` - Writing style enforcement
|
||||||
|
- `seooptimize` - WordPress SEO optimization
|
||||||
|
- `guestedit` - Guest content editing
|
||||||
|
- `linkstrategy` - Internal linking strategy
|
||||||
|
- `voiceanalysis` - Voice extraction and validation
|
||||||
|
- `campaignmap` - Campaign planning
|
||||||
|
|
||||||
|
### Kept from Zen
|
||||||
|
- `chat` - General brainstorming
|
||||||
|
- `thinkdeep` - Deep strategic thinking
|
||||||
|
- `planner` - Project planning
|
||||||
|
|
||||||
|
## Models
|
||||||
|
|
||||||
|
**Primary Models:**
|
||||||
|
- **Gemini 2.5 Pro** - Analytical and strategic work
|
||||||
|
- **Gemini Flash** - Fast bulk generation
|
||||||
|
- **Minimax M2** - Creative content generation
|
||||||
|
|
||||||
|
**Fallback:**
|
||||||
|
- OpenRouter (if configured)
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
> **Note:** Not yet functional - see [PLAN.md](PLAN.md) for implementation roadmap
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone and setup
|
||||||
|
git clone [repo-url]
|
||||||
|
cd zen-marketing
|
||||||
|
./run-server.sh
|
||||||
|
|
||||||
|
# Configure Claude Desktop (~/.claude.json)
|
||||||
|
{
|
||||||
|
"mcpServers": {
|
||||||
|
"zen-marketing": {
|
||||||
|
"command": "/home/ben/mcp/zen-marketing/.venv/bin/python",
|
||||||
|
"args": ["/home/ben/mcp/zen-marketing/server.py"],
|
||||||
|
"env": {
|
||||||
|
"GEMINI_API_KEY": "your-key",
|
||||||
|
"OPENROUTER_API_KEY": "your-key",
|
||||||
|
"DEFAULT_MODEL": "gemini-2.5-pro"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Restart Claude Desktop
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example Usage
|
||||||
|
|
||||||
|
```
|
||||||
|
# Generate subject line variations
|
||||||
|
"Use zen-marketing contentvariant to generate 15 subject lines for an HVAC newsletter about PCB diagnostics. Test angles: technical curiosity, contrarian, knowledge gap, urgency."
|
||||||
|
|
||||||
|
# Adapt content across platforms
|
||||||
|
"Use zen-marketing platformadapt to take this blog post intro and create versions for LinkedIn (1300 chars), Twitter (280), Instagram (2200), and newsletter."
|
||||||
|
|
||||||
|
# Enforce writing style
|
||||||
|
"Use zen-marketing styleguide to check this draft for em-dashes, 'This isn't X, it's Y' constructions, and ensure direct affirmative statements."
|
||||||
|
|
||||||
|
# SEO optimize
|
||||||
|
"Use zen-marketing seooptimize to create SEO title under 60 chars, excerpt under 156 chars, and suggest WordPress tags for this article about HVAC voltage regulation."
|
||||||
|
```
|
||||||
|
|
||||||
|
## Development Status
|
||||||
|
|
||||||
|
### Phase 1: Foundation (Current)
|
||||||
|
- [x] Create directory structure
|
||||||
|
- [x] Write implementation plan
|
||||||
|
- [x] Create development guide
|
||||||
|
- [ ] Copy core architecture from zen-mcp-server
|
||||||
|
- [ ] Configure minimax provider
|
||||||
|
- [ ] Test basic functionality
|
||||||
|
|
||||||
|
### Phase 2: Simple Tools (Next)
|
||||||
|
- [ ] Implement `contentvariant`
|
||||||
|
- [ ] Implement `subjectlines`
|
||||||
|
- [ ] Implement `platformadapt`
|
||||||
|
- [ ] Implement `factcheck`
|
||||||
|
|
||||||
|
See [PLAN.md](PLAN.md) for complete roadmap.
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
- **[PLAN.md](PLAN.md)** - Detailed implementation plan with tool designs
|
||||||
|
- **[CLAUDE.md](CLAUDE.md)** - Development guide for contributors
|
||||||
|
- **[Project Memories](PLAN.md#project-overview)** - Real-world usage examples
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
Based on Zen MCP Server's proven architecture:
|
||||||
|
- **Conversation continuity** via `continuation_id`
|
||||||
|
- **Multi-model orchestration** (Gemini, Minimax, OpenRouter)
|
||||||
|
- **Simple tools** for fast iteration
|
||||||
|
- **Workflow tools** for multi-step processes
|
||||||
|
- **Web search integration** for current information
|
||||||
|
- **File handling** for content and brand assets
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
Apache 2.0 License (inherited from Zen MCP Server)
|
||||||
|
|
||||||
|
## Acknowledgments
|
||||||
|
|
||||||
|
Built on the foundation of:
|
||||||
|
- [Zen MCP Server](https://github.com/BeehiveInnovations/zen-mcp-server) by Fahad Gilani
|
||||||
|
- [Model Context Protocol](https://modelcontextprotocol.com) by Anthropic
|
||||||
|
- [Claude Desktop](https://claude.ai/download) - AI interface
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Status:** Planning phase complete, ready for implementation
|
||||||
|
**Next:** Start new Claude session in `~/mcp/zen-marketing/` to begin Phase 1
|
||||||
107
config.py
Normal file
107
config.py
Normal file
|
|
@ -0,0 +1,107 @@
|
||||||
|
"""
|
||||||
|
Configuration and constants for Zen-Marketing MCP Server
|
||||||
|
|
||||||
|
This module centralizes all configuration settings for the Zen-Marketing MCP Server.
|
||||||
|
It defines model configurations, token limits, temperature defaults, and other
|
||||||
|
constants used throughout the application.
|
||||||
|
|
||||||
|
Configuration values can be overridden by environment variables where appropriate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Version and metadata
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
__updated__ = "2025-11-07"
|
||||||
|
__author__ = "Ben (based on Zen MCP Server by Fahad Gilani)"
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
# DEFAULT_MODEL: The default model used for all AI operations
|
||||||
|
# Can be overridden by setting DEFAULT_MODEL environment variable
|
||||||
|
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "google/gemini-2.5-pro-latest")
|
||||||
|
|
||||||
|
# Fast model for quick operations (variations, subject lines)
|
||||||
|
FAST_MODEL = os.getenv("FAST_MODEL", "google/gemini-2.5-flash-preview-09-2025")
|
||||||
|
|
||||||
|
# Creative model for content generation
|
||||||
|
CREATIVE_MODEL = os.getenv("CREATIVE_MODEL", "minimax/minimax-m2")
|
||||||
|
|
||||||
|
# Auto mode detection - when DEFAULT_MODEL is "auto", Claude picks the model
|
||||||
|
IS_AUTO_MODE = DEFAULT_MODEL.lower() == "auto"
|
||||||
|
|
||||||
|
# Temperature defaults for different content types
|
||||||
|
# Temperature controls the randomness/creativity of model responses
|
||||||
|
# Lower values (0.0-0.3) produce more deterministic, focused responses
|
||||||
|
# Higher values (0.7-1.0) produce more creative, varied responses
|
||||||
|
|
||||||
|
# TEMPERATURE_PRECISION: Used for fact-checking and technical verification
|
||||||
|
TEMPERATURE_PRECISION = 0.2 # For factcheck, technical verification
|
||||||
|
|
||||||
|
# TEMPERATURE_ANALYTICAL: Used for style enforcement and SEO optimization
|
||||||
|
TEMPERATURE_ANALYTICAL = 0.3 # For styleguide, seooptimize, voiceanalysis
|
||||||
|
|
||||||
|
# TEMPERATURE_BALANCED: Used for strategic planning
|
||||||
|
TEMPERATURE_BALANCED = 0.5 # For guestedit, linkstrategy, campaignmap
|
||||||
|
|
||||||
|
# TEMPERATURE_CREATIVE: Used for content variation and adaptation
|
||||||
|
TEMPERATURE_CREATIVE = 0.7 # For platformadapt
|
||||||
|
|
||||||
|
# TEMPERATURE_HIGHLY_CREATIVE: Used for bulk variation generation
|
||||||
|
TEMPERATURE_HIGHLY_CREATIVE = 0.8 # For contentvariant, subjectlines
|
||||||
|
|
||||||
|
# Thinking Mode Defaults
|
||||||
|
DEFAULT_THINKING_MODE_THINKDEEP = os.getenv("DEFAULT_THINKING_MODE_THINKDEEP", "high")
|
||||||
|
|
||||||
|
# MCP Protocol Transport Limits
|
||||||
|
def _calculate_mcp_prompt_limit() -> int:
|
||||||
|
"""
|
||||||
|
Calculate MCP prompt size limit based on MAX_MCP_OUTPUT_TOKENS environment variable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Maximum character count for user input prompts
|
||||||
|
"""
|
||||||
|
max_tokens_str = os.getenv("MAX_MCP_OUTPUT_TOKENS")
|
||||||
|
|
||||||
|
if max_tokens_str:
|
||||||
|
try:
|
||||||
|
max_tokens = int(max_tokens_str)
|
||||||
|
# Allocate 60% of tokens for input, convert to characters (~4 chars per token)
|
||||||
|
input_token_budget = int(max_tokens * 0.6)
|
||||||
|
character_limit = input_token_budget * 4
|
||||||
|
return character_limit
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Default fallback: 60,000 characters
|
||||||
|
return 60_000
|
||||||
|
|
||||||
|
|
||||||
|
MCP_PROMPT_SIZE_LIMIT = _calculate_mcp_prompt_limit()
|
||||||
|
|
||||||
|
# Language/Locale Configuration
|
||||||
|
LOCALE = os.getenv("LOCALE", "")
|
||||||
|
|
||||||
|
# Platform character limits
|
||||||
|
PLATFORM_LIMITS = {
|
||||||
|
"twitter": 280,
|
||||||
|
"bluesky": 300,
|
||||||
|
"linkedin": 3000,
|
||||||
|
"linkedin_optimal": 1300,
|
||||||
|
"instagram": 2200,
|
||||||
|
"facebook": 500, # Optimal length
|
||||||
|
"email_subject": 60,
|
||||||
|
"email_preview": 100,
|
||||||
|
"meta_description": 156,
|
||||||
|
"page_title": 60,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Web search configuration
|
||||||
|
ENABLE_WEB_SEARCH = os.getenv("ENABLE_WEB_SEARCH", "true").lower() == "true"
|
||||||
|
|
||||||
|
# Tool disabling
|
||||||
|
# Comma-separated list of tools to disable
|
||||||
|
DISABLED_TOOLS_STR = os.getenv("DISABLED_TOOLS", "")
|
||||||
|
DISABLED_TOOLS = set(tool.strip() for tool in DISABLED_TOOLS_STR.split(",") if tool.strip())
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
|
||||||
20
providers/__init__.py
Normal file
20
providers/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
||||||
|
"""Model provider abstractions for supporting multiple AI providers."""
|
||||||
|
|
||||||
|
from .base import ModelProvider
|
||||||
|
from .gemini import GeminiModelProvider
|
||||||
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .openai_provider import OpenAIModelProvider
|
||||||
|
from .openrouter import OpenRouterProvider
|
||||||
|
from .registry import ModelProviderRegistry
|
||||||
|
from .shared import ModelCapabilities, ModelResponse
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModelProvider",
|
||||||
|
"ModelResponse",
|
||||||
|
"ModelCapabilities",
|
||||||
|
"ModelProviderRegistry",
|
||||||
|
"GeminiModelProvider",
|
||||||
|
"OpenAIModelProvider",
|
||||||
|
"OpenAICompatibleProvider",
|
||||||
|
"OpenRouterProvider",
|
||||||
|
]
|
||||||
268
providers/base.py
Normal file
268
providers/base.py
Normal file
|
|
@ -0,0 +1,268 @@
|
||||||
|
"""Base interfaces and common behaviour for model providers."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
|
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProvider(ABC):
|
||||||
|
"""Abstract base class for all model backends in the MCP server.
|
||||||
|
|
||||||
|
Role
|
||||||
|
Defines the interface every provider must implement so the registry,
|
||||||
|
restriction service, and tools have a uniform surface for listing
|
||||||
|
models, resolving aliases, and executing requests.
|
||||||
|
|
||||||
|
Responsibilities
|
||||||
|
* expose static capability metadata for each supported model via
|
||||||
|
:class:`ModelCapabilities`
|
||||||
|
* accept user prompts, forward them to the underlying SDK, and wrap
|
||||||
|
responses in :class:`ModelResponse`
|
||||||
|
* report tokenizer counts for budgeting and validation logic
|
||||||
|
* advertise provider identity (``ProviderType``) so restriction
|
||||||
|
policies can map environment configuration onto providers
|
||||||
|
* validate whether a model name or alias is recognised by the provider
|
||||||
|
|
||||||
|
Shared helpers like temperature validation, alias resolution, and
|
||||||
|
restriction-aware ``list_models`` live here so concrete subclasses only
|
||||||
|
need to supply their catalogue and wire up SDK-specific behaviour.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# All concrete providers must define their supported models
|
||||||
|
MODEL_CAPABILITIES: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, **kwargs):
|
||||||
|
"""Initialize the provider with API key and optional configuration."""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.config = kwargs
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Provider identity & capability surface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
@abstractmethod
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
"""Return the concrete provider identity."""
|
||||||
|
|
||||||
|
def get_capabilities(self, model_name: str) -> ModelCapabilities:
|
||||||
|
"""Resolve capability metadata for a model name.
|
||||||
|
|
||||||
|
This centralises the alias resolution → lookup → restriction check
|
||||||
|
pipeline so providers only override the pieces they genuinely need to
|
||||||
|
customise. Subclasses usually only override ``_lookup_capabilities`` to
|
||||||
|
integrate a registry or dynamic source, or ``_finalise_capabilities`` to
|
||||||
|
tweak the returned object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
capabilities = self._lookup_capabilities(resolved_name, model_name)
|
||||||
|
|
||||||
|
if capabilities is None:
|
||||||
|
self._raise_unsupported_model(model_name)
|
||||||
|
|
||||||
|
self._ensure_model_allowed(capabilities, resolved_name, model_name)
|
||||||
|
return self._finalise_capabilities(capabilities, resolved_name, model_name)
|
||||||
|
|
||||||
|
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||||
|
"""Return statically declared capabilities when available."""
|
||||||
|
|
||||||
|
model_map = getattr(self, "MODEL_CAPABILITIES", None)
|
||||||
|
if isinstance(model_map, dict) and model_map:
|
||||||
|
return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def list_models(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
respect_restrictions: bool = True,
|
||||||
|
include_aliases: bool = True,
|
||||||
|
lowercase: bool = False,
|
||||||
|
unique: bool = False,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Return formatted model names supported by this provider."""
|
||||||
|
|
||||||
|
model_configs = self.get_all_model_capabilities()
|
||||||
|
if not model_configs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
restriction_service = None
|
||||||
|
if respect_restrictions:
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service()
|
||||||
|
|
||||||
|
if restriction_service:
|
||||||
|
allowed_configs = {}
|
||||||
|
for model_name, config in model_configs.items():
|
||||||
|
if restriction_service.is_allowed(self.get_provider_type(), model_name):
|
||||||
|
allowed_configs[model_name] = config
|
||||||
|
model_configs = allowed_configs
|
||||||
|
|
||||||
|
if not model_configs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return ModelCapabilities.collect_model_names(
|
||||||
|
model_configs,
|
||||||
|
include_aliases=include_aliases,
|
||||||
|
lowercase=lowercase,
|
||||||
|
unique=unique,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Request execution
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
@abstractmethod
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.3,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using the model."""
|
||||||
|
|
||||||
|
def count_tokens(self, text: str, model_name: str) -> int:
|
||||||
|
"""Estimate token usage for a piece of text."""
|
||||||
|
|
||||||
|
resolved_model = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
estimated = max(1, len(text) // 4)
|
||||||
|
logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model)
|
||||||
|
return estimated
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Clean up any resources held by the provider."""
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Validation hooks
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def validate_model_name(self, model_name: str) -> bool:
|
||||||
|
"""Return ``True`` when the model resolves to an allowed capability."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.get_capabilities(model_name)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||||||
|
"""Validate model parameters against capabilities."""
|
||||||
|
|
||||||
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
|
||||||
|
if not capabilities.temperature_constraint.validate(temperature):
|
||||||
|
constraint_desc = capabilities.temperature_constraint.get_description()
|
||||||
|
raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Preference / registry hooks
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||||
|
"""Get the preferred model from this provider for a given category."""
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_model_registry(self) -> Optional[dict[str, Any]]:
|
||||||
|
"""Return the model registry backing this provider, if any."""
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Capability lookup pipeline
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def _lookup_capabilities(
|
||||||
|
self,
|
||||||
|
canonical_name: str,
|
||||||
|
requested_name: Optional[str] = None,
|
||||||
|
) -> Optional[ModelCapabilities]:
|
||||||
|
"""Return ``ModelCapabilities`` for the canonical model name."""
|
||||||
|
|
||||||
|
return self.get_all_model_capabilities().get(canonical_name)
|
||||||
|
|
||||||
|
def _ensure_model_allowed(
|
||||||
|
self,
|
||||||
|
capabilities: ModelCapabilities,
|
||||||
|
canonical_name: str,
|
||||||
|
requested_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Raise ``ValueError`` if the model violates restriction policy."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
except Exception: # pragma: no cover - only triggered if service import breaks
|
||||||
|
return
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service()
|
||||||
|
if not restriction_service:
|
||||||
|
return
|
||||||
|
|
||||||
|
if restriction_service.is_allowed(self.get_provider_type(), canonical_name, requested_name):
|
||||||
|
return
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.get_provider_type().value} model '{canonical_name}' is not allowed by restriction policy."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _finalise_capabilities(
|
||||||
|
self,
|
||||||
|
capabilities: ModelCapabilities,
|
||||||
|
canonical_name: str,
|
||||||
|
requested_name: str,
|
||||||
|
) -> ModelCapabilities:
|
||||||
|
"""Allow subclasses to adjust capability metadata before returning."""
|
||||||
|
|
||||||
|
return capabilities
|
||||||
|
|
||||||
|
def _raise_unsupported_model(self, model_name: str) -> None:
|
||||||
|
"""Raise the canonical unsupported-model error."""
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported model '{model_name}' for provider {self.get_provider_type().value}.")
|
||||||
|
|
||||||
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
|
"""Resolve model shorthand to full name.
|
||||||
|
|
||||||
|
This implementation uses the hook methods to support different
|
||||||
|
model configuration sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model name that may be an alias
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved model name
|
||||||
|
"""
|
||||||
|
# Get model configurations from the hook method
|
||||||
|
model_configs = self.get_all_model_capabilities()
|
||||||
|
|
||||||
|
# First check if it's already a base model name (case-sensitive exact match)
|
||||||
|
if model_name in model_configs:
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
# Check case-insensitively for both base models and aliases
|
||||||
|
model_name_lower = model_name.lower()
|
||||||
|
|
||||||
|
# Check base model names case-insensitively
|
||||||
|
for base_model in model_configs:
|
||||||
|
if base_model.lower() == model_name_lower:
|
||||||
|
return base_model
|
||||||
|
|
||||||
|
# Check aliases from the model configurations
|
||||||
|
alias_map = ModelCapabilities.collect_aliases(model_configs)
|
||||||
|
for base_model, aliases in alias_map.items():
|
||||||
|
if any(alias.lower() == model_name_lower for alias in aliases):
|
||||||
|
return base_model
|
||||||
|
|
||||||
|
# If not found, return as-is
|
||||||
|
return model_name
|
||||||
196
providers/custom.py
Normal file
196
providers/custom.py
Normal file
|
|
@ -0,0 +1,196 @@
|
||||||
|
"""Custom API provider implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .openrouter_registry import OpenRouterModelRegistry
|
||||||
|
from .shared import ModelCapabilities, ModelResponse, ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
class CustomProvider(OpenAICompatibleProvider):
|
||||||
|
"""Adapter for self-hosted or local OpenAI-compatible endpoints.
|
||||||
|
|
||||||
|
Role
|
||||||
|
Provide a uniform bridge between the MCP server and user-managed
|
||||||
|
OpenAI-compatible services (Ollama, vLLM, LM Studio, bespoke gateways).
|
||||||
|
By subclassing :class:`OpenAICompatibleProvider` it inherits request and
|
||||||
|
token handling, while the custom registry exposes locally defined model
|
||||||
|
metadata.
|
||||||
|
|
||||||
|
Notable behaviour
|
||||||
|
* Uses :class:`OpenRouterModelRegistry` to load model definitions and
|
||||||
|
aliases so custom deployments share the same metadata pipeline as
|
||||||
|
OpenRouter itself.
|
||||||
|
* Normalises version-tagged model names (``model:latest``) and applies
|
||||||
|
restriction policies just like cloud providers, ensuring consistent
|
||||||
|
behaviour across environments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
FRIENDLY_NAME = "Custom API"
|
||||||
|
|
||||||
|
# Model registry for managing configurations and aliases (shared with OpenRouter)
|
||||||
|
_registry: Optional[OpenRouterModelRegistry] = None
|
||||||
|
|
||||||
|
def __init__(self, api_key: str = "", base_url: str = "", **kwargs):
|
||||||
|
"""Initialize Custom provider for local/self-hosted models.
|
||||||
|
|
||||||
|
This provider supports any OpenAI-compatible API endpoint including:
|
||||||
|
- Ollama (typically no API key required)
|
||||||
|
- vLLM (may require API key)
|
||||||
|
- LM Studio (may require API key)
|
||||||
|
- Text Generation WebUI (may require API key)
|
||||||
|
- Enterprise/self-hosted APIs (typically require API key)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for the custom endpoint. Can be empty string for
|
||||||
|
providers that don't require authentication (like Ollama).
|
||||||
|
Falls back to CUSTOM_API_KEY environment variable if not provided.
|
||||||
|
base_url: Base URL for the custom API endpoint (e.g., 'http://localhost:11434/v1').
|
||||||
|
Falls back to CUSTOM_API_URL environment variable if not provided.
|
||||||
|
**kwargs: Additional configuration passed to parent OpenAI-compatible provider
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no base_url is provided via parameter or environment variable
|
||||||
|
"""
|
||||||
|
# Fall back to environment variables only if not provided
|
||||||
|
if not base_url:
|
||||||
|
base_url = os.getenv("CUSTOM_API_URL", "")
|
||||||
|
if not api_key:
|
||||||
|
api_key = os.getenv("CUSTOM_API_KEY", "")
|
||||||
|
|
||||||
|
if not base_url:
|
||||||
|
raise ValueError(
|
||||||
|
"Custom API URL must be provided via base_url parameter or CUSTOM_API_URL environment variable"
|
||||||
|
)
|
||||||
|
|
||||||
|
# For Ollama and other providers that don't require authentication,
|
||||||
|
# set a dummy API key to avoid OpenAI client header issues
|
||||||
|
if not api_key:
|
||||||
|
api_key = "dummy-key-for-unauthenticated-endpoint"
|
||||||
|
logging.debug("Using dummy API key for unauthenticated custom endpoint")
|
||||||
|
|
||||||
|
logging.info(f"Initializing Custom provider with endpoint: {base_url}")
|
||||||
|
|
||||||
|
super().__init__(api_key, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
|
# Initialize model registry (shared with OpenRouter for consistent aliases)
|
||||||
|
if CustomProvider._registry is None:
|
||||||
|
CustomProvider._registry = OpenRouterModelRegistry()
|
||||||
|
# Log loaded models and aliases only on first load
|
||||||
|
models = self._registry.list_models()
|
||||||
|
aliases = self._registry.list_aliases()
|
||||||
|
logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Capability surface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
def _lookup_capabilities(
|
||||||
|
self,
|
||||||
|
canonical_name: str,
|
||||||
|
requested_name: Optional[str] = None,
|
||||||
|
) -> Optional[ModelCapabilities]:
|
||||||
|
"""Return capabilities for models explicitly marked as custom."""
|
||||||
|
|
||||||
|
builtin = super()._lookup_capabilities(canonical_name, requested_name)
|
||||||
|
if builtin is not None:
|
||||||
|
return builtin
|
||||||
|
|
||||||
|
registry_entry = self._registry.resolve(canonical_name)
|
||||||
|
if registry_entry and getattr(registry_entry, "is_custom", False):
|
||||||
|
registry_entry.provider = ProviderType.CUSTOM
|
||||||
|
return registry_entry
|
||||||
|
|
||||||
|
logging.debug(
|
||||||
|
"Custom provider cannot resolve model '%s'; ensure it is declared with 'is_custom': true in custom_models.json",
|
||||||
|
canonical_name,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
"""Identify this provider for restriction and logging logic."""
|
||||||
|
|
||||||
|
return ProviderType.CUSTOM
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Validation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Request execution
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.3,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using the custom API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt to send to the model
|
||||||
|
model_name: Name of the model to use
|
||||||
|
system_prompt: Optional system prompt for model behavior
|
||||||
|
temperature: Sampling temperature
|
||||||
|
max_output_tokens: Maximum tokens to generate
|
||||||
|
**kwargs: Additional provider-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelResponse with generated content and metadata
|
||||||
|
"""
|
||||||
|
# Resolve model alias to actual model name
|
||||||
|
resolved_model = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
# Call parent method with resolved model name
|
||||||
|
return super().generate_content(
|
||||||
|
prompt=prompt,
|
||||||
|
model_name=resolved_model,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Registry helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
|
"""Resolve registry aliases and strip version tags for local models."""
|
||||||
|
|
||||||
|
config = self._registry.resolve(model_name)
|
||||||
|
if config:
|
||||||
|
if config.model_name != model_name:
|
||||||
|
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||||
|
return config.model_name
|
||||||
|
|
||||||
|
if ":" in model_name:
|
||||||
|
base_model = model_name.split(":")[0]
|
||||||
|
logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'")
|
||||||
|
|
||||||
|
base_config = self._registry.resolve(base_model)
|
||||||
|
if base_config:
|
||||||
|
logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'")
|
||||||
|
return base_config.model_name
|
||||||
|
return base_model
|
||||||
|
|
||||||
|
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||||
|
"""Expose registry capabilities for models marked as custom."""
|
||||||
|
|
||||||
|
if not self._registry:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
capabilities: dict[str, ModelCapabilities] = {}
|
||||||
|
for model_name in self._registry.list_models():
|
||||||
|
config = self._registry.resolve(model_name)
|
||||||
|
if config and getattr(config, "is_custom", False):
|
||||||
|
capabilities[model_name] = config
|
||||||
|
return capabilities
|
||||||
473
providers/dial.py
Normal file
473
providers/dial.py
Normal file
|
|
@ -0,0 +1,473 @@
|
||||||
|
"""DIAL (Data & AI Layer) model provider implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DIALModelProvider(OpenAICompatibleProvider):
|
||||||
|
"""Client for the DIAL (Data & AI Layer) aggregation service.
|
||||||
|
|
||||||
|
DIAL exposes several third-party models behind a single OpenAI-compatible
|
||||||
|
endpoint. This provider wraps the service, publishes capability metadata
|
||||||
|
for the known deployments, and centralises retry/backoff settings tailored
|
||||||
|
to DIAL's latency characteristics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
FRIENDLY_NAME = "DIAL"
|
||||||
|
|
||||||
|
# Retry configuration for API calls
|
||||||
|
MAX_RETRIES = 4
|
||||||
|
RETRY_DELAYS = [1, 3, 5, 8] # seconds
|
||||||
|
|
||||||
|
# Model configurations using ModelCapabilities objects
|
||||||
|
MODEL_CAPABILITIES = {
|
||||||
|
"o3-2025-04-16": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="o3-2025-04-16",
|
||||||
|
friendly_name="DIAL (O3)",
|
||||||
|
context_window=200_000,
|
||||||
|
max_output_tokens=100_000,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # DIAL may not expose function calling
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=20.0,
|
||||||
|
supports_temperature=False, # O3 models don't accept temperature
|
||||||
|
temperature_constraint=TemperatureConstraint.create("fixed"),
|
||||||
|
description="OpenAI O3 via DIAL - Strong reasoning model",
|
||||||
|
aliases=["o3"],
|
||||||
|
),
|
||||||
|
"o4-mini-2025-04-16": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="o4-mini-2025-04-16",
|
||||||
|
friendly_name="DIAL (O4-mini)",
|
||||||
|
context_window=200_000,
|
||||||
|
max_output_tokens=100_000,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # DIAL may not expose function calling
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=20.0,
|
||||||
|
supports_temperature=False, # O4 models don't accept temperature
|
||||||
|
temperature_constraint=TemperatureConstraint.create("fixed"),
|
||||||
|
description="OpenAI O4-mini via DIAL - Fast reasoning model",
|
||||||
|
aliases=["o4-mini"],
|
||||||
|
),
|
||||||
|
"anthropic.claude-sonnet-4.1-20250805-v1:0": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="anthropic.claude-sonnet-4.1-20250805-v1:0",
|
||||||
|
friendly_name="DIAL (Sonnet 4.1)",
|
||||||
|
context_window=200_000,
|
||||||
|
max_output_tokens=64_000,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # Claude doesn't have function calling
|
||||||
|
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=5.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("range"),
|
||||||
|
description="Claude Sonnet 4.1 via DIAL - Balanced performance",
|
||||||
|
aliases=["sonnet-4.1", "sonnet-4"],
|
||||||
|
),
|
||||||
|
"anthropic.claude-sonnet-4.1-20250805-v1:0-with-thinking": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="anthropic.claude-sonnet-4.1-20250805-v1:0-with-thinking",
|
||||||
|
friendly_name="DIAL (Sonnet 4.1 Thinking)",
|
||||||
|
context_window=200_000,
|
||||||
|
max_output_tokens=64_000,
|
||||||
|
supports_extended_thinking=True, # Thinking mode variant
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # Claude doesn't have function calling
|
||||||
|
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=5.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("range"),
|
||||||
|
description="Claude Sonnet 4.1 with thinking mode via DIAL",
|
||||||
|
aliases=["sonnet-4.1-thinking", "sonnet-4-thinking"],
|
||||||
|
),
|
||||||
|
"anthropic.claude-opus-4.1-20250805-v1:0": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="anthropic.claude-opus-4.1-20250805-v1:0",
|
||||||
|
friendly_name="DIAL (Opus 4.1)",
|
||||||
|
context_window=200_000,
|
||||||
|
max_output_tokens=64_000,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # Claude doesn't have function calling
|
||||||
|
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=5.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("range"),
|
||||||
|
description="Claude Opus 4.1 via DIAL - Most capable Claude model",
|
||||||
|
aliases=["opus-4.1", "opus-4"],
|
||||||
|
),
|
||||||
|
"anthropic.claude-opus-4.1-20250805-v1:0-with-thinking": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="anthropic.claude-opus-4.1-20250805-v1:0-with-thinking",
|
||||||
|
friendly_name="DIAL (Opus 4.1 Thinking)",
|
||||||
|
context_window=200_000,
|
||||||
|
max_output_tokens=64_000,
|
||||||
|
supports_extended_thinking=True, # Thinking mode variant
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # Claude doesn't have function calling
|
||||||
|
supports_json_mode=False, # Claude doesn't have JSON mode
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=5.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("range"),
|
||||||
|
description="Claude Opus 4.1 with thinking mode via DIAL",
|
||||||
|
aliases=["opus-4.1-thinking", "opus-4-thinking"],
|
||||||
|
),
|
||||||
|
"gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="gemini-2.5-pro-preview-03-25-google-search",
|
||||||
|
friendly_name="DIAL (Gemini 2.5 Pro Search)",
|
||||||
|
context_window=1_000_000,
|
||||||
|
max_output_tokens=65_536,
|
||||||
|
supports_extended_thinking=False, # DIAL doesn't expose thinking mode
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # DIAL may not expose function calling
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=20.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("range"),
|
||||||
|
description="Gemini 2.5 Pro with Google Search via DIAL",
|
||||||
|
aliases=["gemini-2.5-pro-search"],
|
||||||
|
),
|
||||||
|
"gemini-2.5-pro-preview-05-06": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="gemini-2.5-pro-preview-05-06",
|
||||||
|
friendly_name="DIAL (Gemini 2.5 Pro)",
|
||||||
|
context_window=1_000_000,
|
||||||
|
max_output_tokens=65_536,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # DIAL may not expose function calling
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=20.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("range"),
|
||||||
|
description="Gemini 2.5 Pro via DIAL - Deep reasoning",
|
||||||
|
aliases=["gemini-2.5-pro"],
|
||||||
|
),
|
||||||
|
"gemini-2.5-flash-preview-05-20": ModelCapabilities(
|
||||||
|
provider=ProviderType.DIAL,
|
||||||
|
model_name="gemini-2.5-flash-preview-05-20",
|
||||||
|
friendly_name="DIAL (Gemini Flash 2.5)",
|
||||||
|
context_window=1_000_000,
|
||||||
|
max_output_tokens=65_536,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False, # DIAL may not expose function calling
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=20.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("range"),
|
||||||
|
description="Gemini 2.5 Flash via DIAL - Ultra-fast",
|
||||||
|
aliases=["gemini-2.5-flash"],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, **kwargs):
|
||||||
|
"""Initialize DIAL provider with API key and host.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: DIAL API key for authentication
|
||||||
|
**kwargs: Additional configuration options
|
||||||
|
"""
|
||||||
|
# Get DIAL API host from environment or kwargs
|
||||||
|
dial_host = kwargs.get("base_url") or os.getenv("DIAL_API_HOST") or "https://core.dialx.ai"
|
||||||
|
|
||||||
|
# DIAL uses /openai endpoint for OpenAI-compatible API
|
||||||
|
if not dial_host.endswith("/openai"):
|
||||||
|
dial_host = f"{dial_host.rstrip('/')}/openai"
|
||||||
|
|
||||||
|
kwargs["base_url"] = dial_host
|
||||||
|
|
||||||
|
# Get API version from environment or use default
|
||||||
|
self.api_version = os.getenv("DIAL_API_VERSION", "2024-12-01-preview")
|
||||||
|
|
||||||
|
# Add DIAL-specific headers
|
||||||
|
# DIAL uses Api-Key header instead of Authorization: Bearer
|
||||||
|
# Reference: https://dialx.ai/dial_api#section/Authorization
|
||||||
|
self.DEFAULT_HEADERS = {
|
||||||
|
"Api-Key": api_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store the actual API key for use in Api-Key header
|
||||||
|
self._dial_api_key = api_key
|
||||||
|
|
||||||
|
# Pass a placeholder API key to OpenAI client - we'll override the auth header in httpx
|
||||||
|
# The actual authentication happens via the Api-Key header in the httpx client
|
||||||
|
super().__init__("placeholder-not-used", **kwargs)
|
||||||
|
|
||||||
|
# Cache for deployment-specific clients to avoid recreating them on each request
|
||||||
|
self._deployment_clients = {}
|
||||||
|
# Lock to ensure thread-safe client creation
|
||||||
|
self._client_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Create a SINGLE shared httpx client for the provider instance
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
# Create custom event hooks to remove Authorization header
|
||||||
|
def remove_auth_header(request):
|
||||||
|
"""Remove Authorization header that OpenAI client adds."""
|
||||||
|
# httpx headers are case-insensitive, so we need to check all variations
|
||||||
|
headers_to_remove = []
|
||||||
|
for header_name in request.headers:
|
||||||
|
if header_name.lower() == "authorization":
|
||||||
|
headers_to_remove.append(header_name)
|
||||||
|
|
||||||
|
for header_name in headers_to_remove:
|
||||||
|
del request.headers[header_name]
|
||||||
|
|
||||||
|
self._http_client = httpx.Client(
|
||||||
|
timeout=self.timeout_config,
|
||||||
|
verify=True,
|
||||||
|
follow_redirects=True,
|
||||||
|
headers=self.DEFAULT_HEADERS.copy(), # Include DIAL headers including Api-Key
|
||||||
|
limits=httpx.Limits(
|
||||||
|
max_keepalive_connections=5,
|
||||||
|
max_connections=10,
|
||||||
|
keepalive_expiry=30.0,
|
||||||
|
),
|
||||||
|
event_hooks={"request": [remove_auth_header]},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}")
|
||||||
|
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
"""Get the provider type."""
|
||||||
|
return ProviderType.DIAL
|
||||||
|
|
||||||
|
def _get_deployment_client(self, deployment: str):
|
||||||
|
"""Get or create a cached client for a specific deployment.
|
||||||
|
|
||||||
|
This avoids recreating OpenAI clients on every request, improving performance.
|
||||||
|
Reuses the shared HTTP client for connection pooling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
deployment: The deployment/model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpenAI client configured for the specific deployment
|
||||||
|
"""
|
||||||
|
# Check if client already exists without locking for performance
|
||||||
|
if deployment in self._deployment_clients:
|
||||||
|
return self._deployment_clients[deployment]
|
||||||
|
|
||||||
|
# Use lock to ensure thread-safe client creation
|
||||||
|
with self._client_lock:
|
||||||
|
# Double-check pattern: check again inside the lock
|
||||||
|
if deployment not in self._deployment_clients:
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# Build deployment-specific URL
|
||||||
|
base_url = str(self.client.base_url)
|
||||||
|
if base_url.endswith("/"):
|
||||||
|
base_url = base_url[:-1]
|
||||||
|
|
||||||
|
# Remove /openai suffix if present to reconstruct properly
|
||||||
|
if base_url.endswith("/openai"):
|
||||||
|
base_url = base_url[:-7]
|
||||||
|
|
||||||
|
deployment_url = f"{base_url}/openai/deployments/{deployment}"
|
||||||
|
|
||||||
|
# Create and cache the client, REUSING the shared http_client
|
||||||
|
# Use placeholder API key - Authorization header will be removed by http_client event hook
|
||||||
|
self._deployment_clients[deployment] = OpenAI(
|
||||||
|
api_key="placeholder-not-used",
|
||||||
|
base_url=deployment_url,
|
||||||
|
http_client=self._http_client, # Pass the shared client with Api-Key header
|
||||||
|
default_query={"api-version": self.api_version}, # Add api-version as query param
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._deployment_clients[deployment]
|
||||||
|
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.3,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
images: Optional[list[str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using DIAL's deployment-specific endpoint.
|
||||||
|
|
||||||
|
DIAL uses Azure OpenAI-style deployment endpoints:
|
||||||
|
/openai/deployments/{deployment}/chat/completions
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
model_name: Model name or alias
|
||||||
|
system_prompt: Optional system prompt
|
||||||
|
temperature: Sampling temperature
|
||||||
|
max_output_tokens: Maximum tokens to generate
|
||||||
|
**kwargs: Additional provider-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelResponse with generated content and metadata
|
||||||
|
"""
|
||||||
|
# Validate model name against allow-list
|
||||||
|
if not self.validate_model_name(model_name):
|
||||||
|
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
|
||||||
|
|
||||||
|
# Validate parameters and fetch capabilities
|
||||||
|
self.validate_parameters(model_name, temperature)
|
||||||
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
|
||||||
|
# Prepare messages
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
# Build user message content
|
||||||
|
user_message_content = []
|
||||||
|
if prompt:
|
||||||
|
user_message_content.append({"type": "text", "text": prompt})
|
||||||
|
|
||||||
|
if images and capabilities.supports_images:
|
||||||
|
for img_path in images:
|
||||||
|
processed_image = self._process_image(img_path)
|
||||||
|
if processed_image:
|
||||||
|
user_message_content.append(processed_image)
|
||||||
|
elif images:
|
||||||
|
logger.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
|
||||||
|
|
||||||
|
# Add user message. If only text, content will be a string, otherwise a list.
|
||||||
|
if len(user_message_content) == 1 and user_message_content[0]["type"] == "text":
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
else:
|
||||||
|
messages.append({"role": "user", "content": user_message_content})
|
||||||
|
|
||||||
|
# Resolve model name
|
||||||
|
resolved_model = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
# Build completion parameters
|
||||||
|
completion_params = {
|
||||||
|
"model": resolved_model,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Determine temperature support from capabilities
|
||||||
|
supports_temperature = capabilities.supports_temperature
|
||||||
|
|
||||||
|
# Add temperature parameter if supported
|
||||||
|
if supports_temperature:
|
||||||
|
completion_params["temperature"] = temperature
|
||||||
|
|
||||||
|
# Add max tokens if specified and model supports it
|
||||||
|
if max_output_tokens and supports_temperature:
|
||||||
|
completion_params["max_tokens"] = max_output_tokens
|
||||||
|
|
||||||
|
# Add additional parameters
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
||||||
|
if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty"]:
|
||||||
|
continue
|
||||||
|
completion_params[key] = value
|
||||||
|
|
||||||
|
# DIAL-specific: Get cached client for deployment endpoint
|
||||||
|
deployment_client = self._get_deployment_client(resolved_model)
|
||||||
|
|
||||||
|
# Retry logic with progressive delays
|
||||||
|
last_exception = None
|
||||||
|
|
||||||
|
for attempt in range(self.MAX_RETRIES):
|
||||||
|
try:
|
||||||
|
# Generate completion using deployment-specific client
|
||||||
|
response = deployment_client.chat.completions.create(**completion_params)
|
||||||
|
|
||||||
|
# Extract content and usage
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
usage = self._extract_usage(response)
|
||||||
|
|
||||||
|
return ModelResponse(
|
||||||
|
content=content,
|
||||||
|
usage=usage,
|
||||||
|
model_name=model_name,
|
||||||
|
friendly_name=self.FRIENDLY_NAME,
|
||||||
|
provider=self.get_provider_type(),
|
||||||
|
metadata={
|
||||||
|
"finish_reason": response.choices[0].finish_reason,
|
||||||
|
"model": response.model,
|
||||||
|
"id": response.id,
|
||||||
|
"created": response.created,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
# Check if this is a retryable error
|
||||||
|
is_retryable = self._is_error_retryable(e)
|
||||||
|
|
||||||
|
if not is_retryable:
|
||||||
|
# Non-retryable error, raise immediately
|
||||||
|
raise ValueError(f"DIAL API error for model {model_name}: {str(e)}")
|
||||||
|
|
||||||
|
# If this isn't the last attempt and error is retryable, wait and retry
|
||||||
|
if attempt < self.MAX_RETRIES - 1:
|
||||||
|
delay = self.RETRY_DELAYS[attempt]
|
||||||
|
logger.info(
|
||||||
|
f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}"
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# All retries exhausted
|
||||||
|
raise ValueError(
|
||||||
|
f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Clean up HTTP clients when provider is closed."""
|
||||||
|
logger.info("Closing DIAL provider HTTP clients...")
|
||||||
|
|
||||||
|
# Clear the deployment clients cache
|
||||||
|
# Note: We don't need to close individual OpenAI clients since they
|
||||||
|
# use the shared httpx.Client which we close separately
|
||||||
|
self._deployment_clients.clear()
|
||||||
|
|
||||||
|
# Close the shared HTTP client
|
||||||
|
if hasattr(self, "_http_client"):
|
||||||
|
try:
|
||||||
|
self._http_client.close()
|
||||||
|
logger.debug("Closed shared HTTP client")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing shared HTTP client: {e}")
|
||||||
|
|
||||||
|
# Also close the client created by the superclass (OpenAICompatibleProvider)
|
||||||
|
# as it holds its own httpx.Client instance that is not used by DIAL's generate_content
|
||||||
|
if hasattr(self, "client") and self.client and hasattr(self.client, "close"):
|
||||||
|
try:
|
||||||
|
self.client.close()
|
||||||
|
logger.debug("Closed superclass's OpenAI client")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error closing superclass's OpenAI client: {e}")
|
||||||
578
providers/gemini.py
Normal file
578
providers/gemini.py
Normal file
|
|
@ -0,0 +1,578 @@
|
||||||
|
"""Gemini model provider implementation."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types
|
||||||
|
|
||||||
|
from utils.image_utils import validate_image
|
||||||
|
|
||||||
|
from .base import ModelProvider
|
||||||
|
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiModelProvider(ModelProvider):
|
||||||
|
"""First-party Gemini integration built on the official Google SDK.
|
||||||
|
|
||||||
|
The provider advertises detailed thinking-mode budgets, handles optional
|
||||||
|
custom endpoints, and performs image pre-processing before forwarding a
|
||||||
|
request to the Gemini APIs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Model configurations using ModelCapabilities objects
|
||||||
|
MODEL_CAPABILITIES = {
|
||||||
|
"gemini-2.5-pro": ModelCapabilities(
|
||||||
|
provider=ProviderType.GOOGLE,
|
||||||
|
model_name="gemini-2.5-pro",
|
||||||
|
friendly_name="Gemini (Pro 2.5)",
|
||||||
|
context_window=1_048_576, # 1M tokens
|
||||||
|
max_output_tokens=65_536,
|
||||||
|
supports_extended_thinking=True,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # Vision capability
|
||||||
|
max_image_size_mb=32.0, # Higher limit for Pro model
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("range"),
|
||||||
|
max_thinking_tokens=32768, # Max thinking tokens for Pro model
|
||||||
|
description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis",
|
||||||
|
aliases=["pro", "gemini pro", "gemini-pro"],
|
||||||
|
),
|
||||||
|
"gemini-2.0-flash": ModelCapabilities(
|
||||||
|
provider=ProviderType.GOOGLE,
|
||||||
|
model_name="gemini-2.0-flash",
|
||||||
|
friendly_name="Gemini (Flash 2.0)",
|
||||||
|
context_window=1_048_576, # 1M tokens
|
||||||
|
max_output_tokens=65_536,
|
||||||
|
supports_extended_thinking=True, # Experimental thinking mode
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # Vision capability
|
||||||
|
max_image_size_mb=20.0, # Conservative 20MB limit for reliability
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("range"),
|
||||||
|
max_thinking_tokens=24576, # Same as 2.5 flash for consistency
|
||||||
|
description="Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input",
|
||||||
|
aliases=["flash-2.0", "flash2"],
|
||||||
|
),
|
||||||
|
"gemini-2.0-flash-lite": ModelCapabilities(
|
||||||
|
provider=ProviderType.GOOGLE,
|
||||||
|
model_name="gemini-2.0-flash-lite",
|
||||||
|
friendly_name="Gemin (Flash Lite 2.0)",
|
||||||
|
context_window=1_048_576, # 1M tokens
|
||||||
|
max_output_tokens=65_536,
|
||||||
|
supports_extended_thinking=False, # Not supported per user request
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=False, # Does not support images
|
||||||
|
max_image_size_mb=0.0, # No image support
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("range"),
|
||||||
|
description="Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only",
|
||||||
|
aliases=["flashlite", "flash-lite"],
|
||||||
|
),
|
||||||
|
"gemini-2.5-flash": ModelCapabilities(
|
||||||
|
provider=ProviderType.GOOGLE,
|
||||||
|
model_name="gemini-2.5-flash",
|
||||||
|
friendly_name="Gemini (Flash 2.5)",
|
||||||
|
context_window=1_048_576, # 1M tokens
|
||||||
|
max_output_tokens=65_536,
|
||||||
|
supports_extended_thinking=True,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # Vision capability
|
||||||
|
max_image_size_mb=20.0, # Conservative 20MB limit for reliability
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("range"),
|
||||||
|
max_thinking_tokens=24576, # Flash 2.5 thinking budget limit
|
||||||
|
description="Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations",
|
||||||
|
aliases=["flash", "flash2.5"],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Thinking mode configurations - percentages of model's max_thinking_tokens
|
||||||
|
# These percentages work across all models that support thinking
|
||||||
|
THINKING_BUDGETS = {
|
||||||
|
"minimal": 0.005, # 0.5% of max - minimal thinking for fast responses
|
||||||
|
"low": 0.08, # 8% of max - light reasoning tasks
|
||||||
|
"medium": 0.33, # 33% of max - balanced reasoning (default)
|
||||||
|
"high": 0.67, # 67% of max - complex analysis
|
||||||
|
"max": 1.0, # 100% of max - full thinking budget
|
||||||
|
}
|
||||||
|
|
||||||
|
# Model-specific thinking token limits
|
||||||
|
MAX_THINKING_TOKENS = {
|
||||||
|
"gemini-2.0-flash": 24576, # Same as 2.5 flash for consistency
|
||||||
|
"gemini-2.0-flash-lite": 0, # No thinking support
|
||||||
|
"gemini-2.5-flash": 24576, # Flash 2.5 thinking budget limit
|
||||||
|
"gemini-2.5-pro": 32768, # Pro 2.5 thinking budget limit
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, **kwargs):
|
||||||
|
"""Initialize Gemini provider with API key and optional base URL."""
|
||||||
|
super().__init__(api_key, **kwargs)
|
||||||
|
self._client = None
|
||||||
|
self._token_counters = {} # Cache for token counting
|
||||||
|
self._base_url = kwargs.get("base_url", None) # Optional custom endpoint
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Capability surface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Client access
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self):
|
||||||
|
"""Lazy initialization of Gemini client."""
|
||||||
|
if self._client is None:
|
||||||
|
# Check if custom base URL is provided
|
||||||
|
if self._base_url:
|
||||||
|
# Use HttpOptions to set custom endpoint
|
||||||
|
http_options = types.HttpOptions(baseUrl=self._base_url)
|
||||||
|
logger.debug(f"Initializing Gemini client with custom endpoint: {self._base_url}")
|
||||||
|
self._client = genai.Client(api_key=self.api_key, http_options=http_options)
|
||||||
|
else:
|
||||||
|
# Use default Google endpoint
|
||||||
|
self._client = genai.Client(api_key=self.api_key)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Request execution
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.3,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
thinking_mode: str = "medium",
|
||||||
|
images: Optional[list[str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using Gemini model."""
|
||||||
|
# Validate parameters and fetch capabilities
|
||||||
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
self.validate_parameters(model_name, temperature)
|
||||||
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
|
||||||
|
# Prepare content parts (text and potentially images)
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
# Add system and user prompts as text
|
||||||
|
if system_prompt:
|
||||||
|
full_prompt = f"{system_prompt}\n\n{prompt}"
|
||||||
|
else:
|
||||||
|
full_prompt = prompt
|
||||||
|
|
||||||
|
parts.append({"text": full_prompt})
|
||||||
|
|
||||||
|
# Add images if provided and model supports vision
|
||||||
|
if images and capabilities.supports_images:
|
||||||
|
for image_path in images:
|
||||||
|
try:
|
||||||
|
image_part = self._process_image(image_path)
|
||||||
|
if image_part:
|
||||||
|
parts.append(image_part)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to process image {image_path}: {e}")
|
||||||
|
# Continue with other images and text
|
||||||
|
continue
|
||||||
|
elif images and not capabilities.supports_images:
|
||||||
|
logger.warning(f"Model {resolved_name} does not support images, ignoring {len(images)} image(s)")
|
||||||
|
|
||||||
|
# Create contents structure
|
||||||
|
contents = [{"parts": parts}]
|
||||||
|
|
||||||
|
# Prepare generation config
|
||||||
|
generation_config = types.GenerateContentConfig(
|
||||||
|
temperature=temperature,
|
||||||
|
candidate_count=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add max output tokens if specified
|
||||||
|
if max_output_tokens:
|
||||||
|
generation_config.max_output_tokens = max_output_tokens
|
||||||
|
|
||||||
|
# Add thinking configuration for models that support it
|
||||||
|
if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS:
|
||||||
|
# Get model's max thinking tokens and calculate actual budget
|
||||||
|
model_config = self.MODEL_CAPABILITIES.get(resolved_name)
|
||||||
|
if model_config and model_config.max_thinking_tokens > 0:
|
||||||
|
max_thinking_tokens = model_config.max_thinking_tokens
|
||||||
|
actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
||||||
|
generation_config.thinking_config = types.ThinkingConfig(thinking_budget=actual_thinking_budget)
|
||||||
|
|
||||||
|
# Retry logic with progressive delays
|
||||||
|
max_retries = 4 # Total of 4 attempts
|
||||||
|
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
|
||||||
|
|
||||||
|
last_exception = None
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
# Generate content
|
||||||
|
response = self.client.models.generate_content(
|
||||||
|
model=resolved_name,
|
||||||
|
contents=contents,
|
||||||
|
config=generation_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract usage information if available
|
||||||
|
usage = self._extract_usage(response)
|
||||||
|
|
||||||
|
# Intelligently determine finish reason and safety blocks
|
||||||
|
finish_reason_str = "UNKNOWN"
|
||||||
|
is_blocked_by_safety = False
|
||||||
|
safety_feedback_details = None
|
||||||
|
|
||||||
|
if response.candidates:
|
||||||
|
candidate = response.candidates[0]
|
||||||
|
|
||||||
|
# Safely get finish reason
|
||||||
|
try:
|
||||||
|
finish_reason_enum = candidate.finish_reason
|
||||||
|
if finish_reason_enum:
|
||||||
|
# Handle both enum objects and string values
|
||||||
|
try:
|
||||||
|
finish_reason_str = finish_reason_enum.name
|
||||||
|
except AttributeError:
|
||||||
|
finish_reason_str = str(finish_reason_enum)
|
||||||
|
else:
|
||||||
|
finish_reason_str = "STOP"
|
||||||
|
except AttributeError:
|
||||||
|
finish_reason_str = "STOP"
|
||||||
|
|
||||||
|
# If content is empty, check safety ratings for the definitive cause
|
||||||
|
if not response.text:
|
||||||
|
try:
|
||||||
|
safety_ratings = candidate.safety_ratings
|
||||||
|
if safety_ratings: # Check it's not None or empty
|
||||||
|
for rating in safety_ratings:
|
||||||
|
try:
|
||||||
|
if rating.blocked:
|
||||||
|
is_blocked_by_safety = True
|
||||||
|
# Provide details for logging/debugging
|
||||||
|
category_name = "UNKNOWN"
|
||||||
|
probability_name = "UNKNOWN"
|
||||||
|
|
||||||
|
try:
|
||||||
|
category_name = rating.category.name
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
probability_name = rating.probability.name
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
safety_feedback_details = (
|
||||||
|
f"Category: {category_name}, Probability: {probability_name}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
# Individual rating doesn't have expected attributes
|
||||||
|
continue
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
# candidate doesn't have safety_ratings or it's not iterable
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Also check for prompt-level blocking (request rejected entirely)
|
||||||
|
elif response.candidates is not None and len(response.candidates) == 0:
|
||||||
|
# No candidates is the primary indicator of a prompt-level block
|
||||||
|
is_blocked_by_safety = True
|
||||||
|
finish_reason_str = "SAFETY"
|
||||||
|
safety_feedback_details = "Prompt blocked, reason unavailable" # Default message
|
||||||
|
|
||||||
|
try:
|
||||||
|
prompt_feedback = response.prompt_feedback
|
||||||
|
if prompt_feedback and prompt_feedback.block_reason:
|
||||||
|
try:
|
||||||
|
block_reason_name = prompt_feedback.block_reason.name
|
||||||
|
except AttributeError:
|
||||||
|
block_reason_name = str(prompt_feedback.block_reason)
|
||||||
|
safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}"
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
# prompt_feedback doesn't exist or has unexpected attributes; stick with the default message
|
||||||
|
pass
|
||||||
|
|
||||||
|
return ModelResponse(
|
||||||
|
content=response.text,
|
||||||
|
usage=usage,
|
||||||
|
model_name=resolved_name,
|
||||||
|
friendly_name="Gemini",
|
||||||
|
provider=ProviderType.GOOGLE,
|
||||||
|
metadata={
|
||||||
|
"thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None,
|
||||||
|
"finish_reason": finish_reason_str,
|
||||||
|
"is_blocked_by_safety": is_blocked_by_safety,
|
||||||
|
"safety_feedback": safety_feedback_details,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
# Check if this is a retryable error using structured error codes
|
||||||
|
is_retryable = self._is_error_retryable(e)
|
||||||
|
|
||||||
|
# If this is the last attempt or not retryable, give up
|
||||||
|
if attempt == max_retries - 1 or not is_retryable:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get progressive delay
|
||||||
|
delay = retry_delays[attempt]
|
||||||
|
|
||||||
|
# Log retry attempt
|
||||||
|
logger.warning(
|
||||||
|
f"Gemini API error for model {resolved_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
# If we get here, all retries failed
|
||||||
|
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
||||||
|
error_msg = f"Gemini API error for model {resolved_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
|
||||||
|
raise RuntimeError(error_msg) from last_exception
|
||||||
|
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
"""Get the provider type."""
|
||||||
|
return ProviderType.GOOGLE
|
||||||
|
|
||||||
|
def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int:
|
||||||
|
"""Get actual thinking token budget for a model and thinking mode."""
|
||||||
|
resolved_name = self._resolve_model_name(model_name)
|
||||||
|
model_config = self.MODEL_CAPABILITIES.get(resolved_name)
|
||||||
|
|
||||||
|
if not model_config or not model_config.supports_extended_thinking:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if thinking_mode not in self.THINKING_BUDGETS:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
max_thinking_tokens = model_config.max_thinking_tokens
|
||||||
|
if max_thinking_tokens == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode])
|
||||||
|
|
||||||
|
def _extract_usage(self, response) -> dict[str, int]:
|
||||||
|
"""Extract token usage from Gemini response."""
|
||||||
|
usage = {}
|
||||||
|
|
||||||
|
# Try to extract usage metadata from response
|
||||||
|
# Note: The actual structure depends on the SDK version and response format
|
||||||
|
try:
|
||||||
|
metadata = response.usage_metadata
|
||||||
|
if metadata:
|
||||||
|
# Extract token counts with explicit None checks
|
||||||
|
input_tokens = None
|
||||||
|
output_tokens = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
value = metadata.prompt_token_count
|
||||||
|
if value is not None:
|
||||||
|
input_tokens = value
|
||||||
|
usage["input_tokens"] = value
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
value = metadata.candidates_token_count
|
||||||
|
if value is not None:
|
||||||
|
output_tokens = value
|
||||||
|
usage["output_tokens"] = value
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Calculate total only if both values are available and valid
|
||||||
|
if input_tokens is not None and output_tokens is not None:
|
||||||
|
usage["total_tokens"] = input_tokens + output_tokens
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
# response doesn't have usage_metadata
|
||||||
|
pass
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
def _is_error_retryable(self, error: Exception) -> bool:
|
||||||
|
"""Determine if an error should be retried based on structured error codes.
|
||||||
|
|
||||||
|
Uses Gemini API error structure instead of text pattern matching for reliability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: Exception from Gemini API call
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if error should be retried, False otherwise
|
||||||
|
"""
|
||||||
|
error_str = str(error).lower()
|
||||||
|
|
||||||
|
# Check for 429 errors first - these need special handling
|
||||||
|
if "429" in error_str or "quota" in error_str or "resource_exhausted" in error_str:
|
||||||
|
# For Gemini, check for specific non-retryable error indicators
|
||||||
|
# These typically indicate permanent failures or quota/size limits
|
||||||
|
non_retryable_indicators = [
|
||||||
|
"quota exceeded",
|
||||||
|
"resource exhausted",
|
||||||
|
"context length",
|
||||||
|
"token limit",
|
||||||
|
"request too large",
|
||||||
|
"invalid request",
|
||||||
|
"quota_exceeded",
|
||||||
|
"resource_exhausted",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Also check if this is a structured error from Gemini SDK
|
||||||
|
try:
|
||||||
|
# Try to access error details if available
|
||||||
|
error_details = None
|
||||||
|
try:
|
||||||
|
error_details = error.details
|
||||||
|
except AttributeError:
|
||||||
|
try:
|
||||||
|
error_details = error.reason
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if error_details:
|
||||||
|
error_details_str = str(error_details).lower()
|
||||||
|
# Check for non-retryable error codes/reasons
|
||||||
|
if any(indicator in error_details_str for indicator in non_retryable_indicators):
|
||||||
|
logger.debug(f"Non-retryable Gemini error: {error_details}")
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check main error string for non-retryable patterns
|
||||||
|
if any(indicator in error_str for indicator in non_retryable_indicators):
|
||||||
|
logger.debug(f"Non-retryable Gemini error based on message: {error_str[:200]}...")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# If it's a 429/quota error but doesn't match non-retryable patterns, it might be retryable rate limiting
|
||||||
|
logger.debug(f"Retryable Gemini rate limiting error: {error_str[:100]}...")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# For non-429 errors, check if they're retryable
|
||||||
|
retryable_indicators = [
|
||||||
|
"timeout",
|
||||||
|
"connection",
|
||||||
|
"network",
|
||||||
|
"temporary",
|
||||||
|
"unavailable",
|
||||||
|
"retry",
|
||||||
|
"internal error",
|
||||||
|
"408", # Request timeout
|
||||||
|
"500", # Internal server error
|
||||||
|
"502", # Bad gateway
|
||||||
|
"503", # Service unavailable
|
||||||
|
"504", # Gateway timeout
|
||||||
|
"ssl", # SSL errors
|
||||||
|
"handshake", # Handshake failures
|
||||||
|
]
|
||||||
|
|
||||||
|
return any(indicator in error_str for indicator in retryable_indicators)
|
||||||
|
|
||||||
|
def _process_image(self, image_path: str) -> Optional[dict]:
|
||||||
|
"""Process an image for Gemini API."""
|
||||||
|
try:
|
||||||
|
# Use base class validation
|
||||||
|
image_bytes, mime_type = validate_image(image_path)
|
||||||
|
|
||||||
|
# For data URLs, extract the base64 data directly
|
||||||
|
if image_path.startswith("data:"):
|
||||||
|
# Extract base64 data from data URL
|
||||||
|
_, data = image_path.split(",", 1)
|
||||||
|
return {"inline_data": {"mime_type": mime_type, "data": data}}
|
||||||
|
else:
|
||||||
|
# For file paths, encode the bytes
|
||||||
|
image_data = base64.b64encode(image_bytes).decode()
|
||||||
|
return {"inline_data": {"mime_type": mime_type, "data": image_data}}
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(str(e))
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing image {image_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||||
|
"""Get Gemini's preferred model for a given category from allowed models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: The tool category requiring a model
|
||||||
|
allowed_models: Pre-filtered list of models allowed by restrictions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preferred model name or None
|
||||||
|
"""
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
|
if not allowed_models:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Helper to find best model from candidates
|
||||||
|
def find_best(candidates: list[str]) -> Optional[str]:
|
||||||
|
"""Return best model from candidates (sorted for consistency)."""
|
||||||
|
return sorted(candidates, reverse=True)[0] if candidates else None
|
||||||
|
|
||||||
|
if category == ToolModelCategory.EXTENDED_REASONING:
|
||||||
|
# For extended reasoning, prefer models with thinking support
|
||||||
|
# First try Pro models that support thinking
|
||||||
|
pro_thinking = [
|
||||||
|
m
|
||||||
|
for m in allowed_models
|
||||||
|
if "pro" in m and m in self.MODEL_CAPABILITIES and self.MODEL_CAPABILITIES[m].supports_extended_thinking
|
||||||
|
]
|
||||||
|
if pro_thinking:
|
||||||
|
return find_best(pro_thinking)
|
||||||
|
|
||||||
|
# Then any model that supports thinking
|
||||||
|
any_thinking = [
|
||||||
|
m
|
||||||
|
for m in allowed_models
|
||||||
|
if m in self.MODEL_CAPABILITIES and self.MODEL_CAPABILITIES[m].supports_extended_thinking
|
||||||
|
]
|
||||||
|
if any_thinking:
|
||||||
|
return find_best(any_thinking)
|
||||||
|
|
||||||
|
# Finally, just prefer Pro models even without thinking
|
||||||
|
pro_models = [m for m in allowed_models if "pro" in m]
|
||||||
|
if pro_models:
|
||||||
|
return find_best(pro_models)
|
||||||
|
|
||||||
|
elif category == ToolModelCategory.FAST_RESPONSE:
|
||||||
|
# Prefer Flash models for speed
|
||||||
|
flash_models = [m for m in allowed_models if "flash" in m]
|
||||||
|
if flash_models:
|
||||||
|
return find_best(flash_models)
|
||||||
|
|
||||||
|
# Default for BALANCED or as fallback
|
||||||
|
# Prefer Flash for balanced use, then Pro, then anything
|
||||||
|
flash_models = [m for m in allowed_models if "flash" in m]
|
||||||
|
if flash_models:
|
||||||
|
return find_best(flash_models)
|
||||||
|
|
||||||
|
pro_models = [m for m in allowed_models if "pro" in m]
|
||||||
|
if pro_models:
|
||||||
|
return find_best(pro_models)
|
||||||
|
|
||||||
|
# Ultimate fallback to best available model
|
||||||
|
return find_best(allowed_models)
|
||||||
826
providers/openai_compatible.py
Normal file
826
providers/openai_compatible.py
Normal file
|
|
@ -0,0 +1,826 @@
|
||||||
|
"""Base class for OpenAI-compatible API providers."""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import ipaddress
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from utils.image_utils import validate_image
|
||||||
|
|
||||||
|
from .base import ModelProvider
|
||||||
|
from .shared import (
|
||||||
|
ModelCapabilities,
|
||||||
|
ModelResponse,
|
||||||
|
ProviderType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompatibleProvider(ModelProvider):
|
||||||
|
"""Shared implementation for OpenAI API lookalikes.
|
||||||
|
|
||||||
|
The class owns HTTP client configuration (timeouts, proxy hardening,
|
||||||
|
custom headers) and normalises the OpenAI SDK responses into
|
||||||
|
:class:`~providers.shared.ModelResponse`. Concrete subclasses only need to
|
||||||
|
provide capability metadata and any provider-specific request tweaks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_HEADERS = {}
|
||||||
|
FRIENDLY_NAME = "OpenAI Compatible"
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: str = None, **kwargs):
|
||||||
|
"""Initialize the provider with API key and optional base URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for authentication
|
||||||
|
base_url: Base URL for the API endpoint
|
||||||
|
**kwargs: Additional configuration options including timeout
|
||||||
|
"""
|
||||||
|
super().__init__(api_key, **kwargs)
|
||||||
|
self._client = None
|
||||||
|
self.base_url = base_url
|
||||||
|
self.organization = kwargs.get("organization")
|
||||||
|
self.allowed_models = self._parse_allowed_models()
|
||||||
|
|
||||||
|
# Configure timeouts - especially important for custom/local endpoints
|
||||||
|
self.timeout_config = self._configure_timeouts(**kwargs)
|
||||||
|
|
||||||
|
# Validate base URL for security
|
||||||
|
if self.base_url:
|
||||||
|
self._validate_base_url()
|
||||||
|
|
||||||
|
# Warn if using external URL without authentication
|
||||||
|
if self.base_url and not self._is_localhost_url() and not api_key:
|
||||||
|
logging.warning(
|
||||||
|
f"Using external URL '{self.base_url}' without API key. "
|
||||||
|
"This may be insecure. Consider setting an API key for authentication."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _ensure_model_allowed(
|
||||||
|
self,
|
||||||
|
capabilities: ModelCapabilities,
|
||||||
|
canonical_name: str,
|
||||||
|
requested_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Respect provider-specific allowlists before default restriction checks."""
|
||||||
|
|
||||||
|
super()._ensure_model_allowed(capabilities, canonical_name, requested_name)
|
||||||
|
|
||||||
|
if self.allowed_models is not None:
|
||||||
|
requested = requested_name.lower()
|
||||||
|
canonical = canonical_name.lower()
|
||||||
|
|
||||||
|
if requested not in self.allowed_models and canonical not in self.allowed_models:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_allowed_models(self) -> Optional[set[str]]:
|
||||||
|
"""Parse allowed models from environment variable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of allowed model names (lowercase) or None if not configured
|
||||||
|
"""
|
||||||
|
# Get provider-specific allowed models
|
||||||
|
provider_type = self.get_provider_type().value.upper()
|
||||||
|
env_var = f"{provider_type}_ALLOWED_MODELS"
|
||||||
|
models_str = os.getenv(env_var, "")
|
||||||
|
|
||||||
|
if models_str:
|
||||||
|
# Parse and normalize to lowercase for case-insensitive comparison
|
||||||
|
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
|
||||||
|
if models:
|
||||||
|
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
|
||||||
|
return models
|
||||||
|
|
||||||
|
# Log info if no allow-list configured for proxy providers
|
||||||
|
if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
|
||||||
|
logging.info(
|
||||||
|
f"Model allow-list not configured for {self.FRIENDLY_NAME} - all models permitted. "
|
||||||
|
f"To restrict access, set {env_var} with comma-separated model names."
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _configure_timeouts(self, **kwargs):
|
||||||
|
"""Configure timeout settings based on provider type and custom settings.
|
||||||
|
|
||||||
|
Custom URLs and local models often need longer timeouts due to:
|
||||||
|
- Network latency on local networks
|
||||||
|
- Extended thinking models taking longer to respond
|
||||||
|
- Local inference being slower than cloud APIs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
httpx.Timeout object with appropriate timeout settings
|
||||||
|
"""
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
# Default timeouts - more generous for custom/local endpoints
|
||||||
|
default_connect = 30.0 # 30 seconds for connection (vs OpenAI's 5s)
|
||||||
|
default_read = 600.0 # 10 minutes for reading (same as OpenAI default)
|
||||||
|
default_write = 600.0 # 10 minutes for writing
|
||||||
|
default_pool = 600.0 # 10 minutes for pool
|
||||||
|
|
||||||
|
# For custom/local URLs, use even longer timeouts
|
||||||
|
if self.base_url and self._is_localhost_url():
|
||||||
|
default_connect = 60.0 # 1 minute for local connections
|
||||||
|
default_read = 1800.0 # 30 minutes for local models (extended thinking)
|
||||||
|
default_write = 1800.0 # 30 minutes for local models
|
||||||
|
default_pool = 1800.0 # 30 minutes for local models
|
||||||
|
logging.info(f"Using extended timeouts for local endpoint: {self.base_url}")
|
||||||
|
elif self.base_url:
|
||||||
|
default_connect = 45.0 # 45 seconds for custom remote endpoints
|
||||||
|
default_read = 900.0 # 15 minutes for custom remote endpoints
|
||||||
|
default_write = 900.0 # 15 minutes for custom remote endpoints
|
||||||
|
default_pool = 900.0 # 15 minutes for custom remote endpoints
|
||||||
|
logging.info(f"Using extended timeouts for custom endpoint: {self.base_url}")
|
||||||
|
|
||||||
|
# Allow override via kwargs or environment variables in future, for now...
|
||||||
|
connect_timeout = kwargs.get("connect_timeout", float(os.getenv("CUSTOM_CONNECT_TIMEOUT", default_connect)))
|
||||||
|
read_timeout = kwargs.get("read_timeout", float(os.getenv("CUSTOM_READ_TIMEOUT", default_read)))
|
||||||
|
write_timeout = kwargs.get("write_timeout", float(os.getenv("CUSTOM_WRITE_TIMEOUT", default_write)))
|
||||||
|
pool_timeout = kwargs.get("pool_timeout", float(os.getenv("CUSTOM_POOL_TIMEOUT", default_pool)))
|
||||||
|
|
||||||
|
timeout = httpx.Timeout(connect=connect_timeout, read=read_timeout, write=write_timeout, pool=pool_timeout)
|
||||||
|
|
||||||
|
logging.debug(
|
||||||
|
f"Configured timeouts - Connect: {connect_timeout}s, Read: {read_timeout}s, "
|
||||||
|
f"Write: {write_timeout}s, Pool: {pool_timeout}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
return timeout
|
||||||
|
|
||||||
|
def _is_localhost_url(self) -> bool:
|
||||||
|
"""Check if the base URL points to localhost or local network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if URL is localhost or local network, False otherwise
|
||||||
|
"""
|
||||||
|
if not self.base_url:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(self.base_url)
|
||||||
|
hostname = parsed.hostname
|
||||||
|
|
||||||
|
# Check for common localhost patterns
|
||||||
|
if hostname in ["localhost", "127.0.0.1", "::1"]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for private network ranges (local network)
|
||||||
|
if hostname:
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(hostname)
|
||||||
|
return ip.is_private or ip.is_loopback
|
||||||
|
except ValueError:
|
||||||
|
# Not an IP address, might be a hostname
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _validate_base_url(self) -> None:
|
||||||
|
"""Validate base URL for security (SSRF protection).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If URL is invalid or potentially unsafe
|
||||||
|
"""
|
||||||
|
if not self.base_url:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(self.base_url)
|
||||||
|
|
||||||
|
# Check URL scheme - only allow http/https
|
||||||
|
if parsed.scheme not in ("http", "https"):
|
||||||
|
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
|
||||||
|
|
||||||
|
# Check hostname exists
|
||||||
|
if not parsed.hostname:
|
||||||
|
raise ValueError("URL must include a hostname")
|
||||||
|
|
||||||
|
# Check port is valid (if specified)
|
||||||
|
port = parsed.port
|
||||||
|
if port is not None and (port < 1 or port > 65535):
|
||||||
|
raise ValueError(f"Invalid port number: {port}. Must be between 1 and 65535.")
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, ValueError):
|
||||||
|
raise
|
||||||
|
raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self):
|
||||||
|
"""Lazy initialization of OpenAI client with security checks and timeout configuration."""
|
||||||
|
if self._client is None:
|
||||||
|
import os
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
# Temporarily disable proxy environment variables to prevent httpx from detecting them
|
||||||
|
original_env = {}
|
||||||
|
proxy_env_vars = ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]
|
||||||
|
|
||||||
|
for var in proxy_env_vars:
|
||||||
|
if var in os.environ:
|
||||||
|
original_env[var] = os.environ[var]
|
||||||
|
del os.environ[var]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create a custom httpx client that explicitly avoids proxy parameters
|
||||||
|
timeout_config = (
|
||||||
|
self.timeout_config
|
||||||
|
if hasattr(self, "timeout_config") and self.timeout_config
|
||||||
|
else httpx.Timeout(30.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create httpx client with minimal config to avoid proxy conflicts
|
||||||
|
# Note: proxies parameter was removed in httpx 0.28.0
|
||||||
|
# Check for test transport injection
|
||||||
|
if hasattr(self, "_test_transport"):
|
||||||
|
# Use custom transport for testing (HTTP recording/replay)
|
||||||
|
http_client = httpx.Client(
|
||||||
|
transport=self._test_transport,
|
||||||
|
timeout=timeout_config,
|
||||||
|
follow_redirects=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Normal production client
|
||||||
|
http_client = httpx.Client(
|
||||||
|
timeout=timeout_config,
|
||||||
|
follow_redirects=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep client initialization minimal to avoid proxy parameter conflicts
|
||||||
|
client_kwargs = {
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"http_client": http_client,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.base_url:
|
||||||
|
client_kwargs["base_url"] = self.base_url
|
||||||
|
|
||||||
|
if self.organization:
|
||||||
|
client_kwargs["organization"] = self.organization
|
||||||
|
|
||||||
|
# Add default headers if any
|
||||||
|
if self.DEFAULT_HEADERS:
|
||||||
|
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
|
||||||
|
|
||||||
|
logging.debug(f"OpenAI client initialized with custom httpx client and timeout: {timeout_config}")
|
||||||
|
|
||||||
|
# Create OpenAI client with custom httpx client
|
||||||
|
self._client = OpenAI(**client_kwargs)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# If all else fails, try absolute minimal client without custom httpx
|
||||||
|
logging.warning(f"Failed to create client with custom httpx, falling back to minimal config: {e}")
|
||||||
|
try:
|
||||||
|
minimal_kwargs = {"api_key": self.api_key}
|
||||||
|
if self.base_url:
|
||||||
|
minimal_kwargs["base_url"] = self.base_url
|
||||||
|
self._client = OpenAI(**minimal_kwargs)
|
||||||
|
except Exception as fallback_error:
|
||||||
|
logging.error(f"Even minimal OpenAI client creation failed: {fallback_error}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Restore original proxy environment variables
|
||||||
|
for var, value in original_env.items():
|
||||||
|
os.environ[var] = value
|
||||||
|
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
def _sanitize_for_logging(self, params: dict) -> dict:
|
||||||
|
"""Sanitize sensitive data from parameters before logging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Dictionary of API parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Sanitized copy of parameters safe for logging
|
||||||
|
"""
|
||||||
|
sanitized = copy.deepcopy(params)
|
||||||
|
|
||||||
|
# Sanitize messages content
|
||||||
|
if "input" in sanitized:
|
||||||
|
for msg in sanitized.get("input", []):
|
||||||
|
if isinstance(msg, dict) and "content" in msg:
|
||||||
|
for content_item in msg.get("content", []):
|
||||||
|
if isinstance(content_item, dict) and "text" in content_item:
|
||||||
|
# Truncate long text and add ellipsis
|
||||||
|
text = content_item["text"]
|
||||||
|
if len(text) > 100:
|
||||||
|
content_item["text"] = text[:100] + "... [truncated]"
|
||||||
|
|
||||||
|
# Remove any API keys that might be in headers/auth
|
||||||
|
sanitized.pop("api_key", None)
|
||||||
|
sanitized.pop("authorization", None)
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
def _safe_extract_output_text(self, response) -> str:
|
||||||
|
"""Safely extract output_text from o3-pro response with validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Response object from OpenAI SDK
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The output text content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If output_text is missing, None, or not a string
|
||||||
|
"""
|
||||||
|
logging.debug(f"Response object type: {type(response)}")
|
||||||
|
logging.debug(f"Response attributes: {dir(response)}")
|
||||||
|
|
||||||
|
if not hasattr(response, "output_text"):
|
||||||
|
raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}")
|
||||||
|
|
||||||
|
content = response.output_text
|
||||||
|
logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})")
|
||||||
|
|
||||||
|
if content is None:
|
||||||
|
raise ValueError("o3-pro returned None for output_text")
|
||||||
|
|
||||||
|
if not isinstance(content, str):
|
||||||
|
raise ValueError(f"o3-pro output_text is not a string. Got type: {type(content).__name__}")
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _generate_with_responses_endpoint(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
messages: list,
|
||||||
|
temperature: float,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using the /v1/responses endpoint for o3-pro via OpenAI library."""
|
||||||
|
# Convert messages to the correct format for responses endpoint
|
||||||
|
input_messages = []
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
role = message.get("role", "")
|
||||||
|
content = message.get("content", "")
|
||||||
|
|
||||||
|
if role == "system":
|
||||||
|
# For o3-pro, system messages should be handled carefully to avoid policy violations
|
||||||
|
# Instead of prefixing with "System:", we'll include the system content naturally
|
||||||
|
input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]})
|
||||||
|
elif role == "user":
|
||||||
|
input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]})
|
||||||
|
elif role == "assistant":
|
||||||
|
input_messages.append({"role": "assistant", "content": [{"type": "output_text", "text": content}]})
|
||||||
|
|
||||||
|
# Prepare completion parameters for responses endpoint
|
||||||
|
# Based on OpenAI documentation, use nested reasoning object for responses endpoint
|
||||||
|
completion_params = {
|
||||||
|
"model": model_name,
|
||||||
|
"input": input_messages,
|
||||||
|
"reasoning": {"effort": "medium"}, # Use nested object for responses endpoint
|
||||||
|
"store": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add max tokens if specified (using max_completion_tokens for responses endpoint)
|
||||||
|
if max_output_tokens:
|
||||||
|
completion_params["max_completion_tokens"] = max_output_tokens
|
||||||
|
|
||||||
|
# For responses endpoint, we only add parameters that are explicitly supported
|
||||||
|
# Remove unsupported chat completion parameters that may cause API errors
|
||||||
|
|
||||||
|
# Retry logic with progressive delays
|
||||||
|
max_retries = 4
|
||||||
|
retry_delays = [1, 3, 5, 8]
|
||||||
|
last_exception = None
|
||||||
|
actual_attempts = 0
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try: # Log sanitized payload for debugging
|
||||||
|
import json
|
||||||
|
|
||||||
|
sanitized_params = self._sanitize_for_logging(completion_params)
|
||||||
|
logging.info(
|
||||||
|
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use OpenAI client's responses endpoint
|
||||||
|
response = self.client.responses.create(**completion_params)
|
||||||
|
|
||||||
|
# Extract content from responses endpoint format
|
||||||
|
# Use validation helper to safely extract output_text
|
||||||
|
content = self._safe_extract_output_text(response)
|
||||||
|
|
||||||
|
# Try to extract usage information
|
||||||
|
usage = None
|
||||||
|
if hasattr(response, "usage"):
|
||||||
|
usage = self._extract_usage(response)
|
||||||
|
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
|
||||||
|
# Safely extract token counts with None handling
|
||||||
|
input_tokens = getattr(response, "input_tokens", 0) or 0
|
||||||
|
output_tokens = getattr(response, "output_tokens", 0) or 0
|
||||||
|
usage = {
|
||||||
|
"input_tokens": input_tokens,
|
||||||
|
"output_tokens": output_tokens,
|
||||||
|
"total_tokens": input_tokens + output_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
return ModelResponse(
|
||||||
|
content=content,
|
||||||
|
usage=usage,
|
||||||
|
model_name=model_name,
|
||||||
|
friendly_name=self.FRIENDLY_NAME,
|
||||||
|
provider=self.get_provider_type(),
|
||||||
|
metadata={
|
||||||
|
"model": getattr(response, "model", model_name),
|
||||||
|
"id": getattr(response, "id", ""),
|
||||||
|
"created": getattr(response, "created_at", 0),
|
||||||
|
"endpoint": "responses",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
# Check if this is a retryable error using structured error codes
|
||||||
|
is_retryable = self._is_error_retryable(e)
|
||||||
|
|
||||||
|
if is_retryable and attempt < max_retries - 1:
|
||||||
|
delay = retry_delays[attempt]
|
||||||
|
logging.warning(
|
||||||
|
f"Retryable error for o3-pro responses endpoint, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# If we get here, all retries failed
|
||||||
|
error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from last_exception
|
||||||
|
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.3,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
images: Optional[list[str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using the OpenAI-compatible API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt to send to the model
|
||||||
|
model_name: Name of the model to use
|
||||||
|
system_prompt: Optional system prompt for model behavior
|
||||||
|
temperature: Sampling temperature
|
||||||
|
max_output_tokens: Maximum tokens to generate
|
||||||
|
**kwargs: Additional provider-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelResponse with generated content and metadata
|
||||||
|
"""
|
||||||
|
# Validate model name against allow-list
|
||||||
|
if not self.validate_model_name(model_name):
|
||||||
|
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
|
||||||
|
|
||||||
|
capabilities: Optional[ModelCapabilities]
|
||||||
|
try:
|
||||||
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
except Exception as exc:
|
||||||
|
logging.debug(f"Falling back to generic capabilities for {model_name}: {exc}")
|
||||||
|
capabilities = None
|
||||||
|
|
||||||
|
# Get effective temperature for this model from capabilities when available
|
||||||
|
if capabilities:
|
||||||
|
effective_temperature = capabilities.get_effective_temperature(temperature)
|
||||||
|
if effective_temperature is not None and effective_temperature != temperature:
|
||||||
|
logging.debug(
|
||||||
|
f"Adjusting temperature from {temperature} to {effective_temperature} for model {model_name}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
effective_temperature = temperature
|
||||||
|
|
||||||
|
# Only validate if temperature is not None (meaning the model supports it)
|
||||||
|
if effective_temperature is not None:
|
||||||
|
# Validate parameters with the effective temperature
|
||||||
|
self.validate_parameters(model_name, effective_temperature)
|
||||||
|
|
||||||
|
# Prepare messages
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
|
# Prepare user message with text and potentially images
|
||||||
|
user_content = []
|
||||||
|
user_content.append({"type": "text", "text": prompt})
|
||||||
|
|
||||||
|
# Add images if provided and model supports vision
|
||||||
|
if images and capabilities and capabilities.supports_images:
|
||||||
|
for image_path in images:
|
||||||
|
try:
|
||||||
|
image_content = self._process_image(image_path)
|
||||||
|
if image_content:
|
||||||
|
user_content.append(image_content)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to process image {image_path}: {e}")
|
||||||
|
# Continue with other images and text
|
||||||
|
continue
|
||||||
|
elif images and (not capabilities or not capabilities.supports_images):
|
||||||
|
logging.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)")
|
||||||
|
|
||||||
|
# Add user message
|
||||||
|
if len(user_content) == 1:
|
||||||
|
# Only text content, use simple string format for compatibility
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
else:
|
||||||
|
# Text + images, use content array format
|
||||||
|
messages.append({"role": "user", "content": user_content})
|
||||||
|
|
||||||
|
# Prepare completion parameters
|
||||||
|
completion_params = {
|
||||||
|
"model": model_name,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check model capabilities once to determine parameter support
|
||||||
|
resolved_model = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
# Use the effective temperature we calculated earlier
|
||||||
|
supports_sampling = effective_temperature is not None
|
||||||
|
|
||||||
|
if supports_sampling:
|
||||||
|
completion_params["temperature"] = effective_temperature
|
||||||
|
|
||||||
|
# Add max tokens if specified and model supports it
|
||||||
|
# O3/O4 models that don't support temperature also don't support max_tokens
|
||||||
|
if max_output_tokens and supports_sampling:
|
||||||
|
completion_params["max_tokens"] = max_output_tokens
|
||||||
|
|
||||||
|
# Add any additional OpenAI-specific parameters
|
||||||
|
# Use capabilities to filter parameters for reasoning models
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
|
||||||
|
# Reasoning models (those that don't support temperature) also don't support these parameters
|
||||||
|
if not supports_sampling and key in ["top_p", "frequency_penalty", "presence_penalty"]:
|
||||||
|
continue # Skip unsupported parameters for reasoning models
|
||||||
|
completion_params[key] = value
|
||||||
|
|
||||||
|
# Check if this is o3-pro and needs the responses endpoint
|
||||||
|
if resolved_model == "o3-pro":
|
||||||
|
# This model requires the /v1/responses endpoint
|
||||||
|
# If it fails, we should not fall back to chat/completions
|
||||||
|
return self._generate_with_responses_endpoint(
|
||||||
|
model_name=resolved_model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retry logic with progressive delays
|
||||||
|
max_retries = 4 # Total of 4 attempts
|
||||||
|
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
|
||||||
|
|
||||||
|
last_exception = None
|
||||||
|
actual_attempts = 0
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count
|
||||||
|
try:
|
||||||
|
# Generate completion
|
||||||
|
response = self.client.chat.completions.create(**completion_params)
|
||||||
|
|
||||||
|
# Extract content and usage
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
usage = self._extract_usage(response)
|
||||||
|
|
||||||
|
return ModelResponse(
|
||||||
|
content=content,
|
||||||
|
usage=usage,
|
||||||
|
model_name=model_name,
|
||||||
|
friendly_name=self.FRIENDLY_NAME,
|
||||||
|
provider=self.get_provider_type(),
|
||||||
|
metadata={
|
||||||
|
"finish_reason": response.choices[0].finish_reason,
|
||||||
|
"model": response.model, # Actual model used
|
||||||
|
"id": response.id,
|
||||||
|
"created": response.created,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
# Check if this is a retryable error using structured error codes
|
||||||
|
is_retryable = self._is_error_retryable(e)
|
||||||
|
|
||||||
|
# If this is the last attempt or not retryable, give up
|
||||||
|
if attempt == max_retries - 1 or not is_retryable:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get progressive delay
|
||||||
|
delay = retry_delays[attempt]
|
||||||
|
|
||||||
|
# Log retry attempt
|
||||||
|
logging.warning(
|
||||||
|
f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..."
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
# If we get here, all retries failed
|
||||||
|
error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}"
|
||||||
|
logging.error(error_msg)
|
||||||
|
raise RuntimeError(error_msg) from last_exception
|
||||||
|
|
||||||
|
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
|
||||||
|
"""Validate model parameters.
|
||||||
|
|
||||||
|
For proxy providers, this may use generic capabilities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Model to validate for
|
||||||
|
temperature: Temperature to validate
|
||||||
|
**kwargs: Additional parameters to validate
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
capabilities = self.get_capabilities(model_name)
|
||||||
|
|
||||||
|
# Check if we're using generic capabilities
|
||||||
|
if hasattr(capabilities, "_is_generic"):
|
||||||
|
logging.debug(
|
||||||
|
f"Using generic parameter validation for {model_name}. Actual model constraints may differ."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate temperature using parent class method
|
||||||
|
super().validate_parameters(model_name, temperature, **kwargs)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# For proxy providers, we might not have accurate capabilities
|
||||||
|
# Log warning but don't fail
|
||||||
|
logging.warning(f"Parameter validation limited for {model_name}: {e}")
|
||||||
|
|
||||||
|
def _extract_usage(self, response) -> dict[str, int]:
|
||||||
|
"""Extract token usage from OpenAI response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: OpenAI API response object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with usage statistics
|
||||||
|
"""
|
||||||
|
usage = {}
|
||||||
|
|
||||||
|
if hasattr(response, "usage") and response.usage:
|
||||||
|
# Safely extract token counts with None handling
|
||||||
|
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0) or 0
|
||||||
|
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0) or 0
|
||||||
|
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0) or 0
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
def count_tokens(self, text: str, model_name: str) -> int:
|
||||||
|
"""Count tokens using OpenAI-compatible tokenizer tables when available."""
|
||||||
|
|
||||||
|
resolved_model = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
try:
|
||||||
|
encoding = tiktoken.encoding_for_model(resolved_model)
|
||||||
|
except KeyError:
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
return len(encoding.encode(text))
|
||||||
|
|
||||||
|
except (ImportError, Exception) as exc:
|
||||||
|
logging.debug("tiktoken unavailable for %s: %s", resolved_model, exc)
|
||||||
|
|
||||||
|
return super().count_tokens(text, model_name)
|
||||||
|
|
||||||
|
def _is_error_retryable(self, error: Exception) -> bool:
|
||||||
|
"""Determine if an error should be retried based on structured error codes.
|
||||||
|
|
||||||
|
Uses OpenAI API error structure instead of text pattern matching for reliability.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: Exception from OpenAI API call
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if error should be retried, False otherwise
|
||||||
|
"""
|
||||||
|
error_str = str(error).lower()
|
||||||
|
|
||||||
|
# Check for 429 errors first - these need special handling
|
||||||
|
if "429" in error_str:
|
||||||
|
# Try to extract structured error information
|
||||||
|
error_type = None
|
||||||
|
error_code = None
|
||||||
|
|
||||||
|
# Parse structured error from OpenAI API response
|
||||||
|
# Format: "Error code: 429 - {'error': {'type': 'tokens', 'code': 'rate_limit_exceeded', ...}}"
|
||||||
|
try:
|
||||||
|
import ast
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Extract JSON part from error string using regex
|
||||||
|
# Look for pattern: {...} (from first { to last })
|
||||||
|
json_match = re.search(r"\{.*\}", str(error))
|
||||||
|
if json_match:
|
||||||
|
json_like_str = json_match.group(0)
|
||||||
|
|
||||||
|
# First try: parse as Python literal (handles single quotes safely)
|
||||||
|
try:
|
||||||
|
error_data = ast.literal_eval(json_like_str)
|
||||||
|
except (ValueError, SyntaxError):
|
||||||
|
# Fallback: try JSON parsing with simple quote replacement
|
||||||
|
# (for cases where it's already valid JSON or simple replacements work)
|
||||||
|
json_str = json_like_str.replace("'", '"')
|
||||||
|
error_data = json.loads(json_str)
|
||||||
|
|
||||||
|
if "error" in error_data:
|
||||||
|
error_info = error_data["error"]
|
||||||
|
error_type = error_info.get("type")
|
||||||
|
error_code = error_info.get("code")
|
||||||
|
|
||||||
|
except (json.JSONDecodeError, ValueError, SyntaxError, AttributeError):
|
||||||
|
# Fall back to checking hasattr for OpenAI SDK exception objects
|
||||||
|
if hasattr(error, "response") and hasattr(error.response, "json"):
|
||||||
|
try:
|
||||||
|
response_data = error.response.json()
|
||||||
|
if "error" in response_data:
|
||||||
|
error_info = response_data["error"]
|
||||||
|
error_type = error_info.get("type")
|
||||||
|
error_code = error_info.get("code")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Determine if 429 is retryable based on structured error codes
|
||||||
|
if error_type == "tokens":
|
||||||
|
# Token-related 429s are typically non-retryable (request too large)
|
||||||
|
logging.debug(f"Non-retryable 429: token-related error (type={error_type}, code={error_code})")
|
||||||
|
return False
|
||||||
|
elif error_code in ["invalid_request_error", "context_length_exceeded"]:
|
||||||
|
# These are permanent failures
|
||||||
|
logging.debug(f"Non-retryable 429: permanent failure (type={error_type}, code={error_code})")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# Other 429s (like requests per minute) are retryable
|
||||||
|
logging.debug(f"Retryable 429: rate limiting (type={error_type}, code={error_code})")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# For non-429 errors, check if they're retryable
|
||||||
|
retryable_indicators = [
|
||||||
|
"timeout",
|
||||||
|
"connection",
|
||||||
|
"network",
|
||||||
|
"temporary",
|
||||||
|
"unavailable",
|
||||||
|
"retry",
|
||||||
|
"408", # Request timeout
|
||||||
|
"500", # Internal server error
|
||||||
|
"502", # Bad gateway
|
||||||
|
"503", # Service unavailable
|
||||||
|
"504", # Gateway timeout
|
||||||
|
"ssl", # SSL errors
|
||||||
|
"handshake", # Handshake failures
|
||||||
|
]
|
||||||
|
|
||||||
|
return any(indicator in error_str for indicator in retryable_indicators)
|
||||||
|
|
||||||
|
def _process_image(self, image_path: str) -> Optional[dict]:
|
||||||
|
"""Process an image for OpenAI-compatible API."""
|
||||||
|
try:
|
||||||
|
if image_path.startswith("data:"):
|
||||||
|
# Validate the data URL
|
||||||
|
validate_image(image_path)
|
||||||
|
# Handle data URL: data:image/png;base64,iVBORw0...
|
||||||
|
return {"type": "image_url", "image_url": {"url": image_path}}
|
||||||
|
else:
|
||||||
|
# Use base class validation
|
||||||
|
image_bytes, mime_type = validate_image(image_path)
|
||||||
|
|
||||||
|
# Read and encode the image
|
||||||
|
import base64
|
||||||
|
|
||||||
|
image_data = base64.b64encode(image_bytes).decode()
|
||||||
|
logging.debug(f"Processing image '{image_path}' as MIME type '{mime_type}'")
|
||||||
|
|
||||||
|
# Create data URL for OpenAI API
|
||||||
|
data_url = f"data:{mime_type};base64,{image_data}"
|
||||||
|
|
||||||
|
return {"type": "image_url", "image_url": {"url": data_url}}
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logging.warning(str(e))
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error processing image {image_path}: {e}")
|
||||||
|
return None
|
||||||
296
providers/openai_provider.py
Normal file
296
providers/openai_provider.py
Normal file
|
|
@ -0,0 +1,296 @@
|
||||||
|
"""OpenAI model provider implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIModelProvider(OpenAICompatibleProvider):
|
||||||
|
"""Implementation that talks to api.openai.com using rich model metadata.
|
||||||
|
|
||||||
|
In addition to the built-in catalogue, the provider can surface models
|
||||||
|
defined in ``conf/custom_models.json`` (for organisations running their own
|
||||||
|
OpenAI-compatible gateways) while still respecting restriction policies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Model configurations using ModelCapabilities objects
|
||||||
|
MODEL_CAPABILITIES = {
|
||||||
|
"gpt-5": ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENAI,
|
||||||
|
model_name="gpt-5",
|
||||||
|
friendly_name="OpenAI (GPT-5)",
|
||||||
|
context_window=400_000, # 400K tokens
|
||||||
|
max_output_tokens=128_000, # 128K max output tokens
|
||||||
|
supports_extended_thinking=True, # Supports reasoning tokens
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # GPT-5 supports vision
|
||||||
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
|
supports_temperature=True, # Regular models accept temperature parameter
|
||||||
|
temperature_constraint=TemperatureConstraint.create("fixed"),
|
||||||
|
description="GPT-5 (400K context, 128K output) - Advanced model with reasoning support",
|
||||||
|
aliases=["gpt5"],
|
||||||
|
),
|
||||||
|
"gpt-5-mini": ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENAI,
|
||||||
|
model_name="gpt-5-mini",
|
||||||
|
friendly_name="OpenAI (GPT-5-mini)",
|
||||||
|
context_window=400_000, # 400K tokens
|
||||||
|
max_output_tokens=128_000, # 128K max output tokens
|
||||||
|
supports_extended_thinking=True, # Supports reasoning tokens
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # GPT-5-mini supports vision
|
||||||
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("fixed"),
|
||||||
|
description="GPT-5-mini (400K context, 128K output) - Efficient variant with reasoning support",
|
||||||
|
aliases=["gpt5-mini", "gpt5mini", "mini"],
|
||||||
|
),
|
||||||
|
"gpt-5-nano": ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENAI,
|
||||||
|
model_name="gpt-5-nano",
|
||||||
|
friendly_name="OpenAI (GPT-5 nano)",
|
||||||
|
context_window=400_000,
|
||||||
|
max_output_tokens=128_000,
|
||||||
|
supports_extended_thinking=True,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True,
|
||||||
|
max_image_size_mb=20.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("fixed"),
|
||||||
|
description="GPT-5 nano (400K context) - Fastest, cheapest version of GPT-5 for summarization and classification tasks",
|
||||||
|
aliases=["gpt5nano", "gpt5-nano", "nano"],
|
||||||
|
),
|
||||||
|
"o3": ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENAI,
|
||||||
|
model_name="o3",
|
||||||
|
friendly_name="OpenAI (O3)",
|
||||||
|
context_window=200_000, # 200K tokens
|
||||||
|
max_output_tokens=65536, # 64K max output tokens
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # O3 models support vision
|
||||||
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
|
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||||
|
temperature_constraint=TemperatureConstraint.create("fixed"),
|
||||||
|
description="Strong reasoning (200K context) - Logical problems, code generation, systematic analysis",
|
||||||
|
aliases=[],
|
||||||
|
),
|
||||||
|
"o3-mini": ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENAI,
|
||||||
|
model_name="o3-mini",
|
||||||
|
friendly_name="OpenAI (O3-mini)",
|
||||||
|
context_window=200_000, # 200K tokens
|
||||||
|
max_output_tokens=65536, # 64K max output tokens
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # O3 models support vision
|
||||||
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
|
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||||
|
temperature_constraint=TemperatureConstraint.create("fixed"),
|
||||||
|
description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity",
|
||||||
|
aliases=["o3mini"],
|
||||||
|
),
|
||||||
|
"o3-pro": ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENAI,
|
||||||
|
model_name="o3-pro",
|
||||||
|
friendly_name="OpenAI (O3-Pro)",
|
||||||
|
context_window=200_000, # 200K tokens
|
||||||
|
max_output_tokens=65536, # 64K max output tokens
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # O3 models support vision
|
||||||
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
|
supports_temperature=False, # O3 models don't accept temperature parameter
|
||||||
|
temperature_constraint=TemperatureConstraint.create("fixed"),
|
||||||
|
description="Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.",
|
||||||
|
aliases=["o3pro"],
|
||||||
|
),
|
||||||
|
"o4-mini": ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENAI,
|
||||||
|
model_name="o4-mini",
|
||||||
|
friendly_name="OpenAI (O4-mini)",
|
||||||
|
context_window=200_000, # 200K tokens
|
||||||
|
max_output_tokens=65536, # 64K max output tokens
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # O4 models support vision
|
||||||
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
|
supports_temperature=False, # O4 models don't accept temperature parameter
|
||||||
|
temperature_constraint=TemperatureConstraint.create("fixed"),
|
||||||
|
description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning",
|
||||||
|
aliases=["o4mini"],
|
||||||
|
),
|
||||||
|
"gpt-4.1": ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENAI,
|
||||||
|
model_name="gpt-4.1",
|
||||||
|
friendly_name="OpenAI (GPT 4.1)",
|
||||||
|
context_window=1_000_000, # 1M tokens
|
||||||
|
max_output_tokens=32_768,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=True,
|
||||||
|
supports_images=True, # GPT-4.1 supports vision
|
||||||
|
max_image_size_mb=20.0, # 20MB per OpenAI docs
|
||||||
|
supports_temperature=True, # Regular models accept temperature parameter
|
||||||
|
temperature_constraint=TemperatureConstraint.create("range"),
|
||||||
|
description="GPT-4.1 (1M context) - Advanced reasoning model with large context window",
|
||||||
|
aliases=["gpt4.1"],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, **kwargs):
|
||||||
|
"""Initialize OpenAI provider with API key."""
|
||||||
|
# Set default OpenAI base URL, allow override for regions/custom endpoints
|
||||||
|
kwargs.setdefault("base_url", "https://api.openai.com/v1")
|
||||||
|
super().__init__(api_key, **kwargs)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Capability surface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _lookup_capabilities(
|
||||||
|
self,
|
||||||
|
canonical_name: str,
|
||||||
|
requested_name: Optional[str] = None,
|
||||||
|
) -> Optional[ModelCapabilities]:
|
||||||
|
"""Look up OpenAI capabilities from built-ins or the custom registry."""
|
||||||
|
|
||||||
|
builtin = super()._lookup_capabilities(canonical_name, requested_name)
|
||||||
|
if builtin is not None:
|
||||||
|
return builtin
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .openrouter_registry import OpenRouterModelRegistry
|
||||||
|
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
config = registry.get_model_config(canonical_name)
|
||||||
|
|
||||||
|
if config and config.provider == ProviderType.OPENAI:
|
||||||
|
return config
|
||||||
|
|
||||||
|
except Exception as exc: # pragma: no cover - registry failures are non-critical
|
||||||
|
logger.debug(f"Could not resolve custom OpenAI model '{canonical_name}': {exc}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _finalise_capabilities(
|
||||||
|
self,
|
||||||
|
capabilities: ModelCapabilities,
|
||||||
|
canonical_name: str,
|
||||||
|
requested_name: str,
|
||||||
|
) -> ModelCapabilities:
|
||||||
|
"""Ensure registry-sourced models report the correct provider type."""
|
||||||
|
|
||||||
|
if capabilities.provider != ProviderType.OPENAI:
|
||||||
|
capabilities.provider = ProviderType.OPENAI
|
||||||
|
return capabilities
|
||||||
|
|
||||||
|
def _raise_unsupported_model(self, model_name: str) -> None:
|
||||||
|
raise ValueError(f"Unsupported OpenAI model: {model_name}")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Provider identity
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
"""Get the provider type."""
|
||||||
|
return ProviderType.OPENAI
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Request execution
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.3,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using OpenAI API with proper model name resolution."""
|
||||||
|
# Resolve model alias before making API call
|
||||||
|
resolved_model_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
# Call parent implementation with resolved model name
|
||||||
|
return super().generate_content(
|
||||||
|
prompt=prompt,
|
||||||
|
model_name=resolved_model_name,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Provider preferences
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||||
|
"""Get OpenAI's preferred model for a given category from allowed models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: The tool category requiring a model
|
||||||
|
allowed_models: Pre-filtered list of models allowed by restrictions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preferred model name or None
|
||||||
|
"""
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
|
if not allowed_models:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Helper to find first available from preference list
|
||||||
|
def find_first(preferences: list[str]) -> Optional[str]:
|
||||||
|
"""Return first available model from preference list."""
|
||||||
|
for model in preferences:
|
||||||
|
if model in allowed_models:
|
||||||
|
return model
|
||||||
|
return None
|
||||||
|
|
||||||
|
if category == ToolModelCategory.EXTENDED_REASONING:
|
||||||
|
# Prefer models with extended thinking support
|
||||||
|
preferred = find_first(["o3", "o3-pro", "gpt-5"])
|
||||||
|
return preferred if preferred else allowed_models[0]
|
||||||
|
|
||||||
|
elif category == ToolModelCategory.FAST_RESPONSE:
|
||||||
|
# Prefer fast, cost-efficient models
|
||||||
|
preferred = find_first(["gpt-5", "gpt-5-mini", "o4-mini", "o3-mini"])
|
||||||
|
return preferred if preferred else allowed_models[0]
|
||||||
|
|
||||||
|
else: # BALANCED or default
|
||||||
|
# Prefer balanced performance/cost models
|
||||||
|
preferred = find_first(["gpt-5", "gpt-5-mini", "o4-mini", "o3-mini"])
|
||||||
|
return preferred if preferred else allowed_models[0]
|
||||||
251
providers/openrouter.py
Normal file
251
providers/openrouter.py
Normal file
|
|
@ -0,0 +1,251 @@
|
||||||
|
"""OpenRouter provider implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .openrouter_registry import OpenRouterModelRegistry
|
||||||
|
from .shared import (
|
||||||
|
ModelCapabilities,
|
||||||
|
ModelResponse,
|
||||||
|
ProviderType,
|
||||||
|
RangeTemperatureConstraint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterProvider(OpenAICompatibleProvider):
|
||||||
|
"""Client for OpenRouter's multi-model aggregation service.
|
||||||
|
|
||||||
|
Role
|
||||||
|
Surface OpenRouter’s dynamic catalogue through the same interface as
|
||||||
|
native providers so tools can reference OpenRouter models and aliases
|
||||||
|
without special cases.
|
||||||
|
|
||||||
|
Characteristics
|
||||||
|
* Pulls live model definitions from :class:`OpenRouterModelRegistry`
|
||||||
|
(aliases, provider-specific metadata, capability hints)
|
||||||
|
* Applies alias-aware restriction checks before exposing models to the
|
||||||
|
registry or tooling
|
||||||
|
* Reuses :class:`OpenAICompatibleProvider` infrastructure for request
|
||||||
|
execution so OpenRouter endpoints behave like standard OpenAI-style
|
||||||
|
APIs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
FRIENDLY_NAME = "OpenRouter"
|
||||||
|
|
||||||
|
# Custom headers required by OpenRouter
|
||||||
|
DEFAULT_HEADERS = {
|
||||||
|
"HTTP-Referer": os.getenv("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server"),
|
||||||
|
"X-Title": os.getenv("OPENROUTER_TITLE", "Zen MCP Server"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Model registry for managing configurations and aliases
|
||||||
|
_registry: Optional[OpenRouterModelRegistry] = None
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, **kwargs):
|
||||||
|
"""Initialize OpenRouter provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: OpenRouter API key
|
||||||
|
**kwargs: Additional configuration
|
||||||
|
"""
|
||||||
|
base_url = "https://openrouter.ai/api/v1"
|
||||||
|
super().__init__(api_key, base_url=base_url, **kwargs)
|
||||||
|
|
||||||
|
# Initialize model registry
|
||||||
|
if OpenRouterProvider._registry is None:
|
||||||
|
OpenRouterProvider._registry = OpenRouterModelRegistry()
|
||||||
|
# Log loaded models and aliases only on first load
|
||||||
|
models = self._registry.list_models()
|
||||||
|
aliases = self._registry.list_aliases()
|
||||||
|
logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Capability surface
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _lookup_capabilities(
|
||||||
|
self,
|
||||||
|
canonical_name: str,
|
||||||
|
requested_name: Optional[str] = None,
|
||||||
|
) -> Optional[ModelCapabilities]:
|
||||||
|
"""Fetch OpenRouter capabilities from the registry or build a generic fallback."""
|
||||||
|
|
||||||
|
capabilities = self._registry.get_capabilities(canonical_name)
|
||||||
|
if capabilities:
|
||||||
|
return capabilities
|
||||||
|
|
||||||
|
base_identifier = canonical_name.split(":", 1)[0]
|
||||||
|
if "/" in base_identifier:
|
||||||
|
logging.debug(
|
||||||
|
"Using generic OpenRouter capabilities for %s (provider/model format detected)", canonical_name
|
||||||
|
)
|
||||||
|
generic = ModelCapabilities(
|
||||||
|
provider=ProviderType.OPENROUTER,
|
||||||
|
model_name=canonical_name,
|
||||||
|
friendly_name=self.FRIENDLY_NAME,
|
||||||
|
context_window=32_768,
|
||||||
|
max_output_tokens=32_768,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=False,
|
||||||
|
temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0),
|
||||||
|
)
|
||||||
|
generic._is_generic = True
|
||||||
|
return generic
|
||||||
|
|
||||||
|
logging.debug(
|
||||||
|
"Rejecting unknown OpenRouter model '%s' (no provider prefix); requires explicit configuration",
|
||||||
|
canonical_name,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Provider identity
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
"""Identify this provider for restrictions and logging."""
|
||||||
|
return ProviderType.OPENROUTER
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Request execution
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.3,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using the OpenRouter API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt to send to the model
|
||||||
|
model_name: Name of the model (or alias) to use
|
||||||
|
system_prompt: Optional system prompt for model behavior
|
||||||
|
temperature: Sampling temperature
|
||||||
|
max_output_tokens: Maximum tokens to generate
|
||||||
|
**kwargs: Additional provider-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelResponse with generated content and metadata
|
||||||
|
"""
|
||||||
|
# Resolve model alias to actual OpenRouter model name
|
||||||
|
resolved_model = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
# Always disable streaming for OpenRouter
|
||||||
|
# MCP doesn't use streaming, and this avoids issues with O3 model access
|
||||||
|
if "stream" not in kwargs:
|
||||||
|
kwargs["stream"] = False
|
||||||
|
|
||||||
|
# Call parent method with resolved model name
|
||||||
|
return super().generate_content(
|
||||||
|
prompt=prompt,
|
||||||
|
model_name=resolved_model,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Registry helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def list_models(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
respect_restrictions: bool = True,
|
||||||
|
include_aliases: bool = True,
|
||||||
|
lowercase: bool = False,
|
||||||
|
unique: bool = False,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Return formatted OpenRouter model names, respecting alias-aware restrictions."""
|
||||||
|
|
||||||
|
if not self._registry:
|
||||||
|
return []
|
||||||
|
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
|
allowed_configs: dict[str, ModelCapabilities] = {}
|
||||||
|
|
||||||
|
for model_name in self._registry.list_models():
|
||||||
|
config = self._registry.resolve(model_name)
|
||||||
|
if not config:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Custom models belong to CustomProvider; skip them here so the two
|
||||||
|
# providers don't race over the same registrations (important for tests
|
||||||
|
# that stub the registry with minimal objects lacking attrs).
|
||||||
|
if hasattr(config, "is_custom") and config.is_custom is True:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if restriction_service:
|
||||||
|
allowed = restriction_service.is_allowed(self.get_provider_type(), model_name)
|
||||||
|
|
||||||
|
if not allowed and config.aliases:
|
||||||
|
for alias in config.aliases:
|
||||||
|
if restriction_service.is_allowed(self.get_provider_type(), alias):
|
||||||
|
allowed = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
continue
|
||||||
|
|
||||||
|
allowed_configs[model_name] = config
|
||||||
|
|
||||||
|
if not allowed_configs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# When restrictions are in place, don't include aliases to avoid confusion
|
||||||
|
# Only return the canonical model names that are actually allowed
|
||||||
|
actual_include_aliases = include_aliases and not respect_restrictions
|
||||||
|
|
||||||
|
return ModelCapabilities.collect_model_names(
|
||||||
|
allowed_configs,
|
||||||
|
include_aliases=actual_include_aliases,
|
||||||
|
lowercase=lowercase,
|
||||||
|
unique=unique,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Registry helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _resolve_model_name(self, model_name: str) -> str:
|
||||||
|
"""Resolve aliases defined in the OpenRouter registry."""
|
||||||
|
|
||||||
|
config = self._registry.resolve(model_name)
|
||||||
|
if config:
|
||||||
|
if config.model_name != model_name:
|
||||||
|
logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'")
|
||||||
|
return config.model_name
|
||||||
|
|
||||||
|
logging.debug(f"Model '{model_name}' not found in registry, using as-is")
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]:
|
||||||
|
"""Expose registry-backed OpenRouter capabilities."""
|
||||||
|
|
||||||
|
if not self._registry:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
capabilities: dict[str, ModelCapabilities] = {}
|
||||||
|
for model_name in self._registry.list_models():
|
||||||
|
config = self._registry.resolve(model_name)
|
||||||
|
if not config:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# See note in list_models: respect the CustomProvider boundary.
|
||||||
|
if hasattr(config, "is_custom") and config.is_custom is True:
|
||||||
|
continue
|
||||||
|
|
||||||
|
capabilities[model_name] = config
|
||||||
|
return capabilities
|
||||||
292
providers/openrouter_registry.py
Normal file
292
providers/openrouter_registry.py
Normal file
|
|
@ -0,0 +1,292 @@
|
||||||
|
"""OpenRouter model registry for managing model configurations and aliases."""
|
||||||
|
|
||||||
|
import importlib.resources
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
# Import handled via importlib.resources.files() calls directly
|
||||||
|
from utils.file_utils import read_json_file
|
||||||
|
|
||||||
|
from .shared import (
|
||||||
|
ModelCapabilities,
|
||||||
|
ProviderType,
|
||||||
|
TemperatureConstraint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterModelRegistry:
|
||||||
|
"""In-memory view of OpenRouter and custom model metadata.
|
||||||
|
|
||||||
|
Role
|
||||||
|
Parse the packaged ``conf/custom_models.json`` (or user-specified
|
||||||
|
overrides), construct alias and capability maps, and serve those
|
||||||
|
structures to providers that rely on OpenRouter semantics (both the
|
||||||
|
OpenRouter provider itself and the Custom provider).
|
||||||
|
|
||||||
|
Key duties
|
||||||
|
* Load :class:`ModelCapabilities` definitions from configuration files
|
||||||
|
* Maintain a case-insensitive alias → canonical name map for fast
|
||||||
|
resolution
|
||||||
|
* Provide helpers to list models, list aliases, and resolve an arbitrary
|
||||||
|
name to its capability object without repeatedly touching the file
|
||||||
|
system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config_path: Optional[str] = None):
|
||||||
|
"""Initialize the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to config file. If None, uses default locations.
|
||||||
|
"""
|
||||||
|
self.alias_map: dict[str, str] = {} # alias -> model_name
|
||||||
|
self.model_map: dict[str, ModelCapabilities] = {} # model_name -> config
|
||||||
|
|
||||||
|
# Determine config path and loading strategy
|
||||||
|
self.use_resources = False
|
||||||
|
if config_path:
|
||||||
|
# Direct config_path parameter
|
||||||
|
self.config_path = Path(config_path)
|
||||||
|
else:
|
||||||
|
# Check environment variable first
|
||||||
|
env_path = os.getenv("CUSTOM_MODELS_CONFIG_PATH")
|
||||||
|
if env_path:
|
||||||
|
# Environment variable path
|
||||||
|
self.config_path = Path(env_path)
|
||||||
|
else:
|
||||||
|
# Try importlib.resources for robust packaging support
|
||||||
|
self.config_path = None
|
||||||
|
self.use_resources = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
resource_traversable = importlib.resources.files("conf").joinpath("custom_models.json")
|
||||||
|
if hasattr(resource_traversable, "read_text"):
|
||||||
|
self.use_resources = True
|
||||||
|
else:
|
||||||
|
raise AttributeError("read_text not available")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not self.use_resources:
|
||||||
|
# Fallback to file system paths
|
||||||
|
potential_paths = [
|
||||||
|
Path(__file__).parent.parent / "conf" / "custom_models.json",
|
||||||
|
Path.cwd() / "conf" / "custom_models.json",
|
||||||
|
]
|
||||||
|
|
||||||
|
for path in potential_paths:
|
||||||
|
if path.exists():
|
||||||
|
self.config_path = path
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.config_path is None:
|
||||||
|
self.config_path = potential_paths[0]
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
self.reload()
|
||||||
|
|
||||||
|
def reload(self) -> None:
|
||||||
|
"""Reload configuration from disk."""
|
||||||
|
try:
|
||||||
|
configs = self._read_config()
|
||||||
|
self._build_maps(configs)
|
||||||
|
caller_info = ""
|
||||||
|
try:
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
caller_frame = inspect.currentframe().f_back
|
||||||
|
if caller_frame:
|
||||||
|
caller_name = caller_frame.f_code.co_name
|
||||||
|
caller_file = (
|
||||||
|
caller_frame.f_code.co_filename.split("/")[-1] if caller_frame.f_code.co_filename else "unknown"
|
||||||
|
)
|
||||||
|
# Look for tool context
|
||||||
|
while caller_frame:
|
||||||
|
frame_locals = caller_frame.f_locals
|
||||||
|
if "self" in frame_locals and hasattr(frame_locals["self"], "get_name"):
|
||||||
|
tool_name = frame_locals["self"].get_name()
|
||||||
|
caller_info = f" (called from {tool_name} tool)"
|
||||||
|
break
|
||||||
|
caller_frame = caller_frame.f_back
|
||||||
|
if not caller_info:
|
||||||
|
caller_info = f" (called from {caller_name} in {caller_file})"
|
||||||
|
except Exception:
|
||||||
|
# If frame inspection fails, just continue without caller info
|
||||||
|
pass
|
||||||
|
|
||||||
|
logging.debug(
|
||||||
|
f"Loaded {len(self.model_map)} OpenRouter models with {len(self.alias_map)} aliases{caller_info}"
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
# Re-raise ValueError only for duplicate aliases (critical config errors)
|
||||||
|
logging.error(f"Failed to load OpenRouter model configuration: {e}")
|
||||||
|
# Initialize with empty maps on failure
|
||||||
|
self.alias_map = {}
|
||||||
|
self.model_map = {}
|
||||||
|
if "Duplicate alias" in str(e):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to load OpenRouter model configuration: {e}")
|
||||||
|
# Initialize with empty maps on failure
|
||||||
|
self.alias_map = {}
|
||||||
|
self.model_map = {}
|
||||||
|
|
||||||
|
def _read_config(self) -> list[ModelCapabilities]:
|
||||||
|
"""Read configuration from file or package resources.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model configurations
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.use_resources:
|
||||||
|
# Use importlib.resources for packaged environments
|
||||||
|
try:
|
||||||
|
resource_path = importlib.resources.files("conf").joinpath("custom_models.json")
|
||||||
|
if hasattr(resource_path, "read_text"):
|
||||||
|
# Python 3.9+
|
||||||
|
config_text = resource_path.read_text(encoding="utf-8")
|
||||||
|
else:
|
||||||
|
# Python 3.8 fallback
|
||||||
|
with resource_path.open("r", encoding="utf-8") as f:
|
||||||
|
config_text = f.read()
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = json.loads(config_text)
|
||||||
|
logging.debug("Loaded OpenRouter config from package resources")
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to load config from resources: {e}")
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
# Use file path loading
|
||||||
|
if not self.config_path.exists():
|
||||||
|
logging.warning(f"OpenRouter model config not found at {self.config_path}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Use centralized JSON reading utility
|
||||||
|
data = read_json_file(str(self.config_path))
|
||||||
|
logging.debug(f"Loaded OpenRouter config from file: {self.config_path}")
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
location = "resources" if self.use_resources else str(self.config_path)
|
||||||
|
raise ValueError(f"Could not read or parse JSON from {location}")
|
||||||
|
|
||||||
|
# Parse models
|
||||||
|
configs = []
|
||||||
|
for model_data in data.get("models", []):
|
||||||
|
# Create ModelCapabilities directly from JSON data
|
||||||
|
# Handle temperature_constraint conversion
|
||||||
|
temp_constraint_str = model_data.get("temperature_constraint")
|
||||||
|
temp_constraint = TemperatureConstraint.create(temp_constraint_str or "range")
|
||||||
|
|
||||||
|
# Set provider-specific defaults based on is_custom flag
|
||||||
|
is_custom = model_data.get("is_custom", False)
|
||||||
|
if is_custom:
|
||||||
|
model_data.setdefault("provider", ProviderType.CUSTOM)
|
||||||
|
model_data.setdefault("friendly_name", f"Custom ({model_data.get('model_name', 'Unknown')})")
|
||||||
|
else:
|
||||||
|
model_data.setdefault("provider", ProviderType.OPENROUTER)
|
||||||
|
model_data.setdefault("friendly_name", f"OpenRouter ({model_data.get('model_name', 'Unknown')})")
|
||||||
|
model_data["temperature_constraint"] = temp_constraint
|
||||||
|
|
||||||
|
# Remove the string version of temperature_constraint before creating ModelCapabilities
|
||||||
|
if "temperature_constraint" in model_data and isinstance(model_data["temperature_constraint"], str):
|
||||||
|
del model_data["temperature_constraint"]
|
||||||
|
model_data["temperature_constraint"] = temp_constraint
|
||||||
|
|
||||||
|
config = ModelCapabilities(**model_data)
|
||||||
|
configs.append(config)
|
||||||
|
|
||||||
|
return configs
|
||||||
|
except ValueError:
|
||||||
|
# Re-raise ValueError for specific config errors
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
location = "resources" if self.use_resources else str(self.config_path)
|
||||||
|
raise ValueError(f"Error reading config from {location}: {e}")
|
||||||
|
|
||||||
|
def _build_maps(self, configs: list[ModelCapabilities]) -> None:
|
||||||
|
"""Build alias and model maps from configurations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
configs: List of model configurations
|
||||||
|
"""
|
||||||
|
alias_map = {}
|
||||||
|
model_map = {}
|
||||||
|
|
||||||
|
for config in configs:
|
||||||
|
# Add to model map
|
||||||
|
model_map[config.model_name] = config
|
||||||
|
|
||||||
|
# Add the model_name itself as an alias for case-insensitive lookup
|
||||||
|
# But only if it's not already in the aliases list
|
||||||
|
model_name_lower = config.model_name.lower()
|
||||||
|
aliases_lower = [alias.lower() for alias in config.aliases]
|
||||||
|
|
||||||
|
if model_name_lower not in aliases_lower:
|
||||||
|
if model_name_lower in alias_map:
|
||||||
|
existing_model = alias_map[model_name_lower]
|
||||||
|
if existing_model != config.model_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Duplicate model name '{config.model_name}' (case-insensitive) found for models "
|
||||||
|
f"'{existing_model}' and '{config.model_name}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
alias_map[model_name_lower] = config.model_name
|
||||||
|
|
||||||
|
# Add aliases
|
||||||
|
for alias in config.aliases:
|
||||||
|
alias_lower = alias.lower()
|
||||||
|
if alias_lower in alias_map:
|
||||||
|
existing_model = alias_map[alias_lower]
|
||||||
|
raise ValueError(
|
||||||
|
f"Duplicate alias '{alias}' found for models '{existing_model}' and '{config.model_name}'"
|
||||||
|
)
|
||||||
|
alias_map[alias_lower] = config.model_name
|
||||||
|
|
||||||
|
# Atomic update
|
||||||
|
self.alias_map = alias_map
|
||||||
|
self.model_map = model_map
|
||||||
|
|
||||||
|
def resolve(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
||||||
|
"""Resolve a model name or alias to configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name_or_alias: Model name or alias to resolve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model configuration if found, None otherwise
|
||||||
|
"""
|
||||||
|
# Try alias lookup (case-insensitive) - this now includes model names too
|
||||||
|
alias_lower = name_or_alias.lower()
|
||||||
|
if alias_lower in self.alias_map:
|
||||||
|
model_name = self.alias_map[alias_lower]
|
||||||
|
return self.model_map.get(model_name)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_capabilities(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
||||||
|
"""Get model capabilities for a name or alias.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name_or_alias: Model name or alias
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelCapabilities if found, None otherwise
|
||||||
|
"""
|
||||||
|
# Registry now returns ModelCapabilities directly
|
||||||
|
return self.resolve(name_or_alias)
|
||||||
|
|
||||||
|
def get_model_config(self, name_or_alias: str) -> Optional[ModelCapabilities]:
|
||||||
|
"""Backward-compatible wrapper used by providers and older tests."""
|
||||||
|
|
||||||
|
return self.resolve(name_or_alias)
|
||||||
|
|
||||||
|
def list_models(self) -> list[str]:
|
||||||
|
"""List all available model names."""
|
||||||
|
return list(self.model_map.keys())
|
||||||
|
|
||||||
|
def list_aliases(self) -> list[str]:
|
||||||
|
"""List all available aliases."""
|
||||||
|
return list(self.alias_map.keys())
|
||||||
397
providers/registry.py
Normal file
397
providers/registry.py
Normal file
|
|
@ -0,0 +1,397 @@
|
||||||
|
"""Model provider registry for managing available providers."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from .base import ModelProvider
|
||||||
|
from .shared import ProviderType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProviderRegistry:
|
||||||
|
"""Central catalogue of provider implementations used by the MCP server.
|
||||||
|
|
||||||
|
Role
|
||||||
|
Holds the mapping between :class:`ProviderType` values and concrete
|
||||||
|
:class:`ModelProvider` subclasses/factories. At runtime the registry
|
||||||
|
is responsible for instantiating providers, caching them for reuse, and
|
||||||
|
mediating lookup of providers and model names in provider priority
|
||||||
|
order.
|
||||||
|
|
||||||
|
Core responsibilities
|
||||||
|
* Resolve API keys and other runtime configuration for each provider
|
||||||
|
* Lazily create provider instances so unused backends incur no cost
|
||||||
|
* Expose convenience methods for enumerating available models and
|
||||||
|
locating which provider can service a requested model name or alias
|
||||||
|
* Honour the project-wide provider priority policy so namespaces (or
|
||||||
|
alias collisions) are resolved deterministically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
# Provider priority order for model selection
|
||||||
|
# Native APIs first, then custom endpoints, then catch-all providers
|
||||||
|
PROVIDER_PRIORITY_ORDER = [
|
||||||
|
ProviderType.GOOGLE, # Direct Gemini access
|
||||||
|
ProviderType.OPENAI, # Direct OpenAI access
|
||||||
|
ProviderType.XAI, # Direct X.AI GROK access
|
||||||
|
ProviderType.DIAL, # DIAL unified API access
|
||||||
|
ProviderType.CUSTOM, # Local/self-hosted models
|
||||||
|
ProviderType.OPENROUTER, # Catch-all for cloud models
|
||||||
|
]
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
"""Singleton pattern for registry."""
|
||||||
|
if cls._instance is None:
|
||||||
|
logging.debug("REGISTRY: Creating new registry instance")
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
# Initialize instance dictionaries on first creation
|
||||||
|
cls._instance._providers = {}
|
||||||
|
cls._instance._initialized_providers = {}
|
||||||
|
logging.debug(f"REGISTRY: Created instance {cls._instance}")
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_provider(cls, provider_type: ProviderType, provider_class: type[ModelProvider]) -> None:
|
||||||
|
"""Register a new provider class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: Type of the provider (e.g., ProviderType.GOOGLE)
|
||||||
|
provider_class: Class that implements ModelProvider interface
|
||||||
|
"""
|
||||||
|
instance = cls()
|
||||||
|
instance._providers[provider_type] = provider_class
|
||||||
|
# Invalidate any cached instance so subsequent lookups use the new registration
|
||||||
|
instance._initialized_providers.pop(provider_type, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]:
|
||||||
|
"""Get an initialized provider instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: Type of provider to get
|
||||||
|
force_new: Force creation of new instance instead of using cached
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Initialized ModelProvider instance or None if not available
|
||||||
|
"""
|
||||||
|
instance = cls()
|
||||||
|
|
||||||
|
# Return cached instance if available and not forcing new
|
||||||
|
if not force_new and provider_type in instance._initialized_providers:
|
||||||
|
return instance._initialized_providers[provider_type]
|
||||||
|
|
||||||
|
# Check if provider class is registered
|
||||||
|
if provider_type not in instance._providers:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get API key from environment
|
||||||
|
api_key = cls._get_api_key_for_provider(provider_type)
|
||||||
|
|
||||||
|
# Get provider class or factory function
|
||||||
|
provider_class = instance._providers[provider_type]
|
||||||
|
|
||||||
|
# For custom providers, handle special initialization requirements
|
||||||
|
if provider_type == ProviderType.CUSTOM:
|
||||||
|
# Check if it's a factory function (callable but not a class)
|
||||||
|
if callable(provider_class) and not isinstance(provider_class, type):
|
||||||
|
# Factory function - call it with api_key parameter
|
||||||
|
provider = provider_class(api_key=api_key)
|
||||||
|
else:
|
||||||
|
# Regular class - need to handle URL requirement
|
||||||
|
custom_url = os.getenv("CUSTOM_API_URL", "")
|
||||||
|
if not custom_url:
|
||||||
|
if api_key: # Key is set but URL is missing
|
||||||
|
logging.warning("CUSTOM_API_KEY set but CUSTOM_API_URL missing – skipping Custom provider")
|
||||||
|
return None
|
||||||
|
# Use empty string as API key for custom providers that don't need auth (e.g., Ollama)
|
||||||
|
# This allows the provider to be created even without CUSTOM_API_KEY being set
|
||||||
|
api_key = api_key or ""
|
||||||
|
# Initialize custom provider with both API key and base URL
|
||||||
|
provider = provider_class(api_key=api_key, base_url=custom_url)
|
||||||
|
elif provider_type == ProviderType.GOOGLE:
|
||||||
|
# For Gemini, check if custom base URL is configured
|
||||||
|
if not api_key:
|
||||||
|
return None
|
||||||
|
gemini_base_url = os.getenv("GEMINI_BASE_URL")
|
||||||
|
provider_kwargs = {"api_key": api_key}
|
||||||
|
if gemini_base_url:
|
||||||
|
provider_kwargs["base_url"] = gemini_base_url
|
||||||
|
logging.info(f"Initialized Gemini provider with custom endpoint: {gemini_base_url}")
|
||||||
|
provider = provider_class(**provider_kwargs)
|
||||||
|
else:
|
||||||
|
if not api_key:
|
||||||
|
return None
|
||||||
|
# Initialize non-custom provider with just API key
|
||||||
|
provider = provider_class(api_key=api_key)
|
||||||
|
|
||||||
|
# Cache the instance
|
||||||
|
instance._initialized_providers[provider_type] = provider
|
||||||
|
|
||||||
|
return provider
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_provider_for_model(cls, model_name: str) -> Optional[ModelProvider]:
|
||||||
|
"""Get provider instance for a specific model name.
|
||||||
|
|
||||||
|
Provider priority order:
|
||||||
|
1. Native APIs (GOOGLE, OPENAI) - Most direct and efficient
|
||||||
|
2. CUSTOM - For local/private models with specific endpoints
|
||||||
|
3. OPENROUTER - Catch-all for cloud models via unified API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model (e.g., "gemini-2.5-flash", "gpt5")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelProvider instance that supports this model
|
||||||
|
"""
|
||||||
|
logging.debug(f"get_provider_for_model called with model_name='{model_name}'")
|
||||||
|
|
||||||
|
# Check providers in priority order
|
||||||
|
instance = cls()
|
||||||
|
logging.debug(f"Registry instance: {instance}")
|
||||||
|
logging.debug(f"Available providers in registry: {list(instance._providers.keys())}")
|
||||||
|
|
||||||
|
for provider_type in cls.PROVIDER_PRIORITY_ORDER:
|
||||||
|
if provider_type in instance._providers:
|
||||||
|
logging.debug(f"Found {provider_type} in registry")
|
||||||
|
# Get or create provider instance
|
||||||
|
provider = cls.get_provider(provider_type)
|
||||||
|
if provider and provider.validate_model_name(model_name):
|
||||||
|
logging.debug(f"{provider_type} validates model {model_name}")
|
||||||
|
return provider
|
||||||
|
else:
|
||||||
|
logging.debug(f"{provider_type} does not validate model {model_name}")
|
||||||
|
else:
|
||||||
|
logging.debug(f"{provider_type} not found in registry")
|
||||||
|
|
||||||
|
logging.debug(f"No provider found for model {model_name}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_available_providers(cls) -> list[ProviderType]:
|
||||||
|
"""Get list of registered provider types."""
|
||||||
|
instance = cls()
|
||||||
|
return list(instance._providers.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_available_models(cls, respect_restrictions: bool = True) -> dict[str, ProviderType]:
|
||||||
|
"""Get mapping of all available models to their providers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
respect_restrictions: If True, filter out models not allowed by restrictions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping model names to provider types
|
||||||
|
"""
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service() if respect_restrictions else None
|
||||||
|
models: dict[str, ProviderType] = {}
|
||||||
|
instance = cls()
|
||||||
|
|
||||||
|
for provider_type in instance._providers:
|
||||||
|
provider = cls.get_provider(provider_type)
|
||||||
|
if not provider:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
available = provider.list_models(respect_restrictions=respect_restrictions)
|
||||||
|
except NotImplementedError:
|
||||||
|
logging.warning("Provider %s does not implement list_models", provider_type)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for model_name in available:
|
||||||
|
# =====================================================================================
|
||||||
|
# CRITICAL: Prevent double restriction filtering (Fixed Issue #98)
|
||||||
|
# =====================================================================================
|
||||||
|
# Previously, both the provider AND registry applied restrictions, causing
|
||||||
|
# double-filtering that resulted in "no models available" errors.
|
||||||
|
#
|
||||||
|
# Logic: If respect_restrictions=True, provider already filtered models,
|
||||||
|
# so registry should NOT filter them again.
|
||||||
|
# TEST COVERAGE: tests/test_provider_routing_bugs.py::TestOpenRouterAliasRestrictions
|
||||||
|
# =====================================================================================
|
||||||
|
if (
|
||||||
|
restriction_service
|
||||||
|
and not respect_restrictions # Only filter if provider didn't already filter
|
||||||
|
and not restriction_service.is_allowed(provider_type, model_name)
|
||||||
|
):
|
||||||
|
logging.debug("Model %s filtered by restrictions", model_name)
|
||||||
|
continue
|
||||||
|
models[model_name] = provider_type
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]:
|
||||||
|
"""Get list of available model names, optionally filtered by provider.
|
||||||
|
|
||||||
|
This respects model restrictions automatically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: Optional provider to filter by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of available model names
|
||||||
|
"""
|
||||||
|
available_models = cls.get_available_models(respect_restrictions=True)
|
||||||
|
|
||||||
|
if provider_type:
|
||||||
|
# Filter by specific provider
|
||||||
|
return [name for name, ptype in available_models.items() if ptype == provider_type]
|
||||||
|
else:
|
||||||
|
# Return all available models
|
||||||
|
return list(available_models.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]:
|
||||||
|
"""Get API key for a provider from environment variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: Provider type to get API key for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API key string or None if not found
|
||||||
|
"""
|
||||||
|
key_mapping = {
|
||||||
|
ProviderType.GOOGLE: "GEMINI_API_KEY",
|
||||||
|
ProviderType.OPENAI: "OPENAI_API_KEY",
|
||||||
|
ProviderType.XAI: "XAI_API_KEY",
|
||||||
|
ProviderType.OPENROUTER: "OPENROUTER_API_KEY",
|
||||||
|
ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth
|
||||||
|
ProviderType.DIAL: "DIAL_API_KEY",
|
||||||
|
}
|
||||||
|
|
||||||
|
env_var = key_mapping.get(provider_type)
|
||||||
|
if not env_var:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return os.getenv(env_var)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_allowed_models_for_provider(cls, provider: ModelProvider, provider_type: ProviderType) -> list[str]:
|
||||||
|
"""Get a list of allowed canonical model names for a given provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: The provider instance to get models for
|
||||||
|
provider_type: The provider type for restriction checking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model names that are both supported and allowed
|
||||||
|
"""
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service()
|
||||||
|
|
||||||
|
allowed_models = []
|
||||||
|
|
||||||
|
# Get the provider's supported models
|
||||||
|
try:
|
||||||
|
# Use list_models to get all supported models (handles both regular and custom providers)
|
||||||
|
supported_models = provider.list_models(respect_restrictions=False)
|
||||||
|
except (NotImplementedError, AttributeError):
|
||||||
|
# Fallback to provider-declared capability maps if list_models not implemented
|
||||||
|
model_map = getattr(provider, "MODEL_CAPABILITIES", None)
|
||||||
|
supported_models = list(model_map.keys()) if isinstance(model_map, dict) else []
|
||||||
|
|
||||||
|
# Filter by restrictions
|
||||||
|
for model_name in supported_models:
|
||||||
|
if restriction_service.is_allowed(provider_type, model_name):
|
||||||
|
allowed_models.append(model_name)
|
||||||
|
|
||||||
|
return allowed_models
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_preferred_fallback_model(cls, tool_category: Optional["ToolModelCategory"] = None) -> str:
|
||||||
|
"""Get the preferred fallback model based on provider priority and tool category.
|
||||||
|
|
||||||
|
This method orchestrates model selection by:
|
||||||
|
1. Getting allowed models for each provider (respecting restrictions)
|
||||||
|
2. Asking providers for their preference from the allowed list
|
||||||
|
3. Falling back to first available model if no preference given
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_category: Optional category to influence model selection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model name string for fallback use
|
||||||
|
"""
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
|
effective_category = tool_category or ToolModelCategory.BALANCED
|
||||||
|
first_available_model = None
|
||||||
|
|
||||||
|
# Ask each provider for their preference in priority order
|
||||||
|
for provider_type in cls.PROVIDER_PRIORITY_ORDER:
|
||||||
|
provider = cls.get_provider(provider_type)
|
||||||
|
if provider:
|
||||||
|
# 1. Registry filters the models first
|
||||||
|
allowed_models = cls._get_allowed_models_for_provider(provider, provider_type)
|
||||||
|
|
||||||
|
if not allowed_models:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 2. Keep track of the first available model as fallback
|
||||||
|
if not first_available_model:
|
||||||
|
first_available_model = sorted(allowed_models)[0]
|
||||||
|
|
||||||
|
# 3. Ask provider to pick from allowed list
|
||||||
|
preferred_model = provider.get_preferred_model(effective_category, allowed_models)
|
||||||
|
|
||||||
|
if preferred_model:
|
||||||
|
logging.debug(
|
||||||
|
f"Provider {provider_type.value} selected '{preferred_model}' for category '{effective_category.value}'"
|
||||||
|
)
|
||||||
|
return preferred_model
|
||||||
|
|
||||||
|
# If no provider returned a preference, use first available model
|
||||||
|
if first_available_model:
|
||||||
|
logging.debug(f"No provider preference, using first available: {first_available_model}")
|
||||||
|
return first_available_model
|
||||||
|
|
||||||
|
# Ultimate fallback if no providers have models
|
||||||
|
logging.warning("No models available from any provider, using default fallback")
|
||||||
|
return "gemini-2.5-flash"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_available_providers_with_keys(cls) -> list[ProviderType]:
|
||||||
|
"""Get list of provider types that have valid API keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ProviderType values for providers with valid API keys
|
||||||
|
"""
|
||||||
|
available = []
|
||||||
|
instance = cls()
|
||||||
|
for provider_type in instance._providers:
|
||||||
|
if cls.get_provider(provider_type) is not None:
|
||||||
|
available.append(provider_type)
|
||||||
|
return available
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clear_cache(cls) -> None:
|
||||||
|
"""Clear cached provider instances."""
|
||||||
|
instance = cls()
|
||||||
|
instance._initialized_providers.clear()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reset_for_testing(cls) -> None:
|
||||||
|
"""Reset the registry to a clean state for testing.
|
||||||
|
|
||||||
|
This provides a safe, public API for tests to clean up registry state
|
||||||
|
without directly manipulating private attributes.
|
||||||
|
"""
|
||||||
|
cls._instance = None
|
||||||
|
if hasattr(cls, "_providers"):
|
||||||
|
cls._providers = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def unregister_provider(cls, provider_type: ProviderType) -> None:
|
||||||
|
"""Unregister a provider (mainly for testing)."""
|
||||||
|
instance = cls()
|
||||||
|
instance._providers.pop(provider_type, None)
|
||||||
|
instance._initialized_providers.pop(provider_type, None)
|
||||||
21
providers/shared/__init__.py
Normal file
21
providers/shared/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
"""Shared data structures and helpers for model providers."""
|
||||||
|
|
||||||
|
from .model_capabilities import ModelCapabilities
|
||||||
|
from .model_response import ModelResponse
|
||||||
|
from .provider_type import ProviderType
|
||||||
|
from .temperature import (
|
||||||
|
DiscreteTemperatureConstraint,
|
||||||
|
FixedTemperatureConstraint,
|
||||||
|
RangeTemperatureConstraint,
|
||||||
|
TemperatureConstraint,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ModelCapabilities",
|
||||||
|
"ModelResponse",
|
||||||
|
"ProviderType",
|
||||||
|
"TemperatureConstraint",
|
||||||
|
"FixedTemperatureConstraint",
|
||||||
|
"RangeTemperatureConstraint",
|
||||||
|
"DiscreteTemperatureConstraint",
|
||||||
|
]
|
||||||
122
providers/shared/model_capabilities.py
Normal file
122
providers/shared/model_capabilities.py
Normal file
|
|
@ -0,0 +1,122 @@
|
||||||
|
"""Dataclass describing the feature set of a model exposed by a provider."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .provider_type import ProviderType
|
||||||
|
from .temperature import RangeTemperatureConstraint, TemperatureConstraint
|
||||||
|
|
||||||
|
__all__ = ["ModelCapabilities"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelCapabilities:
|
||||||
|
"""Static description of what a model can do within a provider.
|
||||||
|
|
||||||
|
Role
|
||||||
|
Acts as the canonical record for everything the server needs to know
|
||||||
|
about a model—its provider, token limits, feature switches, aliases,
|
||||||
|
and temperature rules. Providers populate these objects so tools and
|
||||||
|
higher-level services can rely on a consistent schema.
|
||||||
|
|
||||||
|
Typical usage
|
||||||
|
* Provider subclasses declare `MODEL_CAPABILITIES` maps containing these
|
||||||
|
objects (for example ``OpenAIModelProvider``)
|
||||||
|
* Helper utilities (e.g. restriction validation, alias expansion) read
|
||||||
|
these objects to build model lists for tooling and policy enforcement
|
||||||
|
* Tool selection logic inspects attributes such as
|
||||||
|
``supports_extended_thinking`` or ``context_window`` to choose an
|
||||||
|
appropriate model for a task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: ProviderType
|
||||||
|
model_name: str
|
||||||
|
friendly_name: str
|
||||||
|
description: str = ""
|
||||||
|
aliases: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Capacity limits / resource budgets
|
||||||
|
context_window: int = 0
|
||||||
|
max_output_tokens: int = 0
|
||||||
|
max_thinking_tokens: int = 0
|
||||||
|
|
||||||
|
# Capability flags
|
||||||
|
supports_extended_thinking: bool = False
|
||||||
|
supports_system_prompts: bool = True
|
||||||
|
supports_streaming: bool = True
|
||||||
|
supports_function_calling: bool = False
|
||||||
|
supports_images: bool = False
|
||||||
|
supports_json_mode: bool = False
|
||||||
|
supports_temperature: bool = True
|
||||||
|
|
||||||
|
# Additional attributes
|
||||||
|
max_image_size_mb: float = 0.0
|
||||||
|
is_custom: bool = False
|
||||||
|
temperature_constraint: TemperatureConstraint = field(
|
||||||
|
default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_effective_temperature(self, requested_temperature: float) -> Optional[float]:
|
||||||
|
"""Return the temperature that should be sent to the provider.
|
||||||
|
|
||||||
|
Models that do not support temperature return ``None`` so that callers
|
||||||
|
can omit the parameter entirely. For supported models, the configured
|
||||||
|
constraint clamps the requested value into a provider-safe range.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self.supports_temperature:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.temperature_constraint.get_corrected_value(requested_temperature)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def collect_aliases(model_configs: dict[str, "ModelCapabilities"]) -> dict[str, list[str]]:
|
||||||
|
"""Build a mapping of model name to aliases from capability configs."""
|
||||||
|
|
||||||
|
return {
|
||||||
|
base_model: capabilities.aliases
|
||||||
|
for base_model, capabilities in model_configs.items()
|
||||||
|
if capabilities.aliases
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def collect_model_names(
|
||||||
|
model_configs: dict[str, "ModelCapabilities"],
|
||||||
|
*,
|
||||||
|
include_aliases: bool = True,
|
||||||
|
lowercase: bool = False,
|
||||||
|
unique: bool = False,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Build an ordered list of model names and aliases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_configs: Mapping of canonical model names to capabilities.
|
||||||
|
include_aliases: When True, include aliases for each model.
|
||||||
|
lowercase: When True, normalize names to lowercase.
|
||||||
|
unique: When True, ensure each returned name appears once (after formatting).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Ordered list of model names (and optionally aliases) formatted per options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
formatted_names: list[str] = []
|
||||||
|
seen: set[str] | None = set() if unique else None
|
||||||
|
|
||||||
|
def append_name(name: str) -> None:
|
||||||
|
formatted = name.lower() if lowercase else name
|
||||||
|
|
||||||
|
if seen is not None:
|
||||||
|
if formatted in seen:
|
||||||
|
return
|
||||||
|
seen.add(formatted)
|
||||||
|
|
||||||
|
formatted_names.append(formatted)
|
||||||
|
|
||||||
|
for base_model, capabilities in model_configs.items():
|
||||||
|
append_name(base_model)
|
||||||
|
|
||||||
|
if include_aliases and capabilities.aliases:
|
||||||
|
for alias in capabilities.aliases:
|
||||||
|
append_name(alias)
|
||||||
|
|
||||||
|
return formatted_names
|
||||||
26
providers/shared/model_response.py
Normal file
26
providers/shared/model_response.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
"""Dataclass used to normalise provider SDK responses."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .provider_type import ProviderType
|
||||||
|
|
||||||
|
__all__ = ["ModelResponse"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelResponse:
|
||||||
|
"""Portable representation of a provider completion."""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
usage: dict[str, int] = field(default_factory=dict)
|
||||||
|
model_name: str = ""
|
||||||
|
friendly_name: str = ""
|
||||||
|
provider: ProviderType = ProviderType.GOOGLE
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tokens(self) -> int:
|
||||||
|
"""Return the total token count if the provider reported usage data."""
|
||||||
|
|
||||||
|
return self.usage.get("total_tokens", 0)
|
||||||
16
providers/shared/provider_type.py
Normal file
16
providers/shared/provider_type.py
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
"""Enumeration describing which backend owns a given model."""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
__all__ = ["ProviderType"]
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderType(Enum):
|
||||||
|
"""Canonical identifiers for every supported provider backend."""
|
||||||
|
|
||||||
|
GOOGLE = "google"
|
||||||
|
OPENAI = "openai"
|
||||||
|
XAI = "xai"
|
||||||
|
OPENROUTER = "openrouter"
|
||||||
|
CUSTOM = "custom"
|
||||||
|
DIAL = "dial"
|
||||||
188
providers/shared/temperature.py
Normal file
188
providers/shared/temperature.py
Normal file
|
|
@ -0,0 +1,188 @@
|
||||||
|
"""Helper types for validating model temperature parameters."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TemperatureConstraint",
|
||||||
|
"FixedTemperatureConstraint",
|
||||||
|
"RangeTemperatureConstraint",
|
||||||
|
"DiscreteTemperatureConstraint",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Common heuristics for determining temperature support when explicit
|
||||||
|
# capabilities are unavailable (e.g., custom/local models).
|
||||||
|
_TEMP_UNSUPPORTED_PATTERNS = {
|
||||||
|
"o1",
|
||||||
|
"o3",
|
||||||
|
"o4", # OpenAI O-series reasoning models
|
||||||
|
"deepseek-reasoner",
|
||||||
|
"deepseek-r1",
|
||||||
|
"r1", # DeepSeek reasoner variants
|
||||||
|
}
|
||||||
|
|
||||||
|
_TEMP_UNSUPPORTED_KEYWORDS = {
|
||||||
|
"reasoner", # Catch additional DeepSeek-style naming patterns
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TemperatureConstraint(ABC):
|
||||||
|
"""Contract for temperature validation used by `ModelCapabilities`.
|
||||||
|
|
||||||
|
Concrete providers describe their temperature behaviour by creating
|
||||||
|
subclasses that expose three operations:
|
||||||
|
* `validate` – decide whether a requested temperature is acceptable.
|
||||||
|
* `get_corrected_value` – coerce out-of-range values into a safe default.
|
||||||
|
* `get_description` – provide a human readable error message for users.
|
||||||
|
|
||||||
|
Providers call these hooks before sending traffic to the underlying API so
|
||||||
|
that unsupported temperatures never reach the remote service.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate(self, temperature: float) -> bool:
|
||||||
|
"""Return ``True`` when the temperature may be sent to the backend."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_corrected_value(self, temperature: float) -> float:
|
||||||
|
"""Return a valid substitute for an out-of-range temperature."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_description(self) -> str:
|
||||||
|
"""Describe the acceptable range to include in error messages."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_default(self) -> float:
|
||||||
|
"""Return the default temperature for the model."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def infer_support(model_name: str) -> tuple[bool, str]:
|
||||||
|
"""Heuristically determine whether a model supports temperature."""
|
||||||
|
|
||||||
|
model_lower = model_name.lower()
|
||||||
|
|
||||||
|
for pattern in _TEMP_UNSUPPORTED_PATTERNS:
|
||||||
|
conditions = (
|
||||||
|
pattern == model_lower,
|
||||||
|
model_lower.startswith(f"{pattern}-"),
|
||||||
|
model_lower.startswith(f"openai/{pattern}"),
|
||||||
|
model_lower.startswith(f"deepseek/{pattern}"),
|
||||||
|
model_lower.endswith(f"-{pattern}"),
|
||||||
|
f"/{pattern}" in model_lower,
|
||||||
|
f"-{pattern}-" in model_lower,
|
||||||
|
)
|
||||||
|
if any(conditions):
|
||||||
|
return False, f"detected pattern '{pattern}'"
|
||||||
|
|
||||||
|
for keyword in _TEMP_UNSUPPORTED_KEYWORDS:
|
||||||
|
if keyword in model_lower:
|
||||||
|
return False, f"detected keyword '{keyword}'"
|
||||||
|
|
||||||
|
return True, "default assumption for models without explicit metadata"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resolve_settings(
|
||||||
|
model_name: str,
|
||||||
|
constraint_hint: Optional[str] = None,
|
||||||
|
) -> tuple[bool, "TemperatureConstraint", str]:
|
||||||
|
"""Derive temperature support and constraint for a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Canonical model identifier or alias.
|
||||||
|
constraint_hint: Optional configuration hint (``"fixed"``,
|
||||||
|
``"range"``, ``"discrete"``). When provided, the hint is
|
||||||
|
honoured directly.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple ``(supports_temperature, constraint, diagnosis)`` describing
|
||||||
|
whether temperature may be tuned, the constraint object that should
|
||||||
|
be attached to :class:`ModelCapabilities`, and the reasoning behind
|
||||||
|
the decision.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if constraint_hint:
|
||||||
|
constraint = TemperatureConstraint.create(constraint_hint)
|
||||||
|
supports_temperature = constraint_hint != "fixed"
|
||||||
|
reason = f"constraint hint '{constraint_hint}'"
|
||||||
|
return supports_temperature, constraint, reason
|
||||||
|
|
||||||
|
supports_temperature, reason = TemperatureConstraint.infer_support(model_name)
|
||||||
|
if supports_temperature:
|
||||||
|
constraint: TemperatureConstraint = RangeTemperatureConstraint(0.0, 2.0, 0.7)
|
||||||
|
else:
|
||||||
|
constraint = FixedTemperatureConstraint(1.0)
|
||||||
|
|
||||||
|
return supports_temperature, constraint, reason
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(constraint_type: str) -> "TemperatureConstraint":
|
||||||
|
"""Factory that yields the appropriate constraint for a configuration hint."""
|
||||||
|
|
||||||
|
if constraint_type == "fixed":
|
||||||
|
# Fixed temperature models (O3/O4) only support temperature=1.0
|
||||||
|
return FixedTemperatureConstraint(1.0)
|
||||||
|
if constraint_type == "discrete":
|
||||||
|
# For models with specific allowed values - using common OpenAI values as default
|
||||||
|
return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3)
|
||||||
|
# Default range constraint (for "range" or None)
|
||||||
|
return RangeTemperatureConstraint(0.0, 2.0, 0.3)
|
||||||
|
|
||||||
|
|
||||||
|
class FixedTemperatureConstraint(TemperatureConstraint):
|
||||||
|
"""Constraint for models that enforce an exact temperature (for example O3)."""
|
||||||
|
|
||||||
|
def __init__(self, value: float):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def validate(self, temperature: float) -> bool:
|
||||||
|
return abs(temperature - self.value) < 1e-6 # Handle floating point precision
|
||||||
|
|
||||||
|
def get_corrected_value(self, temperature: float) -> float:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return f"Only supports temperature={self.value}"
|
||||||
|
|
||||||
|
def get_default(self) -> float:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class RangeTemperatureConstraint(TemperatureConstraint):
|
||||||
|
"""Constraint for providers that expose a continuous min/max temperature range."""
|
||||||
|
|
||||||
|
def __init__(self, min_temp: float, max_temp: float, default: Optional[float] = None):
|
||||||
|
self.min_temp = min_temp
|
||||||
|
self.max_temp = max_temp
|
||||||
|
self.default_temp = default or (min_temp + max_temp) / 2
|
||||||
|
|
||||||
|
def validate(self, temperature: float) -> bool:
|
||||||
|
return self.min_temp <= temperature <= self.max_temp
|
||||||
|
|
||||||
|
def get_corrected_value(self, temperature: float) -> float:
|
||||||
|
return max(self.min_temp, min(self.max_temp, temperature))
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return f"Supports temperature range [{self.min_temp}, {self.max_temp}]"
|
||||||
|
|
||||||
|
def get_default(self) -> float:
|
||||||
|
return self.default_temp
|
||||||
|
|
||||||
|
|
||||||
|
class DiscreteTemperatureConstraint(TemperatureConstraint):
|
||||||
|
"""Constraint for models that permit a discrete list of temperature values."""
|
||||||
|
|
||||||
|
def __init__(self, allowed_values: list[float], default: Optional[float] = None):
|
||||||
|
self.allowed_values = sorted(allowed_values)
|
||||||
|
self.default_temp = default or allowed_values[len(allowed_values) // 2]
|
||||||
|
|
||||||
|
def validate(self, temperature: float) -> bool:
|
||||||
|
return any(abs(temperature - val) < 1e-6 for val in self.allowed_values)
|
||||||
|
|
||||||
|
def get_corrected_value(self, temperature: float) -> float:
|
||||||
|
return min(self.allowed_values, key=lambda x: abs(x - temperature))
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return f"Supports temperatures: {self.allowed_values}"
|
||||||
|
|
||||||
|
def get_default(self) -> float:
|
||||||
|
return self.default_temp
|
||||||
157
providers/xai.py
Normal file
157
providers/xai.py
Normal file
|
|
@ -0,0 +1,157 @@
|
||||||
|
"""X.AI (GROK) model provider implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
|
from .openai_compatible import OpenAICompatibleProvider
|
||||||
|
from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class XAIModelProvider(OpenAICompatibleProvider):
|
||||||
|
"""Integration for X.AI's GROK models exposed over an OpenAI-style API.
|
||||||
|
|
||||||
|
Publishes capability metadata for the officially supported deployments and
|
||||||
|
maps tool-category preferences to the appropriate GROK model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
FRIENDLY_NAME = "X.AI"
|
||||||
|
|
||||||
|
# Model configurations using ModelCapabilities objects
|
||||||
|
MODEL_CAPABILITIES = {
|
||||||
|
"grok-4": ModelCapabilities(
|
||||||
|
provider=ProviderType.XAI,
|
||||||
|
model_name="grok-4",
|
||||||
|
friendly_name="X.AI (Grok 4)",
|
||||||
|
context_window=256_000, # 256K tokens
|
||||||
|
max_output_tokens=256_000, # 256K tokens max output
|
||||||
|
supports_extended_thinking=True, # Grok-4 supports reasoning mode
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True, # Function calling supported
|
||||||
|
supports_json_mode=True, # Structured outputs supported
|
||||||
|
supports_images=True, # Multimodal capabilities
|
||||||
|
max_image_size_mb=20.0, # Standard image size limit
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("range"),
|
||||||
|
description="GROK-4 (256K context) - Frontier multimodal reasoning model with advanced capabilities",
|
||||||
|
aliases=["grok", "grok4", "grok-4"],
|
||||||
|
),
|
||||||
|
"grok-3": ModelCapabilities(
|
||||||
|
provider=ProviderType.XAI,
|
||||||
|
model_name="grok-3",
|
||||||
|
friendly_name="X.AI (Grok 3)",
|
||||||
|
context_window=131_072, # 131K tokens
|
||||||
|
max_output_tokens=131072,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
|
||||||
|
supports_images=False, # Assuming GROK is text-only for now
|
||||||
|
max_image_size_mb=0.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("range"),
|
||||||
|
description="GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis",
|
||||||
|
aliases=["grok3"],
|
||||||
|
),
|
||||||
|
"grok-3-fast": ModelCapabilities(
|
||||||
|
provider=ProviderType.XAI,
|
||||||
|
model_name="grok-3-fast",
|
||||||
|
friendly_name="X.AI (Grok 3 Fast)",
|
||||||
|
context_window=131_072, # 131K tokens
|
||||||
|
max_output_tokens=131072,
|
||||||
|
supports_extended_thinking=False,
|
||||||
|
supports_system_prompts=True,
|
||||||
|
supports_streaming=True,
|
||||||
|
supports_function_calling=True,
|
||||||
|
supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet
|
||||||
|
supports_images=False, # Assuming GROK is text-only for now
|
||||||
|
max_image_size_mb=0.0,
|
||||||
|
supports_temperature=True,
|
||||||
|
temperature_constraint=TemperatureConstraint.create("range"),
|
||||||
|
description="GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive",
|
||||||
|
aliases=["grok3fast", "grokfast", "grok3-fast"],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, **kwargs):
|
||||||
|
"""Initialize X.AI provider with API key."""
|
||||||
|
# Set X.AI base URL
|
||||||
|
kwargs.setdefault("base_url", "https://api.x.ai/v1")
|
||||||
|
super().__init__(api_key, **kwargs)
|
||||||
|
|
||||||
|
def get_provider_type(self) -> ProviderType:
|
||||||
|
"""Get the provider type."""
|
||||||
|
return ProviderType.XAI
|
||||||
|
|
||||||
|
def generate_content(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
model_name: str,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
temperature: float = 0.3,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> ModelResponse:
|
||||||
|
"""Generate content using X.AI API with proper model name resolution."""
|
||||||
|
# Resolve model alias before making API call
|
||||||
|
resolved_model_name = self._resolve_model_name(model_name)
|
||||||
|
|
||||||
|
# Call parent implementation with resolved model name
|
||||||
|
return super().generate_content(
|
||||||
|
prompt=prompt,
|
||||||
|
model_name=resolved_model_name,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]:
|
||||||
|
"""Get XAI's preferred model for a given category from allowed models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: The tool category requiring a model
|
||||||
|
allowed_models: Pre-filtered list of models allowed by restrictions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preferred model name or None
|
||||||
|
"""
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
|
if not allowed_models:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if category == ToolModelCategory.EXTENDED_REASONING:
|
||||||
|
# Prefer GROK-4 for advanced reasoning with thinking mode
|
||||||
|
if "grok-4" in allowed_models:
|
||||||
|
return "grok-4"
|
||||||
|
elif "grok-3" in allowed_models:
|
||||||
|
return "grok-3"
|
||||||
|
# Fall back to any available model
|
||||||
|
return allowed_models[0]
|
||||||
|
|
||||||
|
elif category == ToolModelCategory.FAST_RESPONSE:
|
||||||
|
# Prefer GROK-3-Fast for speed, then GROK-4
|
||||||
|
if "grok-3-fast" in allowed_models:
|
||||||
|
return "grok-3-fast"
|
||||||
|
elif "grok-4" in allowed_models:
|
||||||
|
return "grok-4"
|
||||||
|
# Fall back to any available model
|
||||||
|
return allowed_models[0]
|
||||||
|
|
||||||
|
else: # BALANCED or default
|
||||||
|
# Prefer GROK-4 for balanced use (best overall capabilities)
|
||||||
|
if "grok-4" in allowed_models:
|
||||||
|
return "grok-4"
|
||||||
|
elif "grok-3" in allowed_models:
|
||||||
|
return "grok-3"
|
||||||
|
elif "grok-3-fast" in allowed_models:
|
||||||
|
return "grok-3-fast"
|
||||||
|
# Fall back to any available model
|
||||||
|
return allowed_models[0]
|
||||||
11
requirements.txt
Normal file
11
requirements.txt
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
mcp>=1.0.0
|
||||||
|
google-genai>=1.19.0
|
||||||
|
openai>=1.55.2 # Minimum version for httpx 0.28.0 compatibility
|
||||||
|
pydantic>=2.0.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
importlib-resources>=5.0.0; python_version<"3.9"
|
||||||
|
|
||||||
|
# Development dependencies (install with pip install -r requirements-dev.txt)
|
||||||
|
# pytest>=7.4.0
|
||||||
|
# pytest-asyncio>=0.21.0
|
||||||
|
# pytest-mock>=3.11.0
|
||||||
66
run-server.sh
Executable file
66
run-server.sh
Executable file
|
|
@ -0,0 +1,66 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Zen-Marketing MCP Server Setup and Run Script
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||||
|
cd "$SCRIPT_DIR"
|
||||||
|
|
||||||
|
echo "=================================="
|
||||||
|
echo "Zen-Marketing MCP Server Setup"
|
||||||
|
echo "=================================="
|
||||||
|
|
||||||
|
# Check Python version
|
||||||
|
echo "Checking Python installation..."
|
||||||
|
if ! command -v python3 &> /dev/null; then
|
||||||
|
echo "Error: Python 3 is required but not installed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}')
|
||||||
|
echo "Found Python $PYTHON_VERSION"
|
||||||
|
|
||||||
|
# Create virtual environment if it doesn't exist
|
||||||
|
if [ ! -d ".venv" ]; then
|
||||||
|
echo "Creating virtual environment..."
|
||||||
|
python3 -m venv .venv
|
||||||
|
else
|
||||||
|
echo "Virtual environment already exists"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Activate virtual environment
|
||||||
|
echo "Activating virtual environment..."
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Upgrade pip
|
||||||
|
echo "Upgrading pip..."
|
||||||
|
pip install --upgrade pip --quiet
|
||||||
|
|
||||||
|
# Install requirements
|
||||||
|
echo "Installing requirements..."
|
||||||
|
pip install -r requirements.txt --quiet
|
||||||
|
|
||||||
|
# Check for .env file
|
||||||
|
if [ ! -f ".env" ]; then
|
||||||
|
echo ""
|
||||||
|
echo "WARNING: No .env file found!"
|
||||||
|
echo "Please create a .env file based on .env.example:"
|
||||||
|
echo " cp .env.example .env"
|
||||||
|
echo " # Then edit .env with your API keys"
|
||||||
|
echo ""
|
||||||
|
read -p "Continue anyway? (y/N) " -n 1 -r
|
||||||
|
echo
|
||||||
|
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run the server
|
||||||
|
echo ""
|
||||||
|
echo "=================================="
|
||||||
|
echo "Starting Zen-Marketing MCP Server"
|
||||||
|
echo "=================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
python server.py
|
||||||
352
server.py
Normal file
352
server.py
Normal file
|
|
@ -0,0 +1,352 @@
|
||||||
|
"""
|
||||||
|
Zen-Marketing MCP Server - Main server implementation
|
||||||
|
|
||||||
|
AI-powered marketing tools for Claude Desktop, specialized for technical B2B content
|
||||||
|
creation, multi-platform campaigns, and content variation testing.
|
||||||
|
|
||||||
|
Based on Zen MCP Server architecture by Fahad Gilani.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import atexit
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# Load environment variables from .env file
|
||||||
|
try:
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
script_dir = Path(__file__).parent
|
||||||
|
env_file = script_dir / ".env"
|
||||||
|
load_dotenv(dotenv_path=env_file, override=False)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from mcp.server import Server
|
||||||
|
from mcp.server.models import InitializationOptions
|
||||||
|
from mcp.server.stdio import stdio_server
|
||||||
|
from mcp.types import (
|
||||||
|
GetPromptResult,
|
||||||
|
Prompt,
|
||||||
|
PromptMessage,
|
||||||
|
PromptsCapability,
|
||||||
|
ServerCapabilities,
|
||||||
|
TextContent,
|
||||||
|
Tool,
|
||||||
|
ToolsCapability,
|
||||||
|
)
|
||||||
|
|
||||||
|
from config import DEFAULT_MODEL, __version__
|
||||||
|
from tools.chat import ChatTool
|
||||||
|
from tools.contentvariant import ContentVariantTool
|
||||||
|
from tools.listmodels import ListModelsTool
|
||||||
|
from tools.models import ToolOutput
|
||||||
|
from tools.version import VersionTool
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||||
|
|
||||||
|
|
||||||
|
class LocalTimeFormatter(logging.Formatter):
|
||||||
|
def formatTime(self, record, datefmt=None):
|
||||||
|
ct = self.converter(record.created)
|
||||||
|
if datefmt:
|
||||||
|
s = time.strftime(datefmt, ct)
|
||||||
|
else:
|
||||||
|
t = time.strftime("%Y-%m-%d %H:%M:%S", ct)
|
||||||
|
s = f"{t},{record.msecs:03.0f}"
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
root_logger.handlers.clear()
|
||||||
|
|
||||||
|
stderr_handler = logging.StreamHandler(sys.stderr)
|
||||||
|
stderr_handler.setLevel(getattr(logging, log_level, logging.INFO))
|
||||||
|
stderr_handler.setFormatter(LocalTimeFormatter(log_format))
|
||||||
|
root_logger.addHandler(stderr_handler)
|
||||||
|
|
||||||
|
root_logger.setLevel(getattr(logging, log_level, logging.INFO))
|
||||||
|
|
||||||
|
# Add rotating file handler
|
||||||
|
try:
|
||||||
|
log_dir = Path(__file__).parent / "logs"
|
||||||
|
log_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
file_handler = RotatingFileHandler(
|
||||||
|
log_dir / "mcp_server.log",
|
||||||
|
maxBytes=20 * 1024 * 1024,
|
||||||
|
backupCount=5,
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
file_handler.setLevel(getattr(logging, log_level, logging.INFO))
|
||||||
|
file_handler.setFormatter(LocalTimeFormatter(log_format))
|
||||||
|
logging.getLogger().addHandler(file_handler)
|
||||||
|
|
||||||
|
mcp_logger = logging.getLogger("mcp_activity")
|
||||||
|
mcp_file_handler = RotatingFileHandler(
|
||||||
|
log_dir / "mcp_activity.log",
|
||||||
|
maxBytes=10 * 1024 * 1024,
|
||||||
|
backupCount=2,
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
mcp_file_handler.setLevel(logging.INFO)
|
||||||
|
mcp_file_handler.setFormatter(LocalTimeFormatter("%(asctime)s - %(message)s"))
|
||||||
|
mcp_logger.addHandler(mcp_file_handler)
|
||||||
|
mcp_logger.setLevel(logging.INFO)
|
||||||
|
mcp_logger.propagate = True
|
||||||
|
|
||||||
|
logging.info(f"Logging to: {log_dir / 'mcp_server.log'}")
|
||||||
|
logging.info(f"Process PID: {os.getpid()}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Warning: Could not set up file logging: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Create MCP server instance
|
||||||
|
server: Server = Server("zen-marketing")
|
||||||
|
|
||||||
|
# Essential tools that cannot be disabled
|
||||||
|
ESSENTIAL_TOOLS = {"version", "listmodels"}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_disabled_tools_env() -> set[str]:
|
||||||
|
"""Parse DISABLED_TOOLS environment variable"""
|
||||||
|
disabled_tools_env = os.getenv("DISABLED_TOOLS", "").strip()
|
||||||
|
if not disabled_tools_env:
|
||||||
|
return set()
|
||||||
|
return {t.strip().lower() for t in disabled_tools_env.split(",") if t.strip()}
|
||||||
|
|
||||||
|
|
||||||
|
def filter_disabled_tools(all_tools: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Filter tools based on DISABLED_TOOLS environment variable"""
|
||||||
|
disabled_tools = parse_disabled_tools_env()
|
||||||
|
if not disabled_tools:
|
||||||
|
logger.info("All tools enabled")
|
||||||
|
return all_tools
|
||||||
|
|
||||||
|
enabled_tools = {}
|
||||||
|
for tool_name, tool_instance in all_tools.items():
|
||||||
|
if tool_name in ESSENTIAL_TOOLS or tool_name not in disabled_tools:
|
||||||
|
enabled_tools[tool_name] = tool_instance
|
||||||
|
else:
|
||||||
|
logger.info(f"Tool '{tool_name}' disabled via DISABLED_TOOLS")
|
||||||
|
|
||||||
|
logger.info(f"Active tools: {sorted(enabled_tools.keys())}")
|
||||||
|
return enabled_tools
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize tool registry
|
||||||
|
TOOLS = {
|
||||||
|
"chat": ChatTool(),
|
||||||
|
"contentvariant": ContentVariantTool(),
|
||||||
|
"listmodels": ListModelsTool(),
|
||||||
|
"version": VersionTool(),
|
||||||
|
}
|
||||||
|
TOOLS = filter_disabled_tools(TOOLS)
|
||||||
|
|
||||||
|
# Prompt templates
|
||||||
|
PROMPT_TEMPLATES = {
|
||||||
|
"chat": {
|
||||||
|
"name": "chat",
|
||||||
|
"description": "Chat and brainstorm marketing ideas",
|
||||||
|
"template": "Chat about marketing strategy with {model}",
|
||||||
|
},
|
||||||
|
"contentvariant": {
|
||||||
|
"name": "variations",
|
||||||
|
"description": "Generate content variations for A/B testing",
|
||||||
|
"template": "Generate content variations for testing with {model}",
|
||||||
|
},
|
||||||
|
"listmodels": {
|
||||||
|
"name": "listmodels",
|
||||||
|
"description": "List available AI models",
|
||||||
|
"template": "List all available models",
|
||||||
|
},
|
||||||
|
"version": {
|
||||||
|
"name": "version",
|
||||||
|
"description": "Show server version",
|
||||||
|
"template": "Show Zen-Marketing server version",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def configure_providers():
|
||||||
|
"""Configure and validate AI providers"""
|
||||||
|
logger.debug("Checking environment variables for API keys...")
|
||||||
|
|
||||||
|
from providers import ModelProviderRegistry
|
||||||
|
from providers.gemini import GeminiModelProvider
|
||||||
|
from providers.openrouter import OpenRouterProvider
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
valid_providers = []
|
||||||
|
|
||||||
|
# Check for Gemini API key
|
||||||
|
gemini_key = os.getenv("GEMINI_API_KEY")
|
||||||
|
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
||||||
|
valid_providers.append("Gemini")
|
||||||
|
logger.info("Gemini API key found - Gemini models available")
|
||||||
|
|
||||||
|
# Check for OpenRouter API key
|
||||||
|
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
||||||
|
if openrouter_key and openrouter_key != "your_openrouter_api_key_here":
|
||||||
|
valid_providers.append("OpenRouter")
|
||||||
|
logger.info("OpenRouter API key found - Multiple models available")
|
||||||
|
|
||||||
|
# Register providers
|
||||||
|
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
||||||
|
|
||||||
|
if openrouter_key and openrouter_key != "your_openrouter_api_key_here":
|
||||||
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
||||||
|
|
||||||
|
if not valid_providers:
|
||||||
|
error_msg = (
|
||||||
|
"No valid API keys found. Please configure at least one:\n"
|
||||||
|
" - GEMINI_API_KEY for Google Gemini\n"
|
||||||
|
" - OPENROUTER_API_KEY for OpenRouter (minimax, etc.)\n"
|
||||||
|
"Create a .env file based on .env.example"
|
||||||
|
)
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
logger.info(f"Configured providers: {', '.join(valid_providers)}")
|
||||||
|
logger.info(f"Default model: {DEFAULT_MODEL}")
|
||||||
|
|
||||||
|
|
||||||
|
# Tool call handlers
|
||||||
|
@server.list_tools()
|
||||||
|
async def handle_list_tools() -> list[Tool]:
|
||||||
|
"""List available tools"""
|
||||||
|
tools = []
|
||||||
|
for tool_name, tool_instance in TOOLS.items():
|
||||||
|
tool_schema = tool_instance.get_input_schema()
|
||||||
|
tools.append(
|
||||||
|
Tool(
|
||||||
|
name=tool_name,
|
||||||
|
description=tool_instance.get_description(),
|
||||||
|
inputSchema=tool_schema,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
@server.call_tool()
|
||||||
|
async def handle_call_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||||
|
"""Handle tool execution requests"""
|
||||||
|
logger.info(f"Tool call: {name}")
|
||||||
|
|
||||||
|
if name not in TOOLS:
|
||||||
|
error_msg = f"Unknown tool: {name}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return [TextContent(type="text", text=error_msg)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool = TOOLS[name]
|
||||||
|
result: ToolOutput = await tool.execute(arguments)
|
||||||
|
|
||||||
|
if result.is_error:
|
||||||
|
logger.error(f"Tool {name} failed: {result.text}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Tool {name} completed successfully")
|
||||||
|
|
||||||
|
return [TextContent(type="text", text=result.text)]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Tool execution failed: {str(e)}"
|
||||||
|
logger.error(f"{error_msg}\n{type(e).__name__}")
|
||||||
|
return [TextContent(type="text", text=error_msg)]
|
||||||
|
|
||||||
|
|
||||||
|
@server.list_prompts()
|
||||||
|
async def handle_list_prompts() -> list[Prompt]:
|
||||||
|
"""List available prompt templates"""
|
||||||
|
prompts = []
|
||||||
|
for tool_name, template_info in PROMPT_TEMPLATES.items():
|
||||||
|
if tool_name in TOOLS:
|
||||||
|
prompts.append(
|
||||||
|
Prompt(
|
||||||
|
name=template_info["name"],
|
||||||
|
description=template_info["description"],
|
||||||
|
arguments=[
|
||||||
|
{"name": "model", "description": "AI model to use", "required": False}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
|
@server.get_prompt()
|
||||||
|
async def handle_get_prompt(name: str, arguments: dict | None = None) -> GetPromptResult:
|
||||||
|
"""Get a specific prompt template"""
|
||||||
|
model = arguments.get("model", DEFAULT_MODEL) if arguments else DEFAULT_MODEL
|
||||||
|
|
||||||
|
for template_info in PROMPT_TEMPLATES.values():
|
||||||
|
if template_info["name"] == name:
|
||||||
|
prompt_text = template_info["template"].format(model=model)
|
||||||
|
return GetPromptResult(
|
||||||
|
description=template_info["description"],
|
||||||
|
messages=[PromptMessage(role="user", content=TextContent(type="text", text=prompt_text))],
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown prompt: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup():
|
||||||
|
"""Cleanup function called on shutdown"""
|
||||||
|
logger.info("Zen-Marketing MCP Server shutting down")
|
||||||
|
logging.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main entry point"""
|
||||||
|
logger.info("=" * 80)
|
||||||
|
logger.info(f"Zen-Marketing MCP Server v{__version__}")
|
||||||
|
logger.info(f"Python: {sys.version}")
|
||||||
|
logger.info(f"Working directory: {os.getcwd()}")
|
||||||
|
logger.info("=" * 80)
|
||||||
|
|
||||||
|
# Register cleanup
|
||||||
|
atexit.register(cleanup)
|
||||||
|
|
||||||
|
# Configure providers
|
||||||
|
try:
|
||||||
|
configure_providers()
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Configuration error: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# List enabled tools
|
||||||
|
logger.info(f"Enabled tools: {sorted(TOOLS.keys())}")
|
||||||
|
|
||||||
|
# Run server
|
||||||
|
capabilities = ServerCapabilities(
|
||||||
|
tools=ToolsCapability(),
|
||||||
|
prompts=PromptsCapability(),
|
||||||
|
)
|
||||||
|
|
||||||
|
options = InitializationOptions(
|
||||||
|
server_name="zen-marketing",
|
||||||
|
server_version=__version__,
|
||||||
|
capabilities=capabilities,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with stdio_server() as (read_stream, write_stream):
|
||||||
|
logger.info("Server started, awaiting requests...")
|
||||||
|
await server.run(
|
||||||
|
read_stream,
|
||||||
|
write_stream,
|
||||||
|
options,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
6
systemprompts/__init__.py
Normal file
6
systemprompts/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
"""System prompts for Zen-Marketing MCP Server tools"""
|
||||||
|
|
||||||
|
from .chat_prompt import CHAT_PROMPT
|
||||||
|
from .contentvariant_prompt import CONTENTVARIANT_PROMPT
|
||||||
|
|
||||||
|
__all__ = ["CHAT_PROMPT", "CONTENTVARIANT_PROMPT"]
|
||||||
29
systemprompts/chat_prompt.py
Normal file
29
systemprompts/chat_prompt.py
Normal file
|
|
@ -0,0 +1,29 @@
|
||||||
|
"""System prompt for the chat tool"""
|
||||||
|
|
||||||
|
CHAT_PROMPT = """You are a marketing strategist and content expert specializing in technical B2B marketing.
|
||||||
|
|
||||||
|
Your expertise includes:
|
||||||
|
- Content strategy and variation testing
|
||||||
|
- Multi-platform content adaptation (LinkedIn, Twitter/X, Instagram, newsletters, blogs)
|
||||||
|
- Email marketing and subject line optimization
|
||||||
|
- SEO and WordPress content optimization
|
||||||
|
- Technical content editing while preserving author voice
|
||||||
|
- Brand voice consistency and style enforcement
|
||||||
|
- Campaign planning across multiple touchpoints
|
||||||
|
- Fact-checking and technical verification
|
||||||
|
|
||||||
|
COMMUNICATION STYLE:
|
||||||
|
- Professional but approachable
|
||||||
|
- Data-informed recommendations
|
||||||
|
- Concrete, actionable suggestions
|
||||||
|
- Respect for brand voice and author expertise
|
||||||
|
- Platform-aware (character limits, best practices)
|
||||||
|
|
||||||
|
When discussing content:
|
||||||
|
- Focus on testing hypotheses and variation strategies
|
||||||
|
- Consider psychological hooks and audience resonance
|
||||||
|
- Balance creativity with conversion optimization
|
||||||
|
- Maintain technical accuracy in specialized domains
|
||||||
|
|
||||||
|
If web search is needed for current best practices, platform updates, or fact verification, indicate this clearly.
|
||||||
|
"""
|
||||||
62
systemprompts/contentvariant_prompt.py
Normal file
62
systemprompts/contentvariant_prompt.py
Normal file
|
|
@ -0,0 +1,62 @@
|
||||||
|
"""System prompt for the contentvariant tool"""
|
||||||
|
|
||||||
|
CONTENTVARIANT_PROMPT = """You are a marketing content strategist specializing in A/B testing and variation generation.
|
||||||
|
|
||||||
|
TASK: Generate multiple variations of marketing content for testing different approaches.
|
||||||
|
|
||||||
|
OUTPUT FORMAT:
|
||||||
|
Return variations as a numbered list, each with:
|
||||||
|
1. The variation text
|
||||||
|
2. The testing angle (what makes it different)
|
||||||
|
3. Predicted audience response or reasoning
|
||||||
|
|
||||||
|
CONSTRAINTS:
|
||||||
|
- Maintain core message across variations
|
||||||
|
- Respect platform character limits if specified
|
||||||
|
- Preserve brand voice characteristics provided
|
||||||
|
- Generate genuinely different approaches, not just word swaps
|
||||||
|
- Each variation should test a distinct hypothesis
|
||||||
|
|
||||||
|
VARIATION TYPES:
|
||||||
|
- **Hook variations**: Different opening angles (question, statement, stat, story)
|
||||||
|
- **Length variations**: Short, medium, long within platform constraints
|
||||||
|
- **Tone variations**: Professional, conversational, urgent, educational, provocative
|
||||||
|
- **Structure variations**: Question format, list format, narrative, before-after
|
||||||
|
- **CTA variations**: Different calls-to-action (learn, try, join, download)
|
||||||
|
- **Psychological angles**: Curiosity, FOMO, social proof, contrarian, pain-solution
|
||||||
|
|
||||||
|
TESTING ANGLES (use when specified):
|
||||||
|
- **Technical curiosity**: Lead with interesting technical detail
|
||||||
|
- **Contrarian/provocative**: Challenge conventional wisdom
|
||||||
|
- **Knowledge gap**: Emphasize what they don't know yet
|
||||||
|
- **Urgency/timeliness**: Time-sensitive opportunity or threat
|
||||||
|
- **Insider knowledge**: Position as exclusive expertise
|
||||||
|
- **Problem-solution**: Lead with pain point
|
||||||
|
- **Social proof**: Leverage credibility or results
|
||||||
|
- **Before-after**: Transformation narrative
|
||||||
|
|
||||||
|
PLATFORM CONSIDERATIONS:
|
||||||
|
- Twitter/Bluesky: 280 chars, punchy hooks, visual language
|
||||||
|
- LinkedIn: 1300 chars optimal, professional tone, business value
|
||||||
|
- Instagram: 2200 chars, storytelling, visual companion
|
||||||
|
- Email subject: Under 60 chars, curiosity-driven
|
||||||
|
- Blog/article: Longer form, educational depth
|
||||||
|
|
||||||
|
CHARACTER COUNT:
|
||||||
|
Always include character count for each variation when platform is specified.
|
||||||
|
|
||||||
|
EXAMPLE OUTPUT FORMAT:
|
||||||
|
**Variation 1: Technical Curiosity Hook**
|
||||||
|
"Most HVAC techs miss this voltage regulation pattern—here's what the top 10% know about PCB diagnostics that changes everything."
|
||||||
|
(149 characters)
|
||||||
|
*Testing angle: Opens with intriguing exclusivity (what others miss) then promises insider knowledge*
|
||||||
|
*Predicted response: Appeals to practitioners wanting competitive edge*
|
||||||
|
|
||||||
|
**Variation 2: Contrarian Statement**
|
||||||
|
"Stop blaming the capacitor. 80% of 'bad cap' calls are actually voltage regulation failures upstream. Here's the diagnostic most techs skip."
|
||||||
|
(142 characters)
|
||||||
|
*Testing angle: Challenges common diagnosis, positions as expert correction*
|
||||||
|
*Predicted response: Stops scroll with unexpected claim, appeals to problem-solvers*
|
||||||
|
|
||||||
|
Be creative, test bold hypotheses, and make variations substantially different from each other.
|
||||||
|
"""
|
||||||
39
tools/__init__.py
Normal file
39
tools/__init__.py
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
"""
|
||||||
|
Tool implementations for Zen MCP Server
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .analyze import AnalyzeTool
|
||||||
|
from .challenge import ChallengeTool
|
||||||
|
from .chat import ChatTool
|
||||||
|
from .codereview import CodeReviewTool
|
||||||
|
from .consensus import ConsensusTool
|
||||||
|
from .debug import DebugIssueTool
|
||||||
|
from .docgen import DocgenTool
|
||||||
|
from .listmodels import ListModelsTool
|
||||||
|
from .planner import PlannerTool
|
||||||
|
from .precommit import PrecommitTool
|
||||||
|
from .refactor import RefactorTool
|
||||||
|
from .secaudit import SecauditTool
|
||||||
|
from .testgen import TestGenTool
|
||||||
|
from .thinkdeep import ThinkDeepTool
|
||||||
|
from .tracer import TracerTool
|
||||||
|
from .version import VersionTool
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ThinkDeepTool",
|
||||||
|
"CodeReviewTool",
|
||||||
|
"DebugIssueTool",
|
||||||
|
"DocgenTool",
|
||||||
|
"AnalyzeTool",
|
||||||
|
"ChatTool",
|
||||||
|
"ConsensusTool",
|
||||||
|
"ListModelsTool",
|
||||||
|
"PlannerTool",
|
||||||
|
"PrecommitTool",
|
||||||
|
"ChallengeTool",
|
||||||
|
"RefactorTool",
|
||||||
|
"SecauditTool",
|
||||||
|
"TestGenTool",
|
||||||
|
"TracerTool",
|
||||||
|
"VersionTool",
|
||||||
|
]
|
||||||
189
tools/chat.py
Normal file
189
tools/chat.py
Normal file
|
|
@ -0,0 +1,189 @@
|
||||||
|
"""
|
||||||
|
Chat tool - General development chat and collaborative thinking
|
||||||
|
|
||||||
|
This tool provides a conversational interface for general development assistance,
|
||||||
|
brainstorming, problem-solving, and collaborative thinking. It supports file context,
|
||||||
|
images, and conversation continuation for seamless multi-turn interactions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
|
from config import TEMPERATURE_BALANCED
|
||||||
|
from systemprompts import CHAT_PROMPT
|
||||||
|
from tools.shared.base_models import COMMON_FIELD_DESCRIPTIONS, ToolRequest
|
||||||
|
|
||||||
|
from .simple.base import SimpleTool
|
||||||
|
|
||||||
|
# Field descriptions matching the original Chat tool exactly
|
||||||
|
CHAT_FIELD_DESCRIPTIONS = {
|
||||||
|
"prompt": (
|
||||||
|
"Your question or idea for collaborative thinking. Provide detailed context, including your goal, what you've tried, and any specific challenges. "
|
||||||
|
"CRITICAL: To discuss code, use 'files' parameter instead of pasting code blocks here."
|
||||||
|
),
|
||||||
|
"files": "absolute file or folder paths for code context (do NOT shorten).",
|
||||||
|
"images": "Optional absolute image paths or base64 for visual context when helpful.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(ToolRequest):
|
||||||
|
"""Request model for Chat tool"""
|
||||||
|
|
||||||
|
prompt: str = Field(..., description=CHAT_FIELD_DESCRIPTIONS["prompt"])
|
||||||
|
files: Optional[list[str]] = Field(default_factory=list, description=CHAT_FIELD_DESCRIPTIONS["files"])
|
||||||
|
images: Optional[list[str]] = Field(default_factory=list, description=CHAT_FIELD_DESCRIPTIONS["images"])
|
||||||
|
|
||||||
|
|
||||||
|
class ChatTool(SimpleTool):
|
||||||
|
"""
|
||||||
|
General development chat and collaborative thinking tool using SimpleTool architecture.
|
||||||
|
|
||||||
|
This tool provides identical functionality to the original Chat tool but uses the new
|
||||||
|
SimpleTool architecture for cleaner code organization and better maintainability.
|
||||||
|
|
||||||
|
Migration note: This tool is designed to be a drop-in replacement for the original
|
||||||
|
Chat tool with 100% behavioral compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "chat"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return (
|
||||||
|
"General chat and collaborative thinking partner for brainstorming, development discussion, "
|
||||||
|
"getting second opinions, and exploring ideas. Use for ideas, validations, questions, and thoughtful explanations."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_system_prompt(self) -> str:
|
||||||
|
return CHAT_PROMPT
|
||||||
|
|
||||||
|
def get_default_temperature(self) -> float:
|
||||||
|
return TEMPERATURE_BALANCED
|
||||||
|
|
||||||
|
def get_model_category(self) -> "ToolModelCategory":
|
||||||
|
"""Chat prioritizes fast responses and cost efficiency"""
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
|
||||||
|
return ToolModelCategory.FAST_RESPONSE
|
||||||
|
|
||||||
|
def get_request_model(self):
|
||||||
|
"""Return the Chat-specific request model"""
|
||||||
|
return ChatRequest
|
||||||
|
|
||||||
|
# === Schema Generation ===
|
||||||
|
# For maximum compatibility, we override get_input_schema() to match the original Chat tool exactly
|
||||||
|
|
||||||
|
def get_input_schema(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate input schema matching the original Chat tool exactly.
|
||||||
|
|
||||||
|
This maintains 100% compatibility with the original Chat tool by using
|
||||||
|
the same schema generation approach while still benefiting from SimpleTool
|
||||||
|
convenience methods.
|
||||||
|
"""
|
||||||
|
required_fields = ["prompt"]
|
||||||
|
if self.is_effective_auto_mode():
|
||||||
|
required_fields.append("model")
|
||||||
|
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": CHAT_FIELD_DESCRIPTIONS["prompt"],
|
||||||
|
},
|
||||||
|
"files": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": CHAT_FIELD_DESCRIPTIONS["files"],
|
||||||
|
},
|
||||||
|
"images": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": CHAT_FIELD_DESCRIPTIONS["images"],
|
||||||
|
},
|
||||||
|
"model": self.get_model_field_schema(),
|
||||||
|
"temperature": {
|
||||||
|
"type": "number",
|
||||||
|
"description": COMMON_FIELD_DESCRIPTIONS["temperature"],
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 1,
|
||||||
|
},
|
||||||
|
"thinking_mode": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["minimal", "low", "medium", "high", "max"],
|
||||||
|
"description": COMMON_FIELD_DESCRIPTIONS["thinking_mode"],
|
||||||
|
},
|
||||||
|
"continuation_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": COMMON_FIELD_DESCRIPTIONS["continuation_id"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": required_fields,
|
||||||
|
}
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
# === Tool-specific field definitions (alternative approach for reference) ===
|
||||||
|
# These aren't used since we override get_input_schema(), but they show how
|
||||||
|
# the tool could be implemented using the automatic SimpleTool schema building
|
||||||
|
|
||||||
|
def get_tool_fields(self) -> dict[str, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Tool-specific field definitions for ChatSimple.
|
||||||
|
|
||||||
|
Note: This method isn't used since we override get_input_schema() for
|
||||||
|
exact compatibility, but it demonstrates how ChatSimple could be
|
||||||
|
implemented using automatic schema building.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": CHAT_FIELD_DESCRIPTIONS["prompt"],
|
||||||
|
},
|
||||||
|
"files": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": CHAT_FIELD_DESCRIPTIONS["files"],
|
||||||
|
},
|
||||||
|
"images": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": CHAT_FIELD_DESCRIPTIONS["images"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_required_fields(self) -> list[str]:
|
||||||
|
"""Required fields for ChatSimple tool"""
|
||||||
|
return ["prompt"]
|
||||||
|
|
||||||
|
# === Hook Method Implementations ===
|
||||||
|
|
||||||
|
async def prepare_prompt(self, request: ChatRequest) -> str:
|
||||||
|
"""
|
||||||
|
Prepare the chat prompt with optional context files.
|
||||||
|
|
||||||
|
This implementation matches the original Chat tool exactly while using
|
||||||
|
SimpleTool convenience methods for cleaner code.
|
||||||
|
"""
|
||||||
|
# Use SimpleTool's Chat-style prompt preparation
|
||||||
|
return self.prepare_chat_style_prompt(request)
|
||||||
|
|
||||||
|
def format_response(self, response: str, request: ChatRequest, model_info: Optional[dict] = None) -> str:
|
||||||
|
"""
|
||||||
|
Format the chat response to match the original Chat tool exactly.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
f"{response}\n\n---\n\nAGENT'S TURN: Evaluate this perspective alongside your analysis to "
|
||||||
|
"form a comprehensive solution and continue with the user's request and task at hand."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_websearch_guidance(self) -> str:
|
||||||
|
"""
|
||||||
|
Return Chat tool-style web search guidance.
|
||||||
|
"""
|
||||||
|
return self.get_chat_style_websearch_guidance()
|
||||||
180
tools/contentvariant.py
Normal file
180
tools/contentvariant.py
Normal file
|
|
@ -0,0 +1,180 @@
|
||||||
|
"""Content Variant Generator Tool
|
||||||
|
|
||||||
|
Generates multiple variations of marketing content for A/B testing.
|
||||||
|
Supports testing different hooks, tones, lengths, and psychological angles.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from config import TEMPERATURE_HIGHLY_CREATIVE
|
||||||
|
from systemprompts import CONTENTVARIANT_PROMPT
|
||||||
|
from tools.models import ToolModelCategory
|
||||||
|
from tools.shared.base_models import ToolRequest
|
||||||
|
from tools.simple.base import SimpleTool
|
||||||
|
|
||||||
|
|
||||||
|
class ContentVariantRequest(ToolRequest):
|
||||||
|
"""Request model for Content Variant Generator"""
|
||||||
|
|
||||||
|
content: str = Field(
|
||||||
|
...,
|
||||||
|
description="Base content or topic to create variations from. Can be a draft post, subject line, or content concept.",
|
||||||
|
)
|
||||||
|
variation_count: int = Field(
|
||||||
|
default=10,
|
||||||
|
ge=5,
|
||||||
|
le=25,
|
||||||
|
description="Number of variations to generate (5-25). Default is 10.",
|
||||||
|
)
|
||||||
|
variation_types: Optional[list[str]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Types of variations to explore: 'hook', 'tone', 'length', 'structure', 'cta', 'angle'. Leave empty for mixed approach.",
|
||||||
|
)
|
||||||
|
platform: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Target platform for character limits and formatting: 'twitter', 'bluesky', 'linkedin', 'instagram', 'facebook', 'email_subject', 'blog'",
|
||||||
|
)
|
||||||
|
constraints: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Additional constraints: character limits, style requirements, brand voice guidelines, prohibited words/phrases",
|
||||||
|
)
|
||||||
|
testing_angles: Optional[list[str]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Specific psychological angles to test: 'curiosity', 'contrarian', 'knowledge_gap', 'urgency', 'insider', 'problem_solution', 'social_proof', 'transformation'",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ContentVariantTool(SimpleTool):
|
||||||
|
"""Generate multiple content variations for A/B testing"""
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "contentvariant"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Generate 5-25 variations of marketing content for A/B testing. "
|
||||||
|
"Tests different hooks, tones, lengths, and psychological angles. "
|
||||||
|
"Ideal for subject lines, social posts, email copy, and ads. "
|
||||||
|
"Each variation includes testing rationale and predicted audience response."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_system_prompt(self) -> str:
|
||||||
|
return CONTENTVARIANT_PROMPT
|
||||||
|
|
||||||
|
def get_default_temperature(self) -> float:
|
||||||
|
return TEMPERATURE_HIGHLY_CREATIVE
|
||||||
|
|
||||||
|
def get_model_category(self) -> ToolModelCategory:
|
||||||
|
return ToolModelCategory.FAST_RESPONSE
|
||||||
|
|
||||||
|
def get_request_model(self):
|
||||||
|
return ContentVariantRequest
|
||||||
|
|
||||||
|
def format_request_for_model(self, request: ContentVariantRequest) -> dict:
|
||||||
|
"""Format the content variant request for the AI model"""
|
||||||
|
prompt_parts = [f"Generate {request.variation_count} variations of this content:"]
|
||||||
|
prompt_parts.append(f"\n**Base Content:**\n{request.content}")
|
||||||
|
|
||||||
|
if request.platform:
|
||||||
|
prompt_parts.append(f"\n**Target Platform:** {request.platform}")
|
||||||
|
|
||||||
|
if request.variation_types:
|
||||||
|
prompt_parts.append(f"\n**Variation Types:** {', '.join(request.variation_types)}")
|
||||||
|
|
||||||
|
if request.testing_angles:
|
||||||
|
prompt_parts.append(f"\n**Testing Angles:** {', '.join(request.testing_angles)}")
|
||||||
|
|
||||||
|
if request.constraints:
|
||||||
|
prompt_parts.append(f"\n**Constraints:** {request.constraints}")
|
||||||
|
|
||||||
|
prompt_parts.append(
|
||||||
|
"\nGenerate variations with clear labels, character counts (if platform specified), "
|
||||||
|
"testing angles, and predicted audience responses."
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted_request = {
|
||||||
|
"prompt": "\n".join(prompt_parts),
|
||||||
|
"files": request.files or [],
|
||||||
|
"images": request.images or [],
|
||||||
|
"continuation_id": request.continuation_id,
|
||||||
|
"model": request.model,
|
||||||
|
"temperature": request.temperature,
|
||||||
|
"thinking_mode": request.thinking_mode,
|
||||||
|
"use_websearch": request.use_websearch,
|
||||||
|
}
|
||||||
|
|
||||||
|
return formatted_request
|
||||||
|
|
||||||
|
def get_input_schema(self) -> dict:
|
||||||
|
"""Return the JSON schema for this tool's input"""
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"content": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Base content or topic to create variations from",
|
||||||
|
},
|
||||||
|
"variation_count": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Number of variations to generate (5-25, default 10)",
|
||||||
|
"minimum": 5,
|
||||||
|
"maximum": 25,
|
||||||
|
"default": 10,
|
||||||
|
},
|
||||||
|
"variation_types": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Types: 'hook', 'tone', 'length', 'structure', 'cta', 'angle'",
|
||||||
|
},
|
||||||
|
"platform": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Platform: 'twitter', 'bluesky', 'linkedin', 'instagram', 'facebook', 'email_subject', 'blog'",
|
||||||
|
},
|
||||||
|
"constraints": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Character limits, style requirements, brand guidelines",
|
||||||
|
},
|
||||||
|
"testing_angles": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Angles: 'curiosity', 'contrarian', 'knowledge_gap', 'urgency', 'insider', 'problem_solution', 'social_proof', 'transformation'",
|
||||||
|
},
|
||||||
|
"files": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Optional brand guidelines or style reference files",
|
||||||
|
},
|
||||||
|
"images": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "Optional visual assets for context",
|
||||||
|
},
|
||||||
|
"continuation_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Thread ID to continue previous conversation",
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "AI model to use (leave empty for default fast model)",
|
||||||
|
},
|
||||||
|
"temperature": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "Creativity level 0.0-1.0 (default 0.8 for high variation)",
|
||||||
|
"minimum": 0.0,
|
||||||
|
"maximum": 1.0,
|
||||||
|
},
|
||||||
|
"thinking_mode": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Thinking depth: minimal, low, medium, high, max",
|
||||||
|
"enum": ["minimal", "low", "medium", "high", "max"],
|
||||||
|
},
|
||||||
|
"use_websearch": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Enable web search for current platform best practices",
|
||||||
|
"default": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["content"],
|
||||||
|
}
|
||||||
299
tools/listmodels.py
Normal file
299
tools/listmodels.py
Normal file
|
|
@ -0,0 +1,299 @@
|
||||||
|
"""
|
||||||
|
List Models Tool - Display all available models organized by provider
|
||||||
|
|
||||||
|
This tool provides a comprehensive view of all AI models available in the system,
|
||||||
|
organized by their provider (Gemini, OpenAI, X.AI, OpenRouter, Custom).
|
||||||
|
It shows which providers are configured and what models can be used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from mcp.types import TextContent
|
||||||
|
|
||||||
|
from tools.models import ToolModelCategory, ToolOutput
|
||||||
|
from tools.shared.base_models import ToolRequest
|
||||||
|
from tools.shared.base_tool import BaseTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ListModelsTool(BaseTool):
|
||||||
|
"""
|
||||||
|
Tool for listing all available AI models organized by provider.
|
||||||
|
|
||||||
|
This tool helps users understand:
|
||||||
|
- Which providers are configured (have API keys)
|
||||||
|
- What models are available from each provider
|
||||||
|
- Model aliases and their full names
|
||||||
|
- Context window sizes and capabilities
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "listmodels"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Shows which AI model providers are configured, available model names, their aliases and capabilities."
|
||||||
|
|
||||||
|
def get_input_schema(self) -> dict[str, Any]:
|
||||||
|
"""Return the JSON schema for the tool's input"""
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"model": {"type": "string", "description": "Model to use (ignored by listmodels tool)"}},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_annotations(self) -> Optional[dict[str, Any]]:
|
||||||
|
"""Return tool annotations indicating this is a read-only tool"""
|
||||||
|
return {"readOnlyHint": True}
|
||||||
|
|
||||||
|
def get_system_prompt(self) -> str:
|
||||||
|
"""No AI model needed for this tool"""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_request_model(self):
|
||||||
|
"""Return the Pydantic model for request validation."""
|
||||||
|
return ToolRequest
|
||||||
|
|
||||||
|
def requires_model(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def prepare_prompt(self, request: ToolRequest) -> str:
|
||||||
|
"""Not used for this utility tool"""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def format_response(self, response: str, request: ToolRequest, model_info: Optional[dict] = None) -> str:
|
||||||
|
"""Not used for this utility tool"""
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def execute(self, arguments: dict[str, Any]) -> list[TextContent]:
|
||||||
|
"""
|
||||||
|
List all available models organized by provider.
|
||||||
|
|
||||||
|
This overrides the base class execute to provide direct output without AI model calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: Standard tool arguments (none required)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted list of models by provider
|
||||||
|
"""
|
||||||
|
from providers.openrouter_registry import OpenRouterModelRegistry
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
output_lines = ["# Available AI Models\n"]
|
||||||
|
|
||||||
|
# Map provider types to friendly names and their models
|
||||||
|
provider_info = {
|
||||||
|
ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"},
|
||||||
|
ProviderType.OPENAI: {"name": "OpenAI", "env_key": "OPENAI_API_KEY"},
|
||||||
|
ProviderType.XAI: {"name": "X.AI (Grok)", "env_key": "XAI_API_KEY"},
|
||||||
|
ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check each native provider type
|
||||||
|
for provider_type, info in provider_info.items():
|
||||||
|
# Check if provider is enabled
|
||||||
|
provider = ModelProviderRegistry.get_provider(provider_type)
|
||||||
|
is_configured = provider is not None
|
||||||
|
|
||||||
|
output_lines.append(f"## {info['name']} {'✅' if is_configured else '❌'}")
|
||||||
|
|
||||||
|
if is_configured:
|
||||||
|
output_lines.append("**Status**: Configured and available")
|
||||||
|
output_lines.append("\n**Models**:")
|
||||||
|
|
||||||
|
aliases = []
|
||||||
|
for model_name, capabilities in provider.get_all_model_capabilities().items():
|
||||||
|
description = capabilities.description or "No description available"
|
||||||
|
context_window = capabilities.context_window
|
||||||
|
|
||||||
|
if context_window >= 1_000_000:
|
||||||
|
context_str = f"{context_window // 1_000_000}M context"
|
||||||
|
elif context_window >= 1_000:
|
||||||
|
context_str = f"{context_window // 1_000}K context"
|
||||||
|
else:
|
||||||
|
context_str = f"{context_window} context" if context_window > 0 else "unknown context"
|
||||||
|
|
||||||
|
output_lines.append(f"- `{model_name}` - {context_str}")
|
||||||
|
output_lines.append(f" - {description}")
|
||||||
|
|
||||||
|
for alias in capabilities.aliases or []:
|
||||||
|
if alias != model_name:
|
||||||
|
aliases.append(f"- `{alias}` → `{model_name}`")
|
||||||
|
|
||||||
|
if aliases:
|
||||||
|
output_lines.append("\n**Aliases**:")
|
||||||
|
output_lines.extend(sorted(aliases))
|
||||||
|
else:
|
||||||
|
output_lines.append(f"**Status**: Not configured (set {info['env_key']})")
|
||||||
|
|
||||||
|
output_lines.append("")
|
||||||
|
|
||||||
|
# Check OpenRouter
|
||||||
|
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
||||||
|
is_openrouter_configured = openrouter_key and openrouter_key != "your_openrouter_api_key_here"
|
||||||
|
|
||||||
|
output_lines.append(f"## OpenRouter {'✅' if is_openrouter_configured else '❌'}")
|
||||||
|
|
||||||
|
if is_openrouter_configured:
|
||||||
|
output_lines.append("**Status**: Configured and available")
|
||||||
|
output_lines.append("**Description**: Access to multiple cloud AI providers via unified API")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get OpenRouter provider from registry to properly apply restrictions
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER)
|
||||||
|
if provider:
|
||||||
|
# Get models with restrictions applied
|
||||||
|
available_models = provider.list_models(respect_restrictions=True)
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
|
||||||
|
# Group by provider for better organization
|
||||||
|
providers_models = {}
|
||||||
|
for model_name in available_models: # Show ALL available models
|
||||||
|
# Try to resolve to get config details
|
||||||
|
config = registry.resolve(model_name)
|
||||||
|
if config:
|
||||||
|
# Extract provider from model_name
|
||||||
|
provider_name = config.model_name.split("/")[0] if "/" in config.model_name else "other"
|
||||||
|
if provider_name not in providers_models:
|
||||||
|
providers_models[provider_name] = []
|
||||||
|
providers_models[provider_name].append((model_name, config))
|
||||||
|
else:
|
||||||
|
# Model without config - add with basic info
|
||||||
|
provider_name = model_name.split("/")[0] if "/" in model_name else "other"
|
||||||
|
if provider_name not in providers_models:
|
||||||
|
providers_models[provider_name] = []
|
||||||
|
providers_models[provider_name].append((model_name, None))
|
||||||
|
|
||||||
|
output_lines.append("\n**Available Models**:")
|
||||||
|
for provider_name, models in sorted(providers_models.items()):
|
||||||
|
output_lines.append(f"\n*{provider_name.title()}:*")
|
||||||
|
for alias, config in models: # Show ALL models from each provider
|
||||||
|
if config:
|
||||||
|
context_str = f"{config.context_window // 1000}K" if config.context_window else "?"
|
||||||
|
output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)")
|
||||||
|
else:
|
||||||
|
output_lines.append(f"- `{alias}`")
|
||||||
|
|
||||||
|
total_models = len(available_models)
|
||||||
|
# Show all models - no truncation message needed
|
||||||
|
|
||||||
|
# Check if restrictions are applied
|
||||||
|
restriction_service = None
|
||||||
|
try:
|
||||||
|
from utils.model_restrictions import get_restriction_service
|
||||||
|
|
||||||
|
restriction_service = get_restriction_service()
|
||||||
|
if restriction_service.has_restrictions(ProviderType.OPENROUTER):
|
||||||
|
allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER)
|
||||||
|
output_lines.append(
|
||||||
|
f"\n**Note**: Restricted to models matching: {', '.join(sorted(allowed_set))}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error checking OpenRouter restrictions: {e}")
|
||||||
|
else:
|
||||||
|
output_lines.append("**Error**: Could not load OpenRouter provider")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
output_lines.append(f"**Error loading models**: {str(e)}")
|
||||||
|
else:
|
||||||
|
output_lines.append("**Status**: Not configured (set OPENROUTER_API_KEY)")
|
||||||
|
output_lines.append("**Note**: Provides access to GPT-5, O3, Mistral, and many more")
|
||||||
|
|
||||||
|
output_lines.append("")
|
||||||
|
|
||||||
|
# Check Custom API
|
||||||
|
custom_url = os.getenv("CUSTOM_API_URL")
|
||||||
|
|
||||||
|
output_lines.append(f"## Custom/Local API {'✅' if custom_url else '❌'}")
|
||||||
|
|
||||||
|
if custom_url:
|
||||||
|
output_lines.append("**Status**: Configured and available")
|
||||||
|
output_lines.append(f"**Endpoint**: {custom_url}")
|
||||||
|
output_lines.append("**Description**: Local models via Ollama, vLLM, LM Studio, etc.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
registry = OpenRouterModelRegistry()
|
||||||
|
custom_models = []
|
||||||
|
|
||||||
|
for alias in registry.list_aliases():
|
||||||
|
config = registry.resolve(alias)
|
||||||
|
if config and config.is_custom:
|
||||||
|
custom_models.append((alias, config))
|
||||||
|
|
||||||
|
if custom_models:
|
||||||
|
output_lines.append("\n**Custom Models**:")
|
||||||
|
for alias, config in custom_models:
|
||||||
|
context_str = f"{config.context_window // 1000}K" if config.context_window else "?"
|
||||||
|
output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)")
|
||||||
|
if config.description:
|
||||||
|
output_lines.append(f" - {config.description}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
output_lines.append(f"**Error loading custom models**: {str(e)}")
|
||||||
|
else:
|
||||||
|
output_lines.append("**Status**: Not configured (set CUSTOM_API_URL)")
|
||||||
|
output_lines.append("**Example**: CUSTOM_API_URL=http://localhost:11434 (for Ollama)")
|
||||||
|
|
||||||
|
output_lines.append("")
|
||||||
|
|
||||||
|
# Add summary
|
||||||
|
output_lines.append("## Summary")
|
||||||
|
|
||||||
|
# Count configured providers
|
||||||
|
configured_count = sum(
|
||||||
|
[
|
||||||
|
1
|
||||||
|
for provider_type, info in provider_info.items()
|
||||||
|
if ModelProviderRegistry.get_provider(provider_type) is not None
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if is_openrouter_configured:
|
||||||
|
configured_count += 1
|
||||||
|
if custom_url:
|
||||||
|
configured_count += 1
|
||||||
|
|
||||||
|
output_lines.append(f"**Configured Providers**: {configured_count}")
|
||||||
|
|
||||||
|
# Get total available models
|
||||||
|
try:
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
|
||||||
|
# Get all available models respecting restrictions
|
||||||
|
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||||
|
total_models = len(available_models)
|
||||||
|
output_lines.append(f"**Total Available Models**: {total_models}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error getting total available models: {e}")
|
||||||
|
|
||||||
|
# Add usage tips
|
||||||
|
output_lines.append("\n**Usage Tips**:")
|
||||||
|
output_lines.append("- Use model aliases (e.g., 'flash', 'gpt5', 'opus') for convenience")
|
||||||
|
output_lines.append("- In auto mode, the CLI Agent will select the best model for each task")
|
||||||
|
output_lines.append("- Custom models are only available when CUSTOM_API_URL is set")
|
||||||
|
output_lines.append("- OpenRouter provides access to many cloud models with one API key")
|
||||||
|
|
||||||
|
# Format output
|
||||||
|
content = "\n".join(output_lines)
|
||||||
|
|
||||||
|
tool_output = ToolOutput(
|
||||||
|
status="success",
|
||||||
|
content=content,
|
||||||
|
content_type="text",
|
||||||
|
metadata={
|
||||||
|
"tool_name": self.name,
|
||||||
|
"configured_providers": configured_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return [TextContent(type="text", text=tool_output.model_dump_json())]
|
||||||
|
|
||||||
|
def get_model_category(self) -> ToolModelCategory:
|
||||||
|
"""Return the model category for this tool."""
|
||||||
|
return ToolModelCategory.FAST_RESPONSE # Simple listing, no AI needed
|
||||||
373
tools/models.py
Normal file
373
tools/models.py
Normal file
|
|
@ -0,0 +1,373 @@
|
||||||
|
"""
|
||||||
|
Data models for tool responses and interactions
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ToolModelCategory(Enum):
|
||||||
|
"""Categories for tool model selection based on requirements."""
|
||||||
|
|
||||||
|
EXTENDED_REASONING = "extended_reasoning" # Requires deep thinking capabilities
|
||||||
|
FAST_RESPONSE = "fast_response" # Speed and cost efficiency preferred
|
||||||
|
BALANCED = "balanced" # Balance of capability and performance
|
||||||
|
|
||||||
|
|
||||||
|
class ContinuationOffer(BaseModel):
|
||||||
|
"""Offer for CLI agent to continue conversation when Gemini doesn't ask follow-up"""
|
||||||
|
|
||||||
|
continuation_id: str = Field(
|
||||||
|
..., description="Thread continuation ID for multi-turn conversations across different tools"
|
||||||
|
)
|
||||||
|
note: str = Field(..., description="Message explaining continuation opportunity to CLI agent")
|
||||||
|
remaining_turns: int = Field(..., description="Number of conversation turns remaining")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolOutput(BaseModel):
|
||||||
|
"""Standardized output format for all tools"""
|
||||||
|
|
||||||
|
status: Literal[
|
||||||
|
"success",
|
||||||
|
"error",
|
||||||
|
"files_required_to_continue",
|
||||||
|
"full_codereview_required",
|
||||||
|
"focused_review_required",
|
||||||
|
"test_sample_needed",
|
||||||
|
"more_tests_required",
|
||||||
|
"refactor_analysis_complete",
|
||||||
|
"trace_complete",
|
||||||
|
"resend_prompt",
|
||||||
|
"code_too_large",
|
||||||
|
"continuation_available",
|
||||||
|
"no_bug_found",
|
||||||
|
] = "success"
|
||||||
|
content: Optional[str] = Field(None, description="The main content/response from the tool")
|
||||||
|
content_type: Literal["text", "markdown", "json"] = "text"
|
||||||
|
metadata: Optional[dict[str, Any]] = Field(default_factory=dict)
|
||||||
|
continuation_offer: Optional[ContinuationOffer] = Field(
|
||||||
|
None, description="Optional offer for Agent to continue conversation"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FilesNeededRequest(BaseModel):
|
||||||
|
"""Request for missing files / code to continue"""
|
||||||
|
|
||||||
|
status: Literal["files_required_to_continue"] = "files_required_to_continue"
|
||||||
|
mandatory_instructions: str = Field(..., description="Critical instructions for Agent regarding required context")
|
||||||
|
files_needed: Optional[list[str]] = Field(
|
||||||
|
default_factory=list, description="Specific files that are needed for analysis"
|
||||||
|
)
|
||||||
|
suggested_next_action: Optional[dict[str, Any]] = Field(
|
||||||
|
None,
|
||||||
|
description="Suggested tool call with parameters after getting clarification",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FullCodereviewRequired(BaseModel):
|
||||||
|
"""Request for full code review when scope is too large for quick review"""
|
||||||
|
|
||||||
|
status: Literal["full_codereview_required"] = "full_codereview_required"
|
||||||
|
important: Optional[str] = Field(None, description="Important message about escalation")
|
||||||
|
reason: Optional[str] = Field(None, description="Reason why full review is needed")
|
||||||
|
|
||||||
|
|
||||||
|
class FocusedReviewRequired(BaseModel):
|
||||||
|
"""Request for Agent to provide smaller, focused subsets of code for review"""
|
||||||
|
|
||||||
|
status: Literal["focused_review_required"] = "focused_review_required"
|
||||||
|
reason: str = Field(..., description="Why the current scope is too large for effective review")
|
||||||
|
suggestion: str = Field(
|
||||||
|
..., description="Suggested approach for breaking down the review into smaller, focused parts"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSampleNeeded(BaseModel):
|
||||||
|
"""Request for additional test samples to determine testing framework"""
|
||||||
|
|
||||||
|
status: Literal["test_sample_needed"] = "test_sample_needed"
|
||||||
|
reason: str = Field(..., description="Reason why additional test samples are required")
|
||||||
|
|
||||||
|
|
||||||
|
class MoreTestsRequired(BaseModel):
|
||||||
|
"""Request for continuation to generate additional tests"""
|
||||||
|
|
||||||
|
status: Literal["more_tests_required"] = "more_tests_required"
|
||||||
|
pending_tests: str = Field(..., description="List of pending tests to be generated")
|
||||||
|
|
||||||
|
|
||||||
|
class RefactorOpportunity(BaseModel):
|
||||||
|
"""A single refactoring opportunity with precise targeting information"""
|
||||||
|
|
||||||
|
id: str = Field(..., description="Unique identifier for this refactoring opportunity")
|
||||||
|
type: Literal["decompose", "codesmells", "modernize", "organization"] = Field(
|
||||||
|
..., description="Type of refactoring"
|
||||||
|
)
|
||||||
|
severity: Literal["critical", "high", "medium", "low"] = Field(..., description="Severity level")
|
||||||
|
file: str = Field(..., description="Absolute path to the file")
|
||||||
|
start_line: int = Field(..., description="Starting line number")
|
||||||
|
end_line: int = Field(..., description="Ending line number")
|
||||||
|
context_start_text: str = Field(..., description="Exact text from start line for verification")
|
||||||
|
context_end_text: str = Field(..., description="Exact text from end line for verification")
|
||||||
|
issue: str = Field(..., description="Clear description of what needs refactoring")
|
||||||
|
suggestion: str = Field(..., description="Specific refactoring action to take")
|
||||||
|
rationale: str = Field(..., description="Why this improves the code")
|
||||||
|
code_to_replace: str = Field(..., description="Original code that should be changed")
|
||||||
|
replacement_code_snippet: str = Field(..., description="Refactored version of the code")
|
||||||
|
new_code_snippets: Optional[list[dict]] = Field(
|
||||||
|
default_factory=list, description="Additional code snippets to be added"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RefactorAction(BaseModel):
|
||||||
|
"""Next action for Agent to implement refactoring"""
|
||||||
|
|
||||||
|
action_type: Literal["EXTRACT_METHOD", "SPLIT_CLASS", "MODERNIZE_SYNTAX", "REORGANIZE_CODE", "DECOMPOSE_FILE"] = (
|
||||||
|
Field(..., description="Type of action to perform")
|
||||||
|
)
|
||||||
|
target_file: str = Field(..., description="Absolute path to target file")
|
||||||
|
source_lines: str = Field(..., description="Line range (e.g., '45-67')")
|
||||||
|
description: str = Field(..., description="Step-by-step action description for CLI Agent")
|
||||||
|
|
||||||
|
|
||||||
|
class RefactorAnalysisComplete(BaseModel):
|
||||||
|
"""Complete refactor analysis with prioritized opportunities"""
|
||||||
|
|
||||||
|
status: Literal["refactor_analysis_complete"] = "refactor_analysis_complete"
|
||||||
|
refactor_opportunities: list[RefactorOpportunity] = Field(..., description="List of refactoring opportunities")
|
||||||
|
priority_sequence: list[str] = Field(..., description="Recommended order of refactoring IDs")
|
||||||
|
next_actions: list[RefactorAction] = Field(..., description="Specific actions for the agent to implement")
|
||||||
|
|
||||||
|
|
||||||
|
class CodeTooLargeRequest(BaseModel):
|
||||||
|
"""Request to reduce file selection due to size constraints"""
|
||||||
|
|
||||||
|
status: Literal["code_too_large"] = "code_too_large"
|
||||||
|
content: str = Field(..., description="Message explaining the size constraint")
|
||||||
|
content_type: Literal["text"] = "text"
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ResendPromptRequest(BaseModel):
|
||||||
|
"""Request to resend prompt via file due to size limits"""
|
||||||
|
|
||||||
|
status: Literal["resend_prompt"] = "resend_prompt"
|
||||||
|
content: str = Field(..., description="Instructions for handling large prompt")
|
||||||
|
content_type: Literal["text"] = "text"
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TraceEntryPoint(BaseModel):
|
||||||
|
"""Entry point information for trace analysis"""
|
||||||
|
|
||||||
|
file: str = Field(..., description="Absolute path to the file")
|
||||||
|
class_or_struct: str = Field(..., description="Class or module name")
|
||||||
|
method: str = Field(..., description="Method or function name")
|
||||||
|
signature: str = Field(..., description="Full method signature")
|
||||||
|
parameters: Optional[dict[str, Any]] = Field(default_factory=dict, description="Parameter values used in analysis")
|
||||||
|
|
||||||
|
|
||||||
|
class TraceTarget(BaseModel):
|
||||||
|
"""Target information for dependency analysis"""
|
||||||
|
|
||||||
|
file: str = Field(..., description="Absolute path to the file")
|
||||||
|
class_or_struct: str = Field(..., description="Class or module name")
|
||||||
|
method: str = Field(..., description="Method or function name")
|
||||||
|
signature: str = Field(..., description="Full method signature")
|
||||||
|
|
||||||
|
|
||||||
|
class CallPathStep(BaseModel):
|
||||||
|
"""A single step in the call path trace"""
|
||||||
|
|
||||||
|
from_info: dict[str, Any] = Field(..., description="Source location information", alias="from")
|
||||||
|
to: dict[str, Any] = Field(..., description="Target location information")
|
||||||
|
reason: str = Field(..., description="Reason for the call or dependency")
|
||||||
|
condition: Optional[str] = Field(None, description="Conditional logic if applicable")
|
||||||
|
ambiguous: bool = Field(False, description="Whether this call is ambiguous")
|
||||||
|
|
||||||
|
|
||||||
|
class BranchingPoint(BaseModel):
|
||||||
|
"""A branching point in the execution flow"""
|
||||||
|
|
||||||
|
file: str = Field(..., description="File containing the branching point")
|
||||||
|
method: str = Field(..., description="Method containing the branching point")
|
||||||
|
line: int = Field(..., description="Line number of the branching point")
|
||||||
|
condition: str = Field(..., description="Branching condition")
|
||||||
|
branches: list[str] = Field(..., description="Possible execution branches")
|
||||||
|
ambiguous: bool = Field(False, description="Whether the branching is ambiguous")
|
||||||
|
|
||||||
|
|
||||||
|
class SideEffect(BaseModel):
|
||||||
|
"""A side effect detected in the trace"""
|
||||||
|
|
||||||
|
type: str = Field(..., description="Type of side effect")
|
||||||
|
description: str = Field(..., description="Description of the side effect")
|
||||||
|
file: str = Field(..., description="File where the side effect occurs")
|
||||||
|
method: str = Field(..., description="Method where the side effect occurs")
|
||||||
|
line: int = Field(..., description="Line number of the side effect")
|
||||||
|
|
||||||
|
|
||||||
|
class UnresolvedDependency(BaseModel):
|
||||||
|
"""An unresolved dependency in the trace"""
|
||||||
|
|
||||||
|
reason: str = Field(..., description="Reason why the dependency is unresolved")
|
||||||
|
affected_file: str = Field(..., description="File affected by the unresolved dependency")
|
||||||
|
line: int = Field(..., description="Line number of the unresolved dependency")
|
||||||
|
|
||||||
|
|
||||||
|
class IncomingDependency(BaseModel):
|
||||||
|
"""An incoming dependency (what calls this target)"""
|
||||||
|
|
||||||
|
from_file: str = Field(..., description="Source file of the dependency")
|
||||||
|
from_class: str = Field(..., description="Source class of the dependency")
|
||||||
|
from_method: str = Field(..., description="Source method of the dependency")
|
||||||
|
line: int = Field(..., description="Line number of the dependency")
|
||||||
|
type: str = Field(..., description="Type of dependency")
|
||||||
|
|
||||||
|
|
||||||
|
class OutgoingDependency(BaseModel):
|
||||||
|
"""An outgoing dependency (what this target calls)"""
|
||||||
|
|
||||||
|
to_file: str = Field(..., description="Target file of the dependency")
|
||||||
|
to_class: str = Field(..., description="Target class of the dependency")
|
||||||
|
to_method: str = Field(..., description="Target method of the dependency")
|
||||||
|
line: int = Field(..., description="Line number of the dependency")
|
||||||
|
type: str = Field(..., description="Type of dependency")
|
||||||
|
|
||||||
|
|
||||||
|
class TypeDependency(BaseModel):
|
||||||
|
"""A type-level dependency (inheritance, imports, etc.)"""
|
||||||
|
|
||||||
|
dependency_type: str = Field(..., description="Type of dependency")
|
||||||
|
source_file: str = Field(..., description="Source file of the dependency")
|
||||||
|
source_entity: str = Field(..., description="Source entity (class, module)")
|
||||||
|
target: str = Field(..., description="Target entity")
|
||||||
|
|
||||||
|
|
||||||
|
class StateAccess(BaseModel):
|
||||||
|
"""State access information"""
|
||||||
|
|
||||||
|
file: str = Field(..., description="File where state is accessed")
|
||||||
|
method: str = Field(..., description="Method accessing the state")
|
||||||
|
access_type: str = Field(..., description="Type of access (reads, writes, etc.)")
|
||||||
|
state_entity: str = Field(..., description="State entity being accessed")
|
||||||
|
|
||||||
|
|
||||||
|
class TraceComplete(BaseModel):
|
||||||
|
"""Complete trace analysis response"""
|
||||||
|
|
||||||
|
status: Literal["trace_complete"] = "trace_complete"
|
||||||
|
trace_type: Literal["precision", "dependencies"] = Field(..., description="Type of trace performed")
|
||||||
|
|
||||||
|
# Precision mode fields
|
||||||
|
entry_point: Optional[TraceEntryPoint] = Field(None, description="Entry point for precision trace")
|
||||||
|
call_path: Optional[list[CallPathStep]] = Field(default_factory=list, description="Call path for precision trace")
|
||||||
|
branching_points: Optional[list[BranchingPoint]] = Field(default_factory=list, description="Branching points")
|
||||||
|
side_effects: Optional[list[SideEffect]] = Field(default_factory=list, description="Side effects detected")
|
||||||
|
unresolved: Optional[list[UnresolvedDependency]] = Field(
|
||||||
|
default_factory=list, description="Unresolved dependencies"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dependencies mode fields
|
||||||
|
target: Optional[TraceTarget] = Field(None, description="Target for dependency analysis")
|
||||||
|
incoming_dependencies: Optional[list[IncomingDependency]] = Field(
|
||||||
|
default_factory=list, description="Incoming dependencies"
|
||||||
|
)
|
||||||
|
outgoing_dependencies: Optional[list[OutgoingDependency]] = Field(
|
||||||
|
default_factory=list, description="Outgoing dependencies"
|
||||||
|
)
|
||||||
|
type_dependencies: Optional[list[TypeDependency]] = Field(default_factory=list, description="Type dependencies")
|
||||||
|
state_access: Optional[list[StateAccess]] = Field(default_factory=list, description="State access information")
|
||||||
|
|
||||||
|
|
||||||
|
class DiagnosticHypothesis(BaseModel):
|
||||||
|
"""A debugging hypothesis with context and next steps"""
|
||||||
|
|
||||||
|
rank: int = Field(..., description="Ranking of this hypothesis (1 = most likely)")
|
||||||
|
confidence: Literal["high", "medium", "low"] = Field(..., description="Confidence level")
|
||||||
|
hypothesis: str = Field(..., description="Description of the potential root cause")
|
||||||
|
reasoning: str = Field(..., description="Why this hypothesis is plausible")
|
||||||
|
next_step: str = Field(..., description="Suggested action to test/validate this hypothesis")
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredDebugResponse(BaseModel):
|
||||||
|
"""Enhanced debug response with multiple hypotheses"""
|
||||||
|
|
||||||
|
summary: str = Field(..., description="Brief summary of the issue")
|
||||||
|
hypotheses: list[DiagnosticHypothesis] = Field(..., description="Ranked list of potential causes")
|
||||||
|
immediate_actions: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Immediate steps to take regardless of root cause",
|
||||||
|
)
|
||||||
|
additional_context_needed: Optional[list[str]] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Additional files or information that would help with analysis",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DebugHypothesis(BaseModel):
|
||||||
|
"""A debugging hypothesis with detailed analysis"""
|
||||||
|
|
||||||
|
name: str = Field(..., description="Name/title of the hypothesis")
|
||||||
|
confidence: Literal["High", "Medium", "Low"] = Field(..., description="Confidence level")
|
||||||
|
root_cause: str = Field(..., description="Technical explanation of the root cause")
|
||||||
|
evidence: str = Field(..., description="Logs or code clues supporting this hypothesis")
|
||||||
|
correlation: str = Field(..., description="How symptoms map to the cause")
|
||||||
|
validation: str = Field(..., description="Quick test to confirm the hypothesis")
|
||||||
|
minimal_fix: str = Field(..., description="Smallest change to resolve the issue")
|
||||||
|
regression_check: str = Field(..., description="Why this fix is safe")
|
||||||
|
file_references: list[str] = Field(default_factory=list, description="File:line format for exact locations")
|
||||||
|
|
||||||
|
|
||||||
|
class DebugAnalysisComplete(BaseModel):
|
||||||
|
"""Complete debugging analysis with systematic investigation tracking"""
|
||||||
|
|
||||||
|
status: Literal["analysis_complete"] = "analysis_complete"
|
||||||
|
investigation_id: str = Field(..., description="Auto-generated unique ID for this investigation")
|
||||||
|
summary: str = Field(..., description="Brief description of the problem and its impact")
|
||||||
|
investigation_steps: list[str] = Field(..., description="Steps taken during the investigation")
|
||||||
|
hypotheses: list[DebugHypothesis] = Field(..., description="Ranked hypotheses with detailed analysis")
|
||||||
|
key_findings: list[str] = Field(..., description="Important discoveries made during analysis")
|
||||||
|
immediate_actions: list[str] = Field(..., description="Steps to take regardless of which hypothesis is correct")
|
||||||
|
recommended_tools: list[str] = Field(default_factory=list, description="Additional tools recommended for analysis")
|
||||||
|
prevention_strategy: Optional[str] = Field(
|
||||||
|
None, description="Targeted measures to prevent this exact issue from recurring"
|
||||||
|
)
|
||||||
|
investigation_summary: str = Field(
|
||||||
|
..., description="Comprehensive summary of the complete investigation process and conclusions"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NoBugFound(BaseModel):
|
||||||
|
"""Response when thorough investigation finds no concrete evidence of a bug"""
|
||||||
|
|
||||||
|
status: Literal["no_bug_found"] = "no_bug_found"
|
||||||
|
summary: str = Field(..., description="Summary of what was thoroughly investigated")
|
||||||
|
investigation_steps: list[str] = Field(..., description="Steps taken during the investigation")
|
||||||
|
areas_examined: list[str] = Field(..., description="Code areas and potential failure points examined")
|
||||||
|
confidence_level: Literal["High", "Medium", "Low"] = Field(
|
||||||
|
..., description="Confidence level in the no-bug finding"
|
||||||
|
)
|
||||||
|
alternative_explanations: list[str] = Field(
|
||||||
|
..., description="Possible alternative explanations for reported symptoms"
|
||||||
|
)
|
||||||
|
recommended_questions: list[str] = Field(..., description="Questions to clarify the issue with the user")
|
||||||
|
next_steps: list[str] = Field(..., description="Suggested actions to better understand the reported issue")
|
||||||
|
|
||||||
|
|
||||||
|
# Registry mapping status strings to their corresponding Pydantic models
|
||||||
|
SPECIAL_STATUS_MODELS = {
|
||||||
|
"files_required_to_continue": FilesNeededRequest,
|
||||||
|
"full_codereview_required": FullCodereviewRequired,
|
||||||
|
"focused_review_required": FocusedReviewRequired,
|
||||||
|
"test_sample_needed": TestSampleNeeded,
|
||||||
|
"more_tests_required": MoreTestsRequired,
|
||||||
|
"refactor_analysis_complete": RefactorAnalysisComplete,
|
||||||
|
"trace_complete": TraceComplete,
|
||||||
|
"resend_prompt": ResendPromptRequest,
|
||||||
|
"code_too_large": CodeTooLargeRequest,
|
||||||
|
"analysis_complete": DebugAnalysisComplete,
|
||||||
|
"no_bug_found": NoBugFound,
|
||||||
|
}
|
||||||
19
tools/shared/__init__.py
Normal file
19
tools/shared/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
"""
|
||||||
|
Shared infrastructure for Zen MCP tools.
|
||||||
|
|
||||||
|
This module contains the core base classes and utilities that are shared
|
||||||
|
across all tool types. It provides the foundation for the tool architecture.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base_models import BaseWorkflowRequest, ConsolidatedFindings, ToolRequest, WorkflowRequest
|
||||||
|
from .base_tool import BaseTool
|
||||||
|
from .schema_builders import SchemaBuilder
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseTool",
|
||||||
|
"ToolRequest",
|
||||||
|
"BaseWorkflowRequest",
|
||||||
|
"WorkflowRequest",
|
||||||
|
"ConsolidatedFindings",
|
||||||
|
"SchemaBuilder",
|
||||||
|
]
|
||||||
165
tools/shared/base_models.py
Normal file
165
tools/shared/base_models.py
Normal file
|
|
@ -0,0 +1,165 @@
|
||||||
|
"""
|
||||||
|
Base models for Zen MCP tools.
|
||||||
|
|
||||||
|
This module contains the shared Pydantic models used across all tools,
|
||||||
|
extracted to avoid circular imports and promote code reuse.
|
||||||
|
|
||||||
|
Key Models:
|
||||||
|
- ToolRequest: Base request model for all tools
|
||||||
|
- WorkflowRequest: Extended request model for workflow-based tools
|
||||||
|
- ConsolidatedFindings: Model for tracking workflow progress
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Shared field descriptions to avoid duplication
|
||||||
|
COMMON_FIELD_DESCRIPTIONS = {
|
||||||
|
"model": "Model to run. Supply a name if requested by the user or stay in auto mode. When in auto mode, use `listmodels` tool for model discovery.",
|
||||||
|
"temperature": "0 = deterministic · 1 = creative.",
|
||||||
|
"thinking_mode": "Reasoning depth: minimal, low, medium, high, or max.",
|
||||||
|
"continuation_id": (
|
||||||
|
"Unique thread continuation ID for multi-turn conversations. Works across different tools. "
|
||||||
|
"ALWAYS reuse the last continuation_id you were given—this preserves full conversation context, "
|
||||||
|
"files, and findings so the agent can resume seamlessly."
|
||||||
|
),
|
||||||
|
"images": "Optional absolute image paths or base64 blobs for visual context.",
|
||||||
|
"files": "Optional absolute file or folder paths (do not shorten).",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Workflow-specific field descriptions
|
||||||
|
WORKFLOW_FIELD_DESCRIPTIONS = {
|
||||||
|
"step": "Current work step content and findings from your overall work",
|
||||||
|
"step_number": "Current step number in work sequence (starts at 1)",
|
||||||
|
"total_steps": "Estimated total steps needed to complete work",
|
||||||
|
"next_step_required": "Whether another work step is needed. When false, aim to reduce total_steps to match step_number to avoid mismatch.",
|
||||||
|
"findings": "Important findings, evidence and insights discovered in this step",
|
||||||
|
"files_checked": "List of files examined during this work step",
|
||||||
|
"relevant_files": "Files identified as relevant to issue/goal (FULL absolute paths to real files/folders - DO NOT SHORTEN)",
|
||||||
|
"relevant_context": "Methods/functions identified as involved in the issue",
|
||||||
|
"issues_found": "Issues identified with severity levels during work",
|
||||||
|
"confidence": (
|
||||||
|
"Confidence level: exploring (just starting), low (early investigation), "
|
||||||
|
"medium (some evidence), high (strong evidence), very_high (comprehensive understanding), "
|
||||||
|
"almost_certain (near complete confidence), certain (100% confidence locally - no external validation needed)"
|
||||||
|
),
|
||||||
|
"hypothesis": "Current theory about issue/goal based on work",
|
||||||
|
"backtrack_from_step": "Step number to backtrack from if work needs revision",
|
||||||
|
"use_assistant_model": (
|
||||||
|
"Use assistant model for expert analysis after workflow steps. "
|
||||||
|
"False skips expert analysis, relies solely on Claude's investigation. "
|
||||||
|
"Defaults to True for comprehensive validation."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Base request model for all Zen MCP tools.
|
||||||
|
|
||||||
|
This model defines common fields that all tools accept, including
|
||||||
|
model selection, temperature control, and conversation threading.
|
||||||
|
Tool-specific request models should inherit from this class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
model: Optional[str] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["model"])
|
||||||
|
temperature: Optional[float] = Field(None, ge=0.0, le=1.0, description=COMMON_FIELD_DESCRIPTIONS["temperature"])
|
||||||
|
thinking_mode: Optional[str] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["thinking_mode"])
|
||||||
|
|
||||||
|
# Conversation support
|
||||||
|
continuation_id: Optional[str] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["continuation_id"])
|
||||||
|
|
||||||
|
# Visual context
|
||||||
|
images: Optional[list[str]] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["images"])
|
||||||
|
|
||||||
|
|
||||||
|
class BaseWorkflowRequest(ToolRequest):
|
||||||
|
"""
|
||||||
|
Minimal base request model for workflow tools.
|
||||||
|
|
||||||
|
This provides only the essential fields that ALL workflow tools need,
|
||||||
|
allowing for maximum flexibility in tool-specific implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Core workflow fields that ALL workflow tools need
|
||||||
|
step: str = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["step"])
|
||||||
|
step_number: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["step_number"])
|
||||||
|
total_steps: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["total_steps"])
|
||||||
|
next_step_required: bool = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"])
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRequest(BaseWorkflowRequest):
|
||||||
|
"""
|
||||||
|
Extended request model for workflow-based tools.
|
||||||
|
|
||||||
|
This model extends ToolRequest with fields specific to the workflow
|
||||||
|
pattern, where tools perform multi-step work with forced pauses between steps.
|
||||||
|
|
||||||
|
Used by: debug, precommit, codereview, refactor, thinkdeep, analyze
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Required workflow fields
|
||||||
|
step: str = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["step"])
|
||||||
|
step_number: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["step_number"])
|
||||||
|
total_steps: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["total_steps"])
|
||||||
|
next_step_required: bool = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"])
|
||||||
|
|
||||||
|
# Work tracking fields
|
||||||
|
findings: str = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["findings"])
|
||||||
|
files_checked: list[str] = Field(default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["files_checked"])
|
||||||
|
relevant_files: list[str] = Field(default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"])
|
||||||
|
relevant_context: list[str] = Field(
|
||||||
|
default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["relevant_context"]
|
||||||
|
)
|
||||||
|
issues_found: list[dict] = Field(default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["issues_found"])
|
||||||
|
confidence: str = Field("low", description=WORKFLOW_FIELD_DESCRIPTIONS["confidence"])
|
||||||
|
|
||||||
|
# Optional workflow fields
|
||||||
|
hypothesis: Optional[str] = Field(None, description=WORKFLOW_FIELD_DESCRIPTIONS["hypothesis"])
|
||||||
|
backtrack_from_step: Optional[int] = Field(
|
||||||
|
None, ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["backtrack_from_step"]
|
||||||
|
)
|
||||||
|
use_assistant_model: Optional[bool] = Field(True, description=WORKFLOW_FIELD_DESCRIPTIONS["use_assistant_model"])
|
||||||
|
|
||||||
|
@field_validator("files_checked", "relevant_files", "relevant_context", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def convert_string_to_list(cls, v):
|
||||||
|
"""Convert string inputs to empty lists to handle malformed inputs gracefully."""
|
||||||
|
if isinstance(v, str):
|
||||||
|
logger.warning(f"Field received string '{v}' instead of list, converting to empty list")
|
||||||
|
return []
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class ConsolidatedFindings(BaseModel):
|
||||||
|
"""
|
||||||
|
Model for tracking consolidated findings across workflow steps.
|
||||||
|
|
||||||
|
This model accumulates findings, files, methods, and issues
|
||||||
|
discovered during multi-step work. It's used by
|
||||||
|
BaseWorkflowMixin to track progress across workflow steps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
files_checked: set[str] = Field(default_factory=set, description="All files examined across all steps")
|
||||||
|
relevant_files: set[str] = Field(
|
||||||
|
default_factory=set,
|
||||||
|
description="Subset of files_checked identified as relevant for work at hand",
|
||||||
|
)
|
||||||
|
relevant_context: set[str] = Field(
|
||||||
|
default_factory=set, description="All methods/functions identified during overall work"
|
||||||
|
)
|
||||||
|
findings: list[str] = Field(default_factory=list, description="Chronological findings from each work step")
|
||||||
|
hypotheses: list[dict] = Field(default_factory=list, description="Evolution of hypotheses across steps")
|
||||||
|
issues_found: list[dict] = Field(default_factory=list, description="All issues with severity levels")
|
||||||
|
images: list[str] = Field(default_factory=list, description="Images collected during work")
|
||||||
|
confidence: str = Field("low", description="Latest confidence level from steps")
|
||||||
|
|
||||||
|
|
||||||
|
# Tool-specific field descriptions are now declared in each tool file
|
||||||
|
# This keeps concerns separated and makes each tool self-contained
|
||||||
1399
tools/shared/base_tool.py
Normal file
1399
tools/shared/base_tool.py
Normal file
File diff suppressed because it is too large
Load diff
159
tools/shared/schema_builders.py
Normal file
159
tools/shared/schema_builders.py
Normal file
|
|
@ -0,0 +1,159 @@
|
||||||
|
"""
|
||||||
|
Core schema building functionality for Zen MCP tools.
|
||||||
|
|
||||||
|
This module provides base schema generation functionality for simple tools.
|
||||||
|
Workflow-specific schema building is located in workflow/schema_builders.py
|
||||||
|
to maintain proper separation of concerns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base_models import COMMON_FIELD_DESCRIPTIONS
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaBuilder:
|
||||||
|
"""
|
||||||
|
Base schema builder for simple MCP tools.
|
||||||
|
|
||||||
|
This class provides static methods to build consistent schemas for simple tools.
|
||||||
|
Workflow tools use WorkflowSchemaBuilder in workflow/schema_builders.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Common field schemas that can be reused across all tool types
|
||||||
|
COMMON_FIELD_SCHEMAS = {
|
||||||
|
"temperature": {
|
||||||
|
"type": "number",
|
||||||
|
"description": COMMON_FIELD_DESCRIPTIONS["temperature"],
|
||||||
|
"minimum": 0.0,
|
||||||
|
"maximum": 1.0,
|
||||||
|
},
|
||||||
|
"thinking_mode": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["minimal", "low", "medium", "high", "max"],
|
||||||
|
"description": COMMON_FIELD_DESCRIPTIONS["thinking_mode"],
|
||||||
|
},
|
||||||
|
"continuation_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": COMMON_FIELD_DESCRIPTIONS["continuation_id"],
|
||||||
|
},
|
||||||
|
"images": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": COMMON_FIELD_DESCRIPTIONS["images"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Simple tool-specific field schemas (workflow tools use relevant_files instead)
|
||||||
|
SIMPLE_FIELD_SCHEMAS = {
|
||||||
|
"files": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": COMMON_FIELD_DESCRIPTIONS["files"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_schema(
|
||||||
|
tool_specific_fields: dict[str, dict[str, Any]] = None,
|
||||||
|
required_fields: list[str] = None,
|
||||||
|
model_field_schema: dict[str, Any] = None,
|
||||||
|
auto_mode: bool = False,
|
||||||
|
require_model: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Build complete schema for simple tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_specific_fields: Additional fields specific to the tool
|
||||||
|
required_fields: List of required field names
|
||||||
|
model_field_schema: Schema for the model field
|
||||||
|
auto_mode: Whether the tool is in auto mode (affects model requirement)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Complete JSON schema for the tool
|
||||||
|
"""
|
||||||
|
properties = {}
|
||||||
|
|
||||||
|
# Add common fields (temperature, thinking_mode, etc.)
|
||||||
|
properties.update(SchemaBuilder.COMMON_FIELD_SCHEMAS)
|
||||||
|
|
||||||
|
# Add simple tool-specific fields (files field for simple tools)
|
||||||
|
properties.update(SchemaBuilder.SIMPLE_FIELD_SCHEMAS)
|
||||||
|
|
||||||
|
# Add model field if provided
|
||||||
|
if model_field_schema:
|
||||||
|
properties["model"] = model_field_schema
|
||||||
|
|
||||||
|
# Add tool-specific fields if provided
|
||||||
|
if tool_specific_fields:
|
||||||
|
properties.update(tool_specific_fields)
|
||||||
|
|
||||||
|
# Build required fields list
|
||||||
|
required = list(required_fields) if required_fields else []
|
||||||
|
if (auto_mode or require_model) and "model" not in required:
|
||||||
|
required.append("model")
|
||||||
|
|
||||||
|
# Build the complete schema
|
||||||
|
schema = {
|
||||||
|
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
"additionalProperties": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
if required:
|
||||||
|
schema["required"] = required
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_common_fields() -> dict[str, dict[str, Any]]:
|
||||||
|
"""Get the standard field schemas for simple tools."""
|
||||||
|
return SchemaBuilder.COMMON_FIELD_SCHEMAS.copy()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_field_schema(
|
||||||
|
field_type: str,
|
||||||
|
description: str,
|
||||||
|
enum_values: list[str] = None,
|
||||||
|
minimum: float = None,
|
||||||
|
maximum: float = None,
|
||||||
|
items_type: str = None,
|
||||||
|
default: Any = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Helper method to create field schemas with common patterns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_type: JSON schema type ("string", "number", "array", etc.)
|
||||||
|
description: Human-readable description of the field
|
||||||
|
enum_values: For enum fields, list of allowed values
|
||||||
|
minimum: For numeric fields, minimum value
|
||||||
|
maximum: For numeric fields, maximum value
|
||||||
|
items_type: For array fields, type of array items
|
||||||
|
default: Default value for the field
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON schema object for the field
|
||||||
|
"""
|
||||||
|
schema = {
|
||||||
|
"type": field_type,
|
||||||
|
"description": description,
|
||||||
|
}
|
||||||
|
|
||||||
|
if enum_values:
|
||||||
|
schema["enum"] = enum_values
|
||||||
|
|
||||||
|
if minimum is not None:
|
||||||
|
schema["minimum"] = minimum
|
||||||
|
|
||||||
|
if maximum is not None:
|
||||||
|
schema["maximum"] = maximum
|
||||||
|
|
||||||
|
if items_type and field_type == "array":
|
||||||
|
schema["items"] = {"type": items_type}
|
||||||
|
|
||||||
|
if default is not None:
|
||||||
|
schema["default"] = default
|
||||||
|
|
||||||
|
return schema
|
||||||
18
tools/simple/__init__.py
Normal file
18
tools/simple/__init__.py
Normal file
|
|
@ -0,0 +1,18 @@
|
||||||
|
"""
|
||||||
|
Simple tools for Zen MCP.
|
||||||
|
|
||||||
|
Simple tools follow a basic request → AI model → response pattern.
|
||||||
|
They inherit from SimpleTool which provides streamlined functionality
|
||||||
|
for tools that don't need multi-step workflows.
|
||||||
|
|
||||||
|
Available simple tools:
|
||||||
|
- chat: General chat and collaborative thinking
|
||||||
|
- consensus: Multi-perspective analysis
|
||||||
|
- listmodels: Model listing and information
|
||||||
|
- testgen: Test generation
|
||||||
|
- tracer: Execution tracing
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import SimpleTool
|
||||||
|
|
||||||
|
__all__ = ["SimpleTool"]
|
||||||
985
tools/simple/base.py
Normal file
985
tools/simple/base.py
Normal file
|
|
@ -0,0 +1,985 @@
|
||||||
|
"""
|
||||||
|
Base class for simple MCP tools.
|
||||||
|
|
||||||
|
Simple tools follow a straightforward pattern:
|
||||||
|
1. Receive request
|
||||||
|
2. Prepare prompt (with files, context, etc.)
|
||||||
|
3. Call AI model
|
||||||
|
4. Format and return response
|
||||||
|
|
||||||
|
They use the shared SchemaBuilder for consistent schema generation
|
||||||
|
and inherit all the conversation, file processing, and model handling
|
||||||
|
capabilities from BaseTool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from tools.shared.base_models import ToolRequest
|
||||||
|
from tools.shared.base_tool import BaseTool
|
||||||
|
from tools.shared.schema_builders import SchemaBuilder
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleTool(BaseTool):
|
||||||
|
"""
|
||||||
|
Base class for simple (non-workflow) tools.
|
||||||
|
|
||||||
|
Simple tools are request/response tools that don't require multi-step workflows.
|
||||||
|
They benefit from:
|
||||||
|
- Automatic schema generation using SchemaBuilder
|
||||||
|
- Inherited conversation handling and file processing
|
||||||
|
- Standardized model integration
|
||||||
|
- Consistent error handling and response formatting
|
||||||
|
|
||||||
|
To create a simple tool:
|
||||||
|
1. Inherit from SimpleTool
|
||||||
|
2. Implement get_tool_fields() to define tool-specific fields
|
||||||
|
3. Implement prepare_prompt() for prompt preparation
|
||||||
|
4. Optionally override format_response() for custom formatting
|
||||||
|
5. Optionally override get_required_fields() for custom requirements
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class ChatTool(SimpleTool):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "chat"
|
||||||
|
|
||||||
|
def get_tool_fields(self) -> Dict[str, Dict[str, Any]]:
|
||||||
|
return {
|
||||||
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Your question or idea...",
|
||||||
|
},
|
||||||
|
"files": SimpleTool.FILES_FIELD,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_required_fields(self) -> List[str]:
|
||||||
|
return ["prompt"]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Common field definitions that simple tools can reuse
|
||||||
|
FILES_FIELD = SchemaBuilder.SIMPLE_FIELD_SCHEMAS["files"]
|
||||||
|
IMAGES_FIELD = SchemaBuilder.COMMON_FIELD_SCHEMAS["images"]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_tool_fields(self) -> dict[str, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Return tool-specific field definitions.
|
||||||
|
|
||||||
|
This method should return a dictionary mapping field names to their
|
||||||
|
JSON schema definitions. Common fields (model, temperature, etc.)
|
||||||
|
are added automatically by the base class.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping field names to JSON schema objects
|
||||||
|
|
||||||
|
Example:
|
||||||
|
return {
|
||||||
|
"prompt": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The user's question or request",
|
||||||
|
},
|
||||||
|
"files": SimpleTool.FILES_FIELD, # Reuse common field
|
||||||
|
"max_tokens": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
"description": "Maximum tokens for response",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_required_fields(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Return list of required field names.
|
||||||
|
|
||||||
|
Override this to specify which fields are required for your tool.
|
||||||
|
The model field is automatically added if in auto mode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of required field names
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_annotations(self) -> Optional[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Return tool annotations. Simple tools are read-only by default.
|
||||||
|
|
||||||
|
All simple tools perform operations without modifying the environment.
|
||||||
|
They may call external AI models for analysis or conversation, but they
|
||||||
|
don't write files or make system changes.
|
||||||
|
|
||||||
|
Override this method if your simple tool needs different annotations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with readOnlyHint set to True
|
||||||
|
"""
|
||||||
|
return {"readOnlyHint": True}
|
||||||
|
|
||||||
|
def format_response(self, response: str, request, model_info: Optional[dict] = None) -> str:
|
||||||
|
"""
|
||||||
|
Format the AI response before returning to the client.
|
||||||
|
|
||||||
|
This is a hook method that subclasses can override to customize
|
||||||
|
response formatting. The default implementation returns the response as-is.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The raw response from the AI model
|
||||||
|
request: The validated request object
|
||||||
|
model_info: Optional model information dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted response string
|
||||||
|
"""
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get_input_schema(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate the complete input schema using SchemaBuilder.
|
||||||
|
|
||||||
|
This method automatically combines:
|
||||||
|
- Tool-specific fields from get_tool_fields()
|
||||||
|
- Common fields (temperature, thinking_mode, etc.)
|
||||||
|
- Model field with proper auto-mode handling
|
||||||
|
- Required fields from get_required_fields()
|
||||||
|
|
||||||
|
Tools can override this method for custom schema generation while
|
||||||
|
still benefiting from SimpleTool's convenience methods.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Complete JSON schema for the tool
|
||||||
|
"""
|
||||||
|
required_fields = list(self.get_required_fields())
|
||||||
|
return SchemaBuilder.build_schema(
|
||||||
|
tool_specific_fields=self.get_tool_fields(),
|
||||||
|
required_fields=required_fields,
|
||||||
|
model_field_schema=self.get_model_field_schema(),
|
||||||
|
auto_mode=self.is_effective_auto_mode(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_request_model(self):
|
||||||
|
"""
|
||||||
|
Return the request model class.
|
||||||
|
|
||||||
|
Simple tools use the base ToolRequest by default.
|
||||||
|
Override this if your tool needs a custom request model.
|
||||||
|
"""
|
||||||
|
return ToolRequest
|
||||||
|
|
||||||
|
# Hook methods for safe attribute access without hasattr/getattr
|
||||||
|
|
||||||
|
def get_request_model_name(self, request) -> Optional[str]:
|
||||||
|
"""Get model name from request. Override for custom model name handling."""
|
||||||
|
try:
|
||||||
|
return request.model
|
||||||
|
except AttributeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_request_images(self, request) -> list:
|
||||||
|
"""Get images from request. Override for custom image handling."""
|
||||||
|
try:
|
||||||
|
return request.images if request.images is not None else []
|
||||||
|
except AttributeError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_request_continuation_id(self, request) -> Optional[str]:
|
||||||
|
"""Get continuation_id from request. Override for custom continuation handling."""
|
||||||
|
try:
|
||||||
|
return request.continuation_id
|
||||||
|
except AttributeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_request_prompt(self, request) -> str:
|
||||||
|
"""Get prompt from request. Override for custom prompt handling."""
|
||||||
|
try:
|
||||||
|
return request.prompt
|
||||||
|
except AttributeError:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_request_temperature(self, request) -> Optional[float]:
|
||||||
|
"""Get temperature from request. Override for custom temperature handling."""
|
||||||
|
try:
|
||||||
|
return request.temperature
|
||||||
|
except AttributeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_validated_temperature(self, request, model_context: Any) -> tuple[float, list[str]]:
|
||||||
|
"""
|
||||||
|
Get temperature from request and validate it against model constraints.
|
||||||
|
|
||||||
|
This is a convenience method that combines temperature extraction and validation
|
||||||
|
for simple tools. It ensures temperature is within valid range for the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The request object containing temperature
|
||||||
|
model_context: Model context object containing model info
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (validated_temperature, warning_messages)
|
||||||
|
"""
|
||||||
|
temperature = self.get_request_temperature(request)
|
||||||
|
if temperature is None:
|
||||||
|
temperature = self.get_default_temperature()
|
||||||
|
return self.validate_and_correct_temperature(temperature, model_context)
|
||||||
|
|
||||||
|
def get_request_thinking_mode(self, request) -> Optional[str]:
|
||||||
|
"""Get thinking_mode from request. Override for custom thinking mode handling."""
|
||||||
|
try:
|
||||||
|
return request.thinking_mode
|
||||||
|
except AttributeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_request_files(self, request) -> list:
|
||||||
|
"""Get files from request. Override for custom file handling."""
|
||||||
|
try:
|
||||||
|
return request.files if request.files is not None else []
|
||||||
|
except AttributeError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_request_as_dict(self, request) -> dict:
|
||||||
|
"""Convert request to dictionary. Override for custom serialization."""
|
||||||
|
try:
|
||||||
|
# Try Pydantic v2 method first
|
||||||
|
return request.model_dump()
|
||||||
|
except AttributeError:
|
||||||
|
try:
|
||||||
|
# Fall back to Pydantic v1 method
|
||||||
|
return request.dict()
|
||||||
|
except AttributeError:
|
||||||
|
# Last resort - convert to dict manually
|
||||||
|
return {"prompt": self.get_request_prompt(request)}
|
||||||
|
|
||||||
|
def set_request_files(self, request, files: list) -> None:
|
||||||
|
"""Set files on request. Override for custom file setting."""
|
||||||
|
try:
|
||||||
|
request.files = files
|
||||||
|
except AttributeError:
|
||||||
|
# If request doesn't support file setting, ignore silently
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_actually_processed_files(self) -> list:
|
||||||
|
"""Get actually processed files. Override for custom file tracking."""
|
||||||
|
try:
|
||||||
|
return self._actually_processed_files
|
||||||
|
except AttributeError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def execute(self, arguments: dict[str, Any]) -> list:
|
||||||
|
"""
|
||||||
|
Execute the simple tool using the comprehensive flow from old base.py.
|
||||||
|
|
||||||
|
This method replicates the proven execution pattern while using SimpleTool hooks.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from mcp.types import TextContent
|
||||||
|
|
||||||
|
from tools.models import ToolOutput
|
||||||
|
|
||||||
|
logger = logging.getLogger(f"tools.{self.get_name()}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Store arguments for access by helper methods
|
||||||
|
self._current_arguments = arguments
|
||||||
|
|
||||||
|
logger.info(f"🔧 {self.get_name()} tool called with arguments: {list(arguments.keys())}")
|
||||||
|
|
||||||
|
# Validate request using the tool's Pydantic model
|
||||||
|
request_model = self.get_request_model()
|
||||||
|
request = request_model(**arguments)
|
||||||
|
logger.debug(f"Request validation successful for {self.get_name()}")
|
||||||
|
|
||||||
|
# Validate file paths for security
|
||||||
|
# This prevents path traversal attacks and ensures proper access control
|
||||||
|
path_error = self._validate_file_paths(request)
|
||||||
|
if path_error:
|
||||||
|
error_output = ToolOutput(
|
||||||
|
status="error",
|
||||||
|
content=path_error,
|
||||||
|
content_type="text",
|
||||||
|
)
|
||||||
|
return [TextContent(type="text", text=error_output.model_dump_json())]
|
||||||
|
|
||||||
|
# Handle model resolution like old base.py
|
||||||
|
model_name = self.get_request_model_name(request)
|
||||||
|
if not model_name:
|
||||||
|
from config import DEFAULT_MODEL
|
||||||
|
|
||||||
|
model_name = DEFAULT_MODEL
|
||||||
|
|
||||||
|
# Store the current model name for later use
|
||||||
|
self._current_model_name = model_name
|
||||||
|
|
||||||
|
# Handle model context from arguments (for in-process testing)
|
||||||
|
if "_model_context" in arguments:
|
||||||
|
self._model_context = arguments["_model_context"]
|
||||||
|
logger.debug(f"{self.get_name()}: Using model context from arguments")
|
||||||
|
else:
|
||||||
|
# Create model context if not provided
|
||||||
|
from utils.model_context import ModelContext
|
||||||
|
|
||||||
|
self._model_context = ModelContext(model_name)
|
||||||
|
logger.debug(f"{self.get_name()}: Created model context for {model_name}")
|
||||||
|
|
||||||
|
# Get images if present
|
||||||
|
images = self.get_request_images(request)
|
||||||
|
continuation_id = self.get_request_continuation_id(request)
|
||||||
|
|
||||||
|
# Handle conversation history and prompt preparation
|
||||||
|
if continuation_id:
|
||||||
|
# Check if conversation history is already embedded
|
||||||
|
field_value = self.get_request_prompt(request)
|
||||||
|
if "=== CONVERSATION HISTORY ===" in field_value:
|
||||||
|
# Use pre-embedded history
|
||||||
|
prompt = field_value
|
||||||
|
logger.debug(f"{self.get_name()}: Using pre-embedded conversation history")
|
||||||
|
else:
|
||||||
|
# No embedded history - reconstruct it (for in-process calls)
|
||||||
|
logger.debug(f"{self.get_name()}: No embedded history found, reconstructing conversation")
|
||||||
|
|
||||||
|
# Get thread context
|
||||||
|
from utils.conversation_memory import add_turn, build_conversation_history, get_thread
|
||||||
|
|
||||||
|
thread_context = get_thread(continuation_id)
|
||||||
|
|
||||||
|
if thread_context:
|
||||||
|
# Add user's new input to conversation
|
||||||
|
user_prompt = self.get_request_prompt(request)
|
||||||
|
user_files = self.get_request_files(request)
|
||||||
|
if user_prompt:
|
||||||
|
add_turn(continuation_id, "user", user_prompt, files=user_files)
|
||||||
|
|
||||||
|
# Get updated thread context after adding the turn
|
||||||
|
thread_context = get_thread(continuation_id)
|
||||||
|
logger.debug(
|
||||||
|
f"{self.get_name()}: Retrieved updated thread with {len(thread_context.turns)} turns"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build conversation history with updated thread context
|
||||||
|
conversation_history, conversation_tokens = build_conversation_history(
|
||||||
|
thread_context, self._model_context
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the base prompt from the tool
|
||||||
|
base_prompt = await self.prepare_prompt(request)
|
||||||
|
|
||||||
|
# Combine with conversation history
|
||||||
|
if conversation_history:
|
||||||
|
prompt = f"{conversation_history}\n\n=== NEW USER INPUT ===\n{base_prompt}"
|
||||||
|
else:
|
||||||
|
prompt = base_prompt
|
||||||
|
else:
|
||||||
|
# Thread not found, prepare normally
|
||||||
|
logger.warning(f"Thread {continuation_id} not found, preparing prompt normally")
|
||||||
|
prompt = await self.prepare_prompt(request)
|
||||||
|
else:
|
||||||
|
# New conversation, prepare prompt normally
|
||||||
|
prompt = await self.prepare_prompt(request)
|
||||||
|
|
||||||
|
# Add follow-up instructions for new conversations
|
||||||
|
from server import get_follow_up_instructions
|
||||||
|
|
||||||
|
follow_up_instructions = get_follow_up_instructions(0)
|
||||||
|
prompt = f"{prompt}\n\n{follow_up_instructions}"
|
||||||
|
logger.debug(
|
||||||
|
f"Added follow-up instructions for new {self.get_name()} conversation"
|
||||||
|
) # Validate images if any were provided
|
||||||
|
if images:
|
||||||
|
image_validation_error = self._validate_image_limits(
|
||||||
|
images, model_context=self._model_context, continuation_id=continuation_id
|
||||||
|
)
|
||||||
|
if image_validation_error:
|
||||||
|
return [TextContent(type="text", text=json.dumps(image_validation_error, ensure_ascii=False))]
|
||||||
|
|
||||||
|
# Get and validate temperature against model constraints
|
||||||
|
temperature, temp_warnings = self.get_validated_temperature(request, self._model_context)
|
||||||
|
|
||||||
|
# Log any temperature corrections
|
||||||
|
for warning in temp_warnings:
|
||||||
|
# Get thinking mode with defaults
|
||||||
|
logger.warning(warning)
|
||||||
|
thinking_mode = self.get_request_thinking_mode(request)
|
||||||
|
if thinking_mode is None:
|
||||||
|
thinking_mode = self.get_default_thinking_mode()
|
||||||
|
|
||||||
|
# Get the provider from model context (clean OOP - no re-fetching)
|
||||||
|
provider = self._model_context.provider
|
||||||
|
|
||||||
|
# Get system prompt for this tool
|
||||||
|
base_system_prompt = self.get_system_prompt()
|
||||||
|
language_instruction = self.get_language_instruction()
|
||||||
|
system_prompt = language_instruction + base_system_prompt
|
||||||
|
|
||||||
|
# Generate AI response using the provider
|
||||||
|
logger.info(f"Sending request to {provider.get_provider_type().value} API for {self.get_name()}")
|
||||||
|
logger.info(
|
||||||
|
f"Using model: {self._model_context.model_name} via {provider.get_provider_type().value} provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Estimate tokens for logging
|
||||||
|
from utils.token_utils import estimate_tokens
|
||||||
|
|
||||||
|
estimated_tokens = estimate_tokens(prompt)
|
||||||
|
logger.debug(f"Prompt length: {len(prompt)} characters (~{estimated_tokens:,} tokens)")
|
||||||
|
|
||||||
|
# Resolve model capabilities for feature gating
|
||||||
|
capabilities = self._model_context.capabilities
|
||||||
|
supports_thinking = capabilities.supports_extended_thinking
|
||||||
|
|
||||||
|
# Generate content with provider abstraction
|
||||||
|
model_response = provider.generate_content(
|
||||||
|
prompt=prompt,
|
||||||
|
model_name=self._current_model_name,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
thinking_mode=thinking_mode if supports_thinking else None,
|
||||||
|
images=images if images else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Received response from {provider.get_provider_type().value} API for {self.get_name()}")
|
||||||
|
|
||||||
|
# Process the model's response
|
||||||
|
if model_response.content:
|
||||||
|
raw_text = model_response.content
|
||||||
|
|
||||||
|
# Create model info for conversation tracking
|
||||||
|
model_info = {
|
||||||
|
"provider": provider,
|
||||||
|
"model_name": self._current_model_name,
|
||||||
|
"model_response": model_response,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse response using the same logic as old base.py
|
||||||
|
tool_output = self._parse_response(raw_text, request, model_info)
|
||||||
|
logger.info(f"✅ {self.get_name()} tool completed successfully")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Handle cases where the model couldn't generate a response
|
||||||
|
metadata = model_response.metadata or {}
|
||||||
|
finish_reason = metadata.get("finish_reason", "Unknown")
|
||||||
|
|
||||||
|
if metadata.get("is_blocked_by_safety"):
|
||||||
|
# Specific handling for content safety blocks
|
||||||
|
safety_details = metadata.get("safety_feedback") or "details not provided"
|
||||||
|
logger.warning(
|
||||||
|
f"Response blocked by content safety policy for {self.get_name()}. "
|
||||||
|
f"Reason: {finish_reason}, Details: {safety_details}"
|
||||||
|
)
|
||||||
|
tool_output = ToolOutput(
|
||||||
|
status="error",
|
||||||
|
content="Your request was blocked by the content safety policy. "
|
||||||
|
"Please try modifying your prompt.",
|
||||||
|
content_type="text",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Handle other empty responses - could be legitimate completion or unclear blocking
|
||||||
|
if finish_reason == "STOP":
|
||||||
|
# Model completed normally but returned empty content - retry with clarification
|
||||||
|
logger.info(
|
||||||
|
f"Model completed with empty response for {self.get_name()}, retrying with clarification"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retry the same request with modified prompt asking for explicit response
|
||||||
|
original_prompt = prompt
|
||||||
|
retry_prompt = f"{original_prompt}\n\nIMPORTANT: Please provide a substantive response. If you cannot respond to the above request, please explain why and suggest alternatives."
|
||||||
|
|
||||||
|
try:
|
||||||
|
retry_response = provider.generate_content(
|
||||||
|
prompt=retry_prompt,
|
||||||
|
model_name=self._current_model_name,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
thinking_mode=thinking_mode if supports_thinking else None,
|
||||||
|
images=images if images else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if retry_response.content:
|
||||||
|
# Successful retry - use the retry response
|
||||||
|
logger.info(f"Retry successful for {self.get_name()}")
|
||||||
|
raw_text = retry_response.content
|
||||||
|
|
||||||
|
# Update model info for the successful retry
|
||||||
|
model_info = {
|
||||||
|
"provider": provider,
|
||||||
|
"model_name": self._current_model_name,
|
||||||
|
"model_response": retry_response,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse the retry response
|
||||||
|
tool_output = self._parse_response(raw_text, request, model_info)
|
||||||
|
logger.info(f"✅ {self.get_name()} tool completed successfully after retry")
|
||||||
|
else:
|
||||||
|
# Retry also failed - inspect metadata to find out why
|
||||||
|
retry_metadata = retry_response.metadata or {}
|
||||||
|
if retry_metadata.get("is_blocked_by_safety"):
|
||||||
|
# The retry was blocked by safety filters
|
||||||
|
safety_details = retry_metadata.get("safety_feedback") or "details not provided"
|
||||||
|
logger.warning(
|
||||||
|
f"Retry for {self.get_name()} was blocked by content safety policy. "
|
||||||
|
f"Details: {safety_details}"
|
||||||
|
)
|
||||||
|
tool_output = ToolOutput(
|
||||||
|
status="error",
|
||||||
|
content="Your request was also blocked by the content safety policy after a retry. "
|
||||||
|
"Please try rephrasing your prompt significantly.",
|
||||||
|
content_type="text",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Retry failed for other reasons (e.g., another STOP)
|
||||||
|
tool_output = ToolOutput(
|
||||||
|
status="error",
|
||||||
|
content="The model repeatedly returned empty responses. This may indicate content filtering or a model issue.",
|
||||||
|
content_type="text",
|
||||||
|
)
|
||||||
|
except Exception as retry_error:
|
||||||
|
logger.warning(f"Retry failed for {self.get_name()}: {retry_error}")
|
||||||
|
tool_output = ToolOutput(
|
||||||
|
status="error",
|
||||||
|
content=f"Model returned empty response and retry failed: {str(retry_error)}",
|
||||||
|
content_type="text",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Non-STOP finish reasons are likely actual errors
|
||||||
|
logger.warning(
|
||||||
|
f"Response blocked or incomplete for {self.get_name()}. Finish reason: {finish_reason}"
|
||||||
|
)
|
||||||
|
tool_output = ToolOutput(
|
||||||
|
status="error",
|
||||||
|
content=f"Response blocked or incomplete. Finish reason: {finish_reason}",
|
||||||
|
content_type="text",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return the tool output as TextContent
|
||||||
|
return [TextContent(type="text", text=tool_output.model_dump_json())]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Special handling for MCP size check errors
|
||||||
|
if str(e).startswith("MCP_SIZE_CHECK:"):
|
||||||
|
# Extract the JSON content after the prefix
|
||||||
|
json_content = str(e)[len("MCP_SIZE_CHECK:") :]
|
||||||
|
return [TextContent(type="text", text=json_content)]
|
||||||
|
|
||||||
|
logger.error(f"Error in {self.get_name()}: {str(e)}")
|
||||||
|
error_output = ToolOutput(
|
||||||
|
status="error",
|
||||||
|
content=f"Error in {self.get_name()}: {str(e)}",
|
||||||
|
content_type="text",
|
||||||
|
)
|
||||||
|
return [TextContent(type="text", text=error_output.model_dump_json())]
|
||||||
|
|
||||||
|
def _parse_response(self, raw_text: str, request, model_info: Optional[dict] = None):
|
||||||
|
"""
|
||||||
|
Parse the raw response and format it using the hook method.
|
||||||
|
|
||||||
|
This simplified version focuses on the SimpleTool pattern: format the response
|
||||||
|
using the format_response hook, then handle conversation continuation.
|
||||||
|
"""
|
||||||
|
from tools.models import ToolOutput
|
||||||
|
|
||||||
|
# Format the response using the hook method
|
||||||
|
formatted_response = self.format_response(raw_text, request, model_info)
|
||||||
|
|
||||||
|
# Handle conversation continuation like old base.py
|
||||||
|
continuation_id = self.get_request_continuation_id(request)
|
||||||
|
if continuation_id:
|
||||||
|
self._record_assistant_turn(continuation_id, raw_text, request, model_info)
|
||||||
|
|
||||||
|
# Create continuation offer like old base.py
|
||||||
|
continuation_data = self._create_continuation_offer(request, model_info)
|
||||||
|
if continuation_data:
|
||||||
|
return self._create_continuation_offer_response(formatted_response, continuation_data, request, model_info)
|
||||||
|
else:
|
||||||
|
# Build metadata with model and provider info for success response
|
||||||
|
metadata = {}
|
||||||
|
if model_info:
|
||||||
|
model_name = model_info.get("model_name")
|
||||||
|
if model_name:
|
||||||
|
metadata["model_used"] = model_name
|
||||||
|
provider = model_info.get("provider")
|
||||||
|
if provider:
|
||||||
|
# Handle both provider objects and string values
|
||||||
|
if isinstance(provider, str):
|
||||||
|
metadata["provider_used"] = provider
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
metadata["provider_used"] = provider.get_provider_type().value
|
||||||
|
except AttributeError:
|
||||||
|
# Fallback if provider doesn't have get_provider_type method
|
||||||
|
metadata["provider_used"] = str(provider)
|
||||||
|
|
||||||
|
return ToolOutput(
|
||||||
|
status="success",
|
||||||
|
content=formatted_response,
|
||||||
|
content_type="text",
|
||||||
|
metadata=metadata if metadata else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_continuation_offer(self, request, model_info: Optional[dict] = None):
|
||||||
|
"""Create continuation offer following old base.py pattern"""
|
||||||
|
continuation_id = self.get_request_continuation_id(request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from utils.conversation_memory import create_thread, get_thread
|
||||||
|
|
||||||
|
if continuation_id:
|
||||||
|
# Existing conversation
|
||||||
|
thread_context = get_thread(continuation_id)
|
||||||
|
if thread_context and thread_context.turns:
|
||||||
|
turn_count = len(thread_context.turns)
|
||||||
|
from utils.conversation_memory import MAX_CONVERSATION_TURNS
|
||||||
|
|
||||||
|
if turn_count >= MAX_CONVERSATION_TURNS - 1:
|
||||||
|
return None # No more turns allowed
|
||||||
|
|
||||||
|
remaining_turns = MAX_CONVERSATION_TURNS - turn_count - 1
|
||||||
|
return {
|
||||||
|
"continuation_id": continuation_id,
|
||||||
|
"remaining_turns": remaining_turns,
|
||||||
|
"note": f"Claude can continue this conversation for {remaining_turns} more exchanges.",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# New conversation - create thread and offer continuation
|
||||||
|
# Convert request to dict for initial_context
|
||||||
|
initial_request_dict = self.get_request_as_dict(request)
|
||||||
|
|
||||||
|
new_thread_id = create_thread(tool_name=self.get_name(), initial_request=initial_request_dict)
|
||||||
|
|
||||||
|
# Add the initial user turn to the new thread
|
||||||
|
from utils.conversation_memory import MAX_CONVERSATION_TURNS, add_turn
|
||||||
|
|
||||||
|
user_prompt = self.get_request_prompt(request)
|
||||||
|
user_files = self.get_request_files(request)
|
||||||
|
user_images = self.get_request_images(request)
|
||||||
|
|
||||||
|
# Add user's initial turn
|
||||||
|
add_turn(
|
||||||
|
new_thread_id, "user", user_prompt, files=user_files, images=user_images, tool_name=self.get_name()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"continuation_id": new_thread_id,
|
||||||
|
"remaining_turns": MAX_CONVERSATION_TURNS - 1,
|
||||||
|
"note": f"Claude can continue this conversation for {MAX_CONVERSATION_TURNS - 1} more exchanges.",
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _create_continuation_offer_response(
|
||||||
|
self, content: str, continuation_data: dict, request, model_info: Optional[dict] = None
|
||||||
|
):
|
||||||
|
"""Create response with continuation offer following old base.py pattern"""
|
||||||
|
from tools.models import ContinuationOffer, ToolOutput
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self.get_request_continuation_id(request):
|
||||||
|
self._record_assistant_turn(
|
||||||
|
continuation_data["continuation_id"],
|
||||||
|
content,
|
||||||
|
request,
|
||||||
|
model_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
continuation_offer = ContinuationOffer(
|
||||||
|
continuation_id=continuation_data["continuation_id"],
|
||||||
|
note=continuation_data["note"],
|
||||||
|
remaining_turns=continuation_data["remaining_turns"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build metadata with model and provider info
|
||||||
|
metadata = {"tool_name": self.get_name(), "conversation_ready": True}
|
||||||
|
if model_info:
|
||||||
|
model_name = model_info.get("model_name")
|
||||||
|
if model_name:
|
||||||
|
metadata["model_used"] = model_name
|
||||||
|
provider = model_info.get("provider")
|
||||||
|
if provider:
|
||||||
|
# Handle both provider objects and string values
|
||||||
|
if isinstance(provider, str):
|
||||||
|
metadata["provider_used"] = provider
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
metadata["provider_used"] = provider.get_provider_type().value
|
||||||
|
except AttributeError:
|
||||||
|
# Fallback if provider doesn't have get_provider_type method
|
||||||
|
metadata["provider_used"] = str(provider)
|
||||||
|
|
||||||
|
return ToolOutput(
|
||||||
|
status="continuation_available",
|
||||||
|
content=content,
|
||||||
|
content_type="text",
|
||||||
|
continuation_offer=continuation_offer,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Fallback to simple success if continuation offer fails
|
||||||
|
return ToolOutput(status="success", content=content, content_type="text")
|
||||||
|
|
||||||
|
def _record_assistant_turn(
|
||||||
|
self, continuation_id: str, response_text: str, request, model_info: Optional[dict]
|
||||||
|
) -> None:
|
||||||
|
"""Persist an assistant response in conversation memory."""
|
||||||
|
|
||||||
|
if not continuation_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
from utils.conversation_memory import add_turn
|
||||||
|
|
||||||
|
model_provider = None
|
||||||
|
model_name = None
|
||||||
|
model_metadata = None
|
||||||
|
|
||||||
|
if model_info:
|
||||||
|
provider = model_info.get("provider")
|
||||||
|
if provider:
|
||||||
|
if isinstance(provider, str):
|
||||||
|
model_provider = provider
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
model_provider = provider.get_provider_type().value
|
||||||
|
except AttributeError:
|
||||||
|
model_provider = str(provider)
|
||||||
|
model_name = model_info.get("model_name")
|
||||||
|
model_response = model_info.get("model_response")
|
||||||
|
if model_response:
|
||||||
|
model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata}
|
||||||
|
|
||||||
|
add_turn(
|
||||||
|
continuation_id,
|
||||||
|
"assistant",
|
||||||
|
response_text,
|
||||||
|
files=self.get_request_files(request),
|
||||||
|
images=self.get_request_images(request),
|
||||||
|
tool_name=self.get_name(),
|
||||||
|
model_provider=model_provider,
|
||||||
|
model_name=model_name,
|
||||||
|
model_metadata=model_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convenience methods for common tool patterns
|
||||||
|
|
||||||
|
def build_standard_prompt(
|
||||||
|
self, system_prompt: str, user_content: str, request, file_context_title: str = "CONTEXT FILES"
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Build a standard prompt with system prompt, user content, and optional files.
|
||||||
|
|
||||||
|
This is a convenience method that handles the common pattern of:
|
||||||
|
1. Adding file content if present
|
||||||
|
2. Checking token limits
|
||||||
|
3. Adding web search instructions
|
||||||
|
4. Combining everything into a well-formatted prompt
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt: The system prompt for the tool
|
||||||
|
user_content: The main user request/content
|
||||||
|
request: The validated request object
|
||||||
|
file_context_title: Title for the file context section
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Complete formatted prompt ready for the AI model
|
||||||
|
"""
|
||||||
|
# Add context files if provided
|
||||||
|
files = self.get_request_files(request)
|
||||||
|
if files:
|
||||||
|
file_content, processed_files = self._prepare_file_content_for_prompt(
|
||||||
|
files,
|
||||||
|
self.get_request_continuation_id(request),
|
||||||
|
"Context files",
|
||||||
|
model_context=getattr(self, "_model_context", None),
|
||||||
|
)
|
||||||
|
self._actually_processed_files = processed_files
|
||||||
|
if file_content:
|
||||||
|
user_content = f"{user_content}\n\n=== {file_context_title} ===\n{file_content}\n=== END CONTEXT ===="
|
||||||
|
|
||||||
|
# Check token limits - only validate original user prompt, not conversation history
|
||||||
|
content_to_validate = self.get_prompt_content_for_size_validation(user_content)
|
||||||
|
self._validate_token_limit(content_to_validate, "Content")
|
||||||
|
|
||||||
|
# Add standardized web search guidance
|
||||||
|
websearch_instruction = self.get_websearch_instruction(True, self.get_websearch_guidance())
|
||||||
|
|
||||||
|
# Combine system prompt with user content
|
||||||
|
full_prompt = f"""{system_prompt}{websearch_instruction}
|
||||||
|
|
||||||
|
=== USER REQUEST ===
|
||||||
|
{user_content}
|
||||||
|
=== END REQUEST ===
|
||||||
|
|
||||||
|
Please provide a thoughtful, comprehensive response:"""
|
||||||
|
|
||||||
|
return full_prompt
|
||||||
|
|
||||||
|
def get_prompt_content_for_size_validation(self, user_content: str) -> str:
|
||||||
|
"""
|
||||||
|
Override to use original user prompt for size validation when conversation history is embedded.
|
||||||
|
|
||||||
|
When server.py embeds conversation history into the prompt field, it also stores
|
||||||
|
the original user prompt in _original_user_prompt. We use that for size validation
|
||||||
|
to avoid incorrectly triggering size limits due to conversation history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_content: The user content (may include conversation history)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The original user prompt if available, otherwise the full user content
|
||||||
|
"""
|
||||||
|
# Check if we have the current arguments from execute() method
|
||||||
|
current_args = getattr(self, "_current_arguments", None)
|
||||||
|
if current_args:
|
||||||
|
# If server.py embedded conversation history, it stores original prompt separately
|
||||||
|
original_user_prompt = current_args.get("_original_user_prompt")
|
||||||
|
if original_user_prompt is not None:
|
||||||
|
# Use original user prompt for size validation (excludes conversation history)
|
||||||
|
return original_user_prompt
|
||||||
|
|
||||||
|
# Fallback to default behavior (validate full user content)
|
||||||
|
return user_content
|
||||||
|
|
||||||
|
def get_websearch_guidance(self) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Return tool-specific web search guidance.
|
||||||
|
|
||||||
|
Override this to provide tool-specific guidance for when web searches
|
||||||
|
would be helpful. Return None to use the default guidance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool-specific web search guidance or None for default
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def handle_prompt_file_with_fallback(self, request) -> str:
|
||||||
|
"""
|
||||||
|
Handle prompt.txt files with fallback to request field.
|
||||||
|
|
||||||
|
This is a convenience method for tools that accept prompts either
|
||||||
|
as a field or as a prompt.txt file. It handles the extraction
|
||||||
|
and validation automatically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The validated request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The effective prompt content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If prompt is too large for MCP transport
|
||||||
|
"""
|
||||||
|
# Check for prompt.txt in files
|
||||||
|
files = self.get_request_files(request)
|
||||||
|
if files:
|
||||||
|
prompt_content, updated_files = self.handle_prompt_file(files)
|
||||||
|
|
||||||
|
# Update request files list if needed
|
||||||
|
if updated_files is not None:
|
||||||
|
self.set_request_files(request, updated_files)
|
||||||
|
else:
|
||||||
|
prompt_content = None
|
||||||
|
|
||||||
|
# Use prompt.txt content if available, otherwise use the prompt field
|
||||||
|
user_content = prompt_content if prompt_content else self.get_request_prompt(request)
|
||||||
|
|
||||||
|
# Check user input size at MCP transport boundary (excluding conversation history)
|
||||||
|
validation_content = self.get_prompt_content_for_size_validation(user_content)
|
||||||
|
size_check = self.check_prompt_size(validation_content)
|
||||||
|
if size_check:
|
||||||
|
from tools.models import ToolOutput
|
||||||
|
|
||||||
|
raise ValueError(f"MCP_SIZE_CHECK:{ToolOutput(**size_check).model_dump_json()}")
|
||||||
|
|
||||||
|
return user_content
|
||||||
|
|
||||||
|
def get_chat_style_websearch_guidance(self) -> str:
|
||||||
|
"""
|
||||||
|
Get Chat tool-style web search guidance.
|
||||||
|
|
||||||
|
Returns web search guidance that matches the original Chat tool pattern.
|
||||||
|
This is useful for tools that want to maintain the same search behavior.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Web search guidance text
|
||||||
|
"""
|
||||||
|
return """When discussing topics, consider if searches for these would help:
|
||||||
|
- Documentation for any technologies or concepts mentioned
|
||||||
|
- Current best practices and patterns
|
||||||
|
- Recent developments or updates
|
||||||
|
- Community discussions and solutions"""
|
||||||
|
|
||||||
|
def supports_custom_request_model(self) -> bool:
|
||||||
|
"""
|
||||||
|
Indicate whether this tool supports custom request models.
|
||||||
|
|
||||||
|
Simple tools support custom request models by default. Tools that override
|
||||||
|
get_request_model() to return something other than ToolRequest should
|
||||||
|
return True here.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the tool uses a custom request model
|
||||||
|
"""
|
||||||
|
return self.get_request_model() != ToolRequest
|
||||||
|
|
||||||
|
def _validate_file_paths(self, request) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Validate that all file paths in the request are absolute paths.
|
||||||
|
|
||||||
|
This is a security measure to prevent path traversal attacks and ensure
|
||||||
|
proper access control. All file paths must be absolute (starting with '/').
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The validated request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: Error message if validation fails, None if all paths are valid
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Check if request has 'files' attribute (used by most tools)
|
||||||
|
files = self.get_request_files(request)
|
||||||
|
if files:
|
||||||
|
for file_path in files:
|
||||||
|
if not os.path.isabs(file_path):
|
||||||
|
return (
|
||||||
|
f"Error: All file paths must be FULL absolute paths to real files / folders - DO NOT SHORTEN. "
|
||||||
|
f"Received relative path: {file_path}\n"
|
||||||
|
f"Please provide the full absolute path starting with '/' (must be FULL absolute paths to real files / folders - DO NOT SHORTEN)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def prepare_chat_style_prompt(self, request, system_prompt: str = None) -> str:
|
||||||
|
"""
|
||||||
|
Prepare a prompt using Chat tool-style patterns.
|
||||||
|
|
||||||
|
This convenience method replicates the Chat tool's prompt preparation logic:
|
||||||
|
1. Handle prompt.txt file if present
|
||||||
|
2. Add file context with specific formatting
|
||||||
|
3. Add web search guidance
|
||||||
|
4. Format with system prompt
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The validated request object
|
||||||
|
system_prompt: System prompt to use (uses get_system_prompt() if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Complete formatted prompt
|
||||||
|
"""
|
||||||
|
# Use provided system prompt or get from tool
|
||||||
|
if system_prompt is None:
|
||||||
|
system_prompt = self.get_system_prompt()
|
||||||
|
|
||||||
|
# Get user content (handles prompt.txt files)
|
||||||
|
user_content = self.handle_prompt_file_with_fallback(request)
|
||||||
|
|
||||||
|
# Build standard prompt with Chat-style web search guidance
|
||||||
|
websearch_guidance = self.get_chat_style_websearch_guidance()
|
||||||
|
|
||||||
|
# Override the websearch guidance temporarily
|
||||||
|
original_guidance = self.get_websearch_guidance
|
||||||
|
self.get_websearch_guidance = lambda: websearch_guidance
|
||||||
|
|
||||||
|
try:
|
||||||
|
full_prompt = self.build_standard_prompt(system_prompt, user_content, request, "CONTEXT FILES")
|
||||||
|
finally:
|
||||||
|
# Restore original guidance method
|
||||||
|
self.get_websearch_guidance = original_guidance
|
||||||
|
|
||||||
|
return full_prompt
|
||||||
368
tools/version.py
Normal file
368
tools/version.py
Normal file
|
|
@ -0,0 +1,368 @@
|
||||||
|
"""
|
||||||
|
Version Tool - Display Zen MCP Server version and system information
|
||||||
|
|
||||||
|
This tool provides version information about the Zen MCP Server including
|
||||||
|
version number, last update date, author, and basic system information.
|
||||||
|
It also checks for updates from the GitHub repository.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import platform
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
from urllib.error import HTTPError, URLError
|
||||||
|
from urllib.request import urlopen
|
||||||
|
|
||||||
|
HAS_URLLIB = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_URLLIB = False
|
||||||
|
|
||||||
|
from mcp.types import TextContent
|
||||||
|
|
||||||
|
from config import __author__, __updated__, __version__
|
||||||
|
from tools.models import ToolModelCategory, ToolOutput
|
||||||
|
from tools.shared.base_models import ToolRequest
|
||||||
|
from tools.shared.base_tool import BaseTool
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||||
|
"""
|
||||||
|
Parse version string to tuple of integers for comparison.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version_str: Version string like "5.5.5"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (major, minor, patch) as integers
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parts = version_str.strip().split(".")
|
||||||
|
if len(parts) >= 3:
|
||||||
|
return (int(parts[0]), int(parts[1]), int(parts[2]))
|
||||||
|
elif len(parts) == 2:
|
||||||
|
return (int(parts[0]), int(parts[1]), 0)
|
||||||
|
elif len(parts) == 1:
|
||||||
|
return (int(parts[0]), 0, 0)
|
||||||
|
else:
|
||||||
|
return (0, 0, 0)
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return (0, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def compare_versions(current: str, remote: str) -> int:
|
||||||
|
"""
|
||||||
|
Compare two version strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current: Current version string
|
||||||
|
remote: Remote version string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
-1 if current < remote (update available)
|
||||||
|
0 if current == remote (up to date)
|
||||||
|
1 if current > remote (ahead of remote)
|
||||||
|
"""
|
||||||
|
current_tuple = parse_version(current)
|
||||||
|
remote_tuple = parse_version(remote)
|
||||||
|
|
||||||
|
if current_tuple < remote_tuple:
|
||||||
|
return -1
|
||||||
|
elif current_tuple > remote_tuple:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_github_version() -> Optional[tuple[str, str]]:
|
||||||
|
"""
|
||||||
|
Fetch the latest version information from GitHub repository.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (version, last_updated) if successful, None if failed
|
||||||
|
"""
|
||||||
|
if not HAS_URLLIB:
|
||||||
|
logger.warning("urllib not available, cannot check for updates")
|
||||||
|
return None
|
||||||
|
|
||||||
|
github_url = "https://raw.githubusercontent.com/BeehiveInnovations/zen-mcp-server/main/config.py"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Set a 10-second timeout
|
||||||
|
with urlopen(github_url, timeout=10) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
logger.warning(f"HTTP error while checking GitHub: {response.status}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = response.read().decode("utf-8")
|
||||||
|
|
||||||
|
# Extract version using regex
|
||||||
|
version_match = re.search(r'__version__\s*=\s*["\']([^"\']+)["\']', content)
|
||||||
|
updated_match = re.search(r'__updated__\s*=\s*["\']([^"\']+)["\']', content)
|
||||||
|
|
||||||
|
if version_match:
|
||||||
|
remote_version = version_match.group(1)
|
||||||
|
remote_updated = updated_match.group(1) if updated_match else "Unknown"
|
||||||
|
return (remote_version, remote_updated)
|
||||||
|
else:
|
||||||
|
logger.warning("Could not parse version from GitHub config.py")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except HTTPError as e:
|
||||||
|
logger.warning(f"HTTP error while checking GitHub: {e.code}")
|
||||||
|
return None
|
||||||
|
except URLError as e:
|
||||||
|
logger.warning(f"URL error while checking GitHub: {e.reason}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error checking GitHub for updates: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class VersionTool(BaseTool):
|
||||||
|
"""
|
||||||
|
Tool for displaying Zen MCP Server version and system information.
|
||||||
|
|
||||||
|
This tool provides:
|
||||||
|
- Current server version
|
||||||
|
- Last update date
|
||||||
|
- Author information
|
||||||
|
- Python version
|
||||||
|
- Platform information
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "version"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Get server version, configuration details, and list of available tools."
|
||||||
|
|
||||||
|
def get_input_schema(self) -> dict[str, Any]:
|
||||||
|
"""Return the JSON schema for the tool's input"""
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"model": {"type": "string", "description": "Model to use (ignored by version tool)"}},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_annotations(self) -> Optional[dict[str, Any]]:
|
||||||
|
"""Return tool annotations indicating this is a read-only tool"""
|
||||||
|
return {"readOnlyHint": True}
|
||||||
|
|
||||||
|
def get_system_prompt(self) -> str:
|
||||||
|
"""No AI model needed for this tool"""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_request_model(self):
|
||||||
|
"""Return the Pydantic model for request validation."""
|
||||||
|
return ToolRequest
|
||||||
|
|
||||||
|
def requires_model(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def prepare_prompt(self, request: ToolRequest) -> str:
|
||||||
|
"""Not used for this utility tool"""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def format_response(self, response: str, request: ToolRequest, model_info: dict = None) -> str:
|
||||||
|
"""Not used for this utility tool"""
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def execute(self, arguments: dict[str, Any]) -> list[TextContent]:
|
||||||
|
"""
|
||||||
|
Display Zen MCP Server version and system information.
|
||||||
|
|
||||||
|
This overrides the base class execute to provide direct output without AI model calls.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
arguments: Standard tool arguments (none required)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted version and system information
|
||||||
|
"""
|
||||||
|
output_lines = ["# Zen MCP Server Version\n"]
|
||||||
|
|
||||||
|
# Server version information
|
||||||
|
output_lines.append("## Server Information")
|
||||||
|
output_lines.append(f"**Current Version**: {__version__}")
|
||||||
|
output_lines.append(f"**Last Updated**: {__updated__}")
|
||||||
|
output_lines.append(f"**Author**: {__author__}")
|
||||||
|
|
||||||
|
model_selection_metadata = {"mode": "unknown", "default_model": None}
|
||||||
|
model_selection_display = "Model selection status unavailable"
|
||||||
|
|
||||||
|
# Model selection configuration
|
||||||
|
try:
|
||||||
|
from config import DEFAULT_MODEL
|
||||||
|
from tools.shared.base_tool import BaseTool
|
||||||
|
|
||||||
|
auto_mode = BaseTool.is_effective_auto_mode(self)
|
||||||
|
if auto_mode:
|
||||||
|
output_lines.append(
|
||||||
|
"**Model Selection**: Auto model selection mode (call `listmodels` to inspect options)"
|
||||||
|
)
|
||||||
|
model_selection_metadata = {"mode": "auto", "default_model": DEFAULT_MODEL}
|
||||||
|
model_selection_display = "Auto model selection (use `listmodels` for options)"
|
||||||
|
else:
|
||||||
|
output_lines.append(f"**Model Selection**: Default model set to `{DEFAULT_MODEL}`")
|
||||||
|
model_selection_metadata = {"mode": "default", "default_model": DEFAULT_MODEL}
|
||||||
|
model_selection_display = f"Default model: `{DEFAULT_MODEL}`"
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(f"Could not determine model selection mode: {exc}")
|
||||||
|
|
||||||
|
output_lines.append("")
|
||||||
|
output_lines.append("## Quick Summary — relay everything below")
|
||||||
|
output_lines.append(f"- Version `{__version__}` (updated {__updated__})")
|
||||||
|
output_lines.append(f"- {model_selection_display}")
|
||||||
|
output_lines.append("- Run `listmodels` for the complete model catalog and capabilities")
|
||||||
|
output_lines.append("")
|
||||||
|
|
||||||
|
# Try to get client information
|
||||||
|
try:
|
||||||
|
# We need access to the server instance
|
||||||
|
# This is a bit hacky but works for now
|
||||||
|
import server as server_module
|
||||||
|
from utils.client_info import format_client_info, get_client_info_from_context
|
||||||
|
|
||||||
|
client_info = get_client_info_from_context(server_module.server)
|
||||||
|
if client_info:
|
||||||
|
formatted = format_client_info(client_info)
|
||||||
|
output_lines.append(f"**Connected Client**: {formatted}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not get client info: {e}")
|
||||||
|
|
||||||
|
# Get the current working directory (MCP server location)
|
||||||
|
current_path = Path.cwd()
|
||||||
|
output_lines.append(f"**Installation Path**: `{current_path}`")
|
||||||
|
output_lines.append("")
|
||||||
|
output_lines.append("## Agent Reporting Guidance")
|
||||||
|
output_lines.append(
|
||||||
|
"Agents MUST report: version, model-selection status, configured providers, and available-model count."
|
||||||
|
)
|
||||||
|
output_lines.append("Repeat the quick-summary bullets verbatim in your reply.")
|
||||||
|
output_lines.append("Reference `listmodels` when users ask about model availability or capabilities.")
|
||||||
|
output_lines.append("")
|
||||||
|
|
||||||
|
# Check for updates from GitHub
|
||||||
|
output_lines.append("## Update Status")
|
||||||
|
|
||||||
|
try:
|
||||||
|
github_info = fetch_github_version()
|
||||||
|
|
||||||
|
if github_info:
|
||||||
|
remote_version, remote_updated = github_info
|
||||||
|
comparison = compare_versions(__version__, remote_version)
|
||||||
|
|
||||||
|
output_lines.append(f"**Latest Version (GitHub)**: {remote_version}")
|
||||||
|
output_lines.append(f"**Latest Updated**: {remote_updated}")
|
||||||
|
|
||||||
|
if comparison < 0:
|
||||||
|
# Update available
|
||||||
|
output_lines.append("")
|
||||||
|
output_lines.append("🚀 **UPDATE AVAILABLE!**")
|
||||||
|
output_lines.append(
|
||||||
|
f"Your version `{__version__}` is older than the latest version `{remote_version}`"
|
||||||
|
)
|
||||||
|
output_lines.append("")
|
||||||
|
output_lines.append("**To update:**")
|
||||||
|
output_lines.append("```bash")
|
||||||
|
output_lines.append(f"cd {current_path}")
|
||||||
|
output_lines.append("git pull")
|
||||||
|
output_lines.append("```")
|
||||||
|
output_lines.append("")
|
||||||
|
output_lines.append("*Note: Restart your session after updating to use the new version.*")
|
||||||
|
elif comparison == 0:
|
||||||
|
# Up to date
|
||||||
|
output_lines.append("")
|
||||||
|
output_lines.append("✅ **UP TO DATE**")
|
||||||
|
output_lines.append("You are running the latest version.")
|
||||||
|
else:
|
||||||
|
# Ahead of remote (development version)
|
||||||
|
output_lines.append("")
|
||||||
|
output_lines.append("🔬 **DEVELOPMENT VERSION**")
|
||||||
|
output_lines.append(
|
||||||
|
f"Your version `{__version__}` is ahead of the published version `{remote_version}`"
|
||||||
|
)
|
||||||
|
output_lines.append("You may be running a development or custom build.")
|
||||||
|
else:
|
||||||
|
output_lines.append("❌ **Could not check for updates**")
|
||||||
|
output_lines.append("Unable to connect to GitHub or parse version information.")
|
||||||
|
output_lines.append("Check your internet connection or try again later.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during version check: {e}")
|
||||||
|
output_lines.append("❌ **Error checking for updates**")
|
||||||
|
output_lines.append(f"Error: {str(e)}")
|
||||||
|
|
||||||
|
output_lines.append("")
|
||||||
|
|
||||||
|
# Configuration information
|
||||||
|
output_lines.append("## Configuration")
|
||||||
|
|
||||||
|
# Check for configured providers
|
||||||
|
try:
|
||||||
|
from providers.registry import ModelProviderRegistry
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
provider_status = []
|
||||||
|
|
||||||
|
# Check each provider type
|
||||||
|
provider_types = [
|
||||||
|
ProviderType.GOOGLE,
|
||||||
|
ProviderType.OPENAI,
|
||||||
|
ProviderType.XAI,
|
||||||
|
ProviderType.DIAL,
|
||||||
|
ProviderType.OPENROUTER,
|
||||||
|
ProviderType.CUSTOM,
|
||||||
|
]
|
||||||
|
provider_names = ["Google Gemini", "OpenAI", "X.AI", "DIAL", "OpenRouter", "Custom/Local"]
|
||||||
|
|
||||||
|
for provider_type, provider_name in zip(provider_types, provider_names):
|
||||||
|
provider = ModelProviderRegistry.get_provider(provider_type)
|
||||||
|
status = "✅ Configured" if provider is not None else "❌ Not configured"
|
||||||
|
provider_status.append(f"- **{provider_name}**: {status}")
|
||||||
|
|
||||||
|
output_lines.append("**Providers**:")
|
||||||
|
output_lines.extend(provider_status)
|
||||||
|
|
||||||
|
# Get total available models
|
||||||
|
try:
|
||||||
|
available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True)
|
||||||
|
output_lines.append(f"\n\n**Available Models**: {len(available_models)}")
|
||||||
|
except Exception:
|
||||||
|
output_lines.append("\n\n**Available Models**: Unknown")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error checking provider configuration: {e}")
|
||||||
|
output_lines.append("\n\n**Providers**: Error checking configuration")
|
||||||
|
|
||||||
|
output_lines.append("")
|
||||||
|
|
||||||
|
# Format output
|
||||||
|
content = "\n".join(output_lines)
|
||||||
|
|
||||||
|
tool_output = ToolOutput(
|
||||||
|
status="success",
|
||||||
|
content=content,
|
||||||
|
content_type="text",
|
||||||
|
metadata={
|
||||||
|
"tool_name": self.name,
|
||||||
|
"server_version": __version__,
|
||||||
|
"last_updated": __updated__,
|
||||||
|
"python_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
|
||||||
|
"platform": f"{platform.system()} {platform.release()}",
|
||||||
|
"model_selection_mode": model_selection_metadata["mode"],
|
||||||
|
"default_model": model_selection_metadata["default_model"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return [TextContent(type="text", text=tool_output.model_dump_json())]
|
||||||
|
|
||||||
|
def get_model_category(self) -> ToolModelCategory:
|
||||||
|
"""Return the model category for this tool."""
|
||||||
|
return ToolModelCategory.FAST_RESPONSE # Simple version info, no AI needed
|
||||||
22
tools/workflow/__init__.py
Normal file
22
tools/workflow/__init__.py
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
"""
|
||||||
|
Workflow tools for Zen MCP.
|
||||||
|
|
||||||
|
Workflow tools follow a multi-step pattern with forced pauses between steps
|
||||||
|
to encourage thorough investigation and analysis. They inherit from WorkflowTool
|
||||||
|
which combines BaseTool with BaseWorkflowMixin.
|
||||||
|
|
||||||
|
Available workflow tools:
|
||||||
|
- debug: Systematic investigation and root cause analysis
|
||||||
|
- planner: Sequential planning (special case - no AI calls)
|
||||||
|
- analyze: Code analysis workflow
|
||||||
|
- codereview: Code review workflow
|
||||||
|
- precommit: Pre-commit validation workflow
|
||||||
|
- refactor: Refactoring analysis workflow
|
||||||
|
- thinkdeep: Deep thinking workflow
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import WorkflowTool
|
||||||
|
from .schema_builders import WorkflowSchemaBuilder
|
||||||
|
from .workflow_mixin import BaseWorkflowMixin
|
||||||
|
|
||||||
|
__all__ = ["WorkflowTool", "WorkflowSchemaBuilder", "BaseWorkflowMixin"]
|
||||||
444
tools/workflow/base.py
Normal file
444
tools/workflow/base.py
Normal file
|
|
@ -0,0 +1,444 @@
|
||||||
|
"""
|
||||||
|
Base class for workflow MCP tools.
|
||||||
|
|
||||||
|
Workflow tools follow a multi-step pattern:
|
||||||
|
1. Claude calls tool with work step data
|
||||||
|
2. Tool tracks findings and progress
|
||||||
|
3. Tool forces Claude to pause and investigate between steps
|
||||||
|
4. Once work is complete, tool calls external AI model for expert analysis
|
||||||
|
5. Tool returns structured response combining investigation + expert analysis
|
||||||
|
|
||||||
|
They combine BaseTool's capabilities with BaseWorkflowMixin's workflow functionality
|
||||||
|
and use SchemaBuilder for consistent schema generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from tools.shared.base_models import WorkflowRequest
|
||||||
|
from tools.shared.base_tool import BaseTool
|
||||||
|
|
||||||
|
from .schema_builders import WorkflowSchemaBuilder
|
||||||
|
from .workflow_mixin import BaseWorkflowMixin
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowTool(BaseTool, BaseWorkflowMixin):
|
||||||
|
"""
|
||||||
|
Base class for workflow (multi-step) tools.
|
||||||
|
|
||||||
|
Workflow tools perform systematic multi-step work with expert analysis.
|
||||||
|
They benefit from:
|
||||||
|
- Automatic workflow orchestration from BaseWorkflowMixin
|
||||||
|
- Automatic schema generation using SchemaBuilder
|
||||||
|
- Inherited conversation handling and file processing from BaseTool
|
||||||
|
- Progress tracking with ConsolidatedFindings
|
||||||
|
- Expert analysis integration
|
||||||
|
|
||||||
|
To create a workflow tool:
|
||||||
|
1. Inherit from WorkflowTool
|
||||||
|
2. Tool name is automatically provided by get_name() method
|
||||||
|
3. Implement get_required_actions() for step guidance
|
||||||
|
4. Implement should_call_expert_analysis() for completion criteria
|
||||||
|
5. Implement prepare_expert_analysis_context() for expert prompts
|
||||||
|
6. Optionally implement get_tool_fields() for additional fields
|
||||||
|
7. Optionally override workflow behavior methods
|
||||||
|
|
||||||
|
Example:
|
||||||
|
class DebugTool(WorkflowTool):
|
||||||
|
# get_name() is inherited from BaseTool
|
||||||
|
|
||||||
|
def get_tool_fields(self) -> Dict[str, Dict[str, Any]]:
|
||||||
|
return {
|
||||||
|
"hypothesis": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Current theory about the issue",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_required_actions(
|
||||||
|
self, step_number: int, confidence: str, findings: str, total_steps: int
|
||||||
|
) -> List[str]:
|
||||||
|
return ["Examine relevant code files", "Trace execution flow", "Check error logs"]
|
||||||
|
|
||||||
|
def should_call_expert_analysis(self, consolidated_findings) -> bool:
|
||||||
|
return len(consolidated_findings.relevant_files) > 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize WorkflowTool with proper multiple inheritance."""
|
||||||
|
BaseTool.__init__(self)
|
||||||
|
BaseWorkflowMixin.__init__(self)
|
||||||
|
|
||||||
|
def get_tool_fields(self) -> dict[str, dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Return tool-specific field definitions beyond the standard workflow fields.
|
||||||
|
|
||||||
|
Workflow tools automatically get all standard workflow fields:
|
||||||
|
- step, step_number, total_steps, next_step_required
|
||||||
|
- findings, files_checked, relevant_files, relevant_context
|
||||||
|
- issues_found, confidence, hypothesis, backtrack_from_step
|
||||||
|
- plus common fields (model, temperature, etc.)
|
||||||
|
|
||||||
|
Override this method to add additional tool-specific fields.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping field names to JSON schema objects
|
||||||
|
|
||||||
|
Example:
|
||||||
|
return {
|
||||||
|
"severity_filter": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["low", "medium", "high"],
|
||||||
|
"description": "Minimum severity level to report",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_required_fields(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Return additional required fields beyond the standard workflow requirements.
|
||||||
|
|
||||||
|
Workflow tools automatically require:
|
||||||
|
- step, step_number, total_steps, next_step_required, findings
|
||||||
|
- model (if in auto mode)
|
||||||
|
|
||||||
|
Override this to add additional required fields.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of additional required field names
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_annotations(self) -> Optional[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Return tool annotations. Workflow tools are read-only by default.
|
||||||
|
|
||||||
|
All workflow tools perform analysis and investigation without modifying
|
||||||
|
the environment. They may call external AI models for expert analysis,
|
||||||
|
but they don't write files or make system changes.
|
||||||
|
|
||||||
|
Override this method if your workflow tool needs different annotations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with readOnlyHint set to True
|
||||||
|
"""
|
||||||
|
return {"readOnlyHint": True}
|
||||||
|
|
||||||
|
def get_input_schema(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate the complete input schema using SchemaBuilder.
|
||||||
|
|
||||||
|
This method automatically combines:
|
||||||
|
- Standard workflow fields (step, findings, etc.)
|
||||||
|
- Common fields (temperature, thinking_mode, etc.)
|
||||||
|
- Model field with proper auto-mode handling
|
||||||
|
- Tool-specific fields from get_tool_fields()
|
||||||
|
- Required fields from get_required_fields()
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Complete JSON schema for the workflow tool
|
||||||
|
"""
|
||||||
|
return WorkflowSchemaBuilder.build_schema(
|
||||||
|
tool_specific_fields=self.get_tool_fields(),
|
||||||
|
required_fields=self.get_required_fields(),
|
||||||
|
model_field_schema=self.get_model_field_schema(),
|
||||||
|
auto_mode=self.is_effective_auto_mode(),
|
||||||
|
tool_name=self.get_name(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_workflow_request_model(self):
|
||||||
|
"""
|
||||||
|
Return the workflow request model class.
|
||||||
|
|
||||||
|
Workflow tools use WorkflowRequest by default, which includes
|
||||||
|
all the standard workflow fields. Override this if your tool
|
||||||
|
needs a custom request model.
|
||||||
|
"""
|
||||||
|
return WorkflowRequest
|
||||||
|
|
||||||
|
# Implement the abstract method from BaseWorkflowMixin
|
||||||
|
def get_work_steps(self, request) -> list[str]:
|
||||||
|
"""
|
||||||
|
Default implementation - workflow tools typically don't need predefined steps.
|
||||||
|
|
||||||
|
The workflow is driven by Claude's investigation process rather than
|
||||||
|
predefined steps. Override this if your tool needs specific step guidance.
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Default implementations for common workflow patterns
|
||||||
|
|
||||||
|
def get_standard_required_actions(self, step_number: int, confidence: str, base_actions: list[str]) -> list[str]:
|
||||||
|
"""
|
||||||
|
Helper method to generate standard required actions based on confidence and step.
|
||||||
|
|
||||||
|
This provides common patterns that most workflow tools can use:
|
||||||
|
- Early steps: broad exploration
|
||||||
|
- Low confidence: deeper investigation
|
||||||
|
- Medium/high confidence: verification and confirmation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_number: Current step number
|
||||||
|
confidence: Current confidence level
|
||||||
|
base_actions: Tool-specific base actions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of required actions appropriate for the current state
|
||||||
|
"""
|
||||||
|
if step_number == 1:
|
||||||
|
# Initial investigation
|
||||||
|
return [
|
||||||
|
"Search for code related to the reported issue or symptoms",
|
||||||
|
"Examine relevant files and understand the current implementation",
|
||||||
|
"Understand the project structure and locate relevant modules",
|
||||||
|
"Identify how the affected functionality is supposed to work",
|
||||||
|
]
|
||||||
|
elif confidence in ["exploring", "low"]:
|
||||||
|
# Need deeper investigation
|
||||||
|
return base_actions + [
|
||||||
|
"Trace method calls and data flow through the system",
|
||||||
|
"Check for edge cases, boundary conditions, and assumptions in the code",
|
||||||
|
"Look for related configuration, dependencies, or external factors",
|
||||||
|
]
|
||||||
|
elif confidence in ["medium", "high"]:
|
||||||
|
# Close to solution - need confirmation
|
||||||
|
return base_actions + [
|
||||||
|
"Examine the exact code sections where you believe the issue occurs",
|
||||||
|
"Trace the execution path that leads to the failure",
|
||||||
|
"Verify your hypothesis with concrete code evidence",
|
||||||
|
"Check for any similar patterns elsewhere in the codebase",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# General continued investigation
|
||||||
|
return base_actions + [
|
||||||
|
"Continue examining the code paths identified in your hypothesis",
|
||||||
|
"Gather more evidence using appropriate investigation tools",
|
||||||
|
"Test edge cases and boundary conditions",
|
||||||
|
"Look for patterns that confirm or refute your theory",
|
||||||
|
]
|
||||||
|
|
||||||
|
def should_call_expert_analysis_default(self, consolidated_findings) -> bool:
|
||||||
|
"""
|
||||||
|
Default implementation for expert analysis decision.
|
||||||
|
|
||||||
|
This provides a reasonable default that most workflow tools can use:
|
||||||
|
- Call expert analysis if we have relevant files or significant findings
|
||||||
|
- Skip if confidence is "certain" (handled by the workflow mixin)
|
||||||
|
|
||||||
|
Override this for tool-specific logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
consolidated_findings: The consolidated findings from all work steps
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if expert analysis should be called
|
||||||
|
"""
|
||||||
|
# Call expert analysis if we have relevant files or substantial findings
|
||||||
|
return (
|
||||||
|
len(consolidated_findings.relevant_files) > 0
|
||||||
|
or len(consolidated_findings.findings) >= 2
|
||||||
|
or len(consolidated_findings.issues_found) > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_standard_expert_context(
|
||||||
|
self, consolidated_findings, initial_description: str, context_sections: dict[str, str] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Helper method to prepare standard expert analysis context.
|
||||||
|
|
||||||
|
This provides a common structure that most workflow tools can use,
|
||||||
|
with the ability to add tool-specific sections.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
consolidated_findings: The consolidated findings from all work steps
|
||||||
|
initial_description: Description of the initial request/issue
|
||||||
|
context_sections: Optional additional sections to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted context string for expert analysis
|
||||||
|
"""
|
||||||
|
context_parts = [f"=== ISSUE DESCRIPTION ===\n{initial_description}\n=== END DESCRIPTION ==="]
|
||||||
|
|
||||||
|
# Add work progression
|
||||||
|
if consolidated_findings.findings:
|
||||||
|
findings_text = "\n".join(consolidated_findings.findings)
|
||||||
|
context_parts.append(f"\n=== INVESTIGATION FINDINGS ===\n{findings_text}\n=== END FINDINGS ===")
|
||||||
|
|
||||||
|
# Add relevant methods if available
|
||||||
|
if consolidated_findings.relevant_context:
|
||||||
|
methods_text = "\n".join(f"- {method}" for method in consolidated_findings.relevant_context)
|
||||||
|
context_parts.append(f"\n=== RELEVANT METHODS/FUNCTIONS ===\n{methods_text}\n=== END METHODS ===")
|
||||||
|
|
||||||
|
# Add hypothesis evolution if available
|
||||||
|
if consolidated_findings.hypotheses:
|
||||||
|
hypotheses_text = "\n".join(
|
||||||
|
f"Step {h['step']} ({h['confidence']} confidence): {h['hypothesis']}"
|
||||||
|
for h in consolidated_findings.hypotheses
|
||||||
|
)
|
||||||
|
context_parts.append(f"\n=== HYPOTHESIS EVOLUTION ===\n{hypotheses_text}\n=== END HYPOTHESES ===")
|
||||||
|
|
||||||
|
# Add issues found if available
|
||||||
|
if consolidated_findings.issues_found:
|
||||||
|
issues_text = "\n".join(
|
||||||
|
f"[{issue.get('severity', 'unknown').upper()}] {issue.get('description', 'No description')}"
|
||||||
|
for issue in consolidated_findings.issues_found
|
||||||
|
)
|
||||||
|
context_parts.append(f"\n=== ISSUES IDENTIFIED ===\n{issues_text}\n=== END ISSUES ===")
|
||||||
|
|
||||||
|
# Add tool-specific sections
|
||||||
|
if context_sections:
|
||||||
|
for section_title, section_content in context_sections.items():
|
||||||
|
context_parts.append(
|
||||||
|
f"\n=== {section_title.upper()} ===\n{section_content}\n=== END {section_title.upper()} ==="
|
||||||
|
)
|
||||||
|
|
||||||
|
return "\n".join(context_parts)
|
||||||
|
|
||||||
|
def handle_completion_without_expert_analysis(
|
||||||
|
self, request, consolidated_findings, initial_description: str = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generic handler for completion when expert analysis is not needed.
|
||||||
|
|
||||||
|
This provides a standard response format for when the tool determines
|
||||||
|
that external expert analysis is not required. All workflow tools
|
||||||
|
can use this generic implementation or override for custom behavior.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The workflow request object
|
||||||
|
consolidated_findings: The consolidated findings from all work steps
|
||||||
|
initial_description: Optional initial description (defaults to request.step)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with completion response data
|
||||||
|
"""
|
||||||
|
# Prepare work summary using inheritance hook
|
||||||
|
work_summary = self.prepare_work_summary()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": self.get_completion_status(),
|
||||||
|
self.get_completion_data_key(): {
|
||||||
|
"initial_request": initial_description or request.step,
|
||||||
|
"steps_taken": len(consolidated_findings.findings),
|
||||||
|
"files_examined": list(consolidated_findings.files_checked),
|
||||||
|
"relevant_files": list(consolidated_findings.relevant_files),
|
||||||
|
"relevant_context": list(consolidated_findings.relevant_context),
|
||||||
|
"work_summary": work_summary,
|
||||||
|
"final_analysis": self.get_final_analysis_from_request(request),
|
||||||
|
"confidence_level": self.get_confidence_level(request),
|
||||||
|
},
|
||||||
|
"next_steps": self.get_completion_message(),
|
||||||
|
"skip_expert_analysis": True,
|
||||||
|
"expert_analysis": {
|
||||||
|
"status": self.get_skip_expert_analysis_status(),
|
||||||
|
"reason": self.get_skip_reason(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Inheritance hooks for customization
|
||||||
|
|
||||||
|
def prepare_work_summary(self) -> str:
|
||||||
|
"""
|
||||||
|
Prepare a summary of the work performed. Override for custom summaries.
|
||||||
|
Default implementation provides a basic summary.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return self._prepare_work_summary()
|
||||||
|
except AttributeError:
|
||||||
|
try:
|
||||||
|
return f"Completed {len(self.work_history)} work steps"
|
||||||
|
except AttributeError:
|
||||||
|
return "Completed 0 work steps"
|
||||||
|
|
||||||
|
def get_completion_status(self) -> str:
|
||||||
|
"""Get the status to use when completing without expert analysis."""
|
||||||
|
return "high_confidence_completion"
|
||||||
|
|
||||||
|
def get_completion_data_key(self) -> str:
|
||||||
|
"""Get the key name for completion data in the response."""
|
||||||
|
return f"complete_{self.get_name()}"
|
||||||
|
|
||||||
|
def get_final_analysis_from_request(self, request) -> Optional[str]:
|
||||||
|
"""Extract final analysis from request. Override for tool-specific extraction."""
|
||||||
|
try:
|
||||||
|
return request.hypothesis
|
||||||
|
except AttributeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_confidence_level(self, request) -> str:
|
||||||
|
"""Get confidence level from request. Override for tool-specific logic."""
|
||||||
|
try:
|
||||||
|
return request.confidence or "high"
|
||||||
|
except AttributeError:
|
||||||
|
return "high"
|
||||||
|
|
||||||
|
def get_completion_message(self) -> str:
|
||||||
|
"""Get completion message. Override for tool-specific messaging."""
|
||||||
|
return (
|
||||||
|
f"{self.get_name().capitalize()} complete with high confidence. You have identified the exact "
|
||||||
|
"analysis and solution. MANDATORY: Present the user with the results "
|
||||||
|
"and proceed with implementing the solution without requiring further "
|
||||||
|
"consultation. Focus on the precise, actionable steps needed."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_skip_reason(self) -> str:
|
||||||
|
"""Get reason for skipping expert analysis. Override for tool-specific reasons."""
|
||||||
|
return f"{self.get_name()} completed with sufficient confidence"
|
||||||
|
|
||||||
|
def get_skip_expert_analysis_status(self) -> str:
|
||||||
|
"""Get status for skipped expert analysis. Override for tool-specific status."""
|
||||||
|
return "skipped_by_tool_design"
|
||||||
|
|
||||||
|
def is_continuation_workflow(self, request) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this is a continuation workflow that should skip multi-step investigation.
|
||||||
|
|
||||||
|
When continuation_id is provided, the workflow typically continues from a previous
|
||||||
|
conversation and should go directly to expert analysis rather than starting a new
|
||||||
|
multi-step investigation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The workflow request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if this is a continuation that should skip multi-step workflow
|
||||||
|
"""
|
||||||
|
continuation_id = self.get_request_continuation_id(request)
|
||||||
|
return bool(continuation_id)
|
||||||
|
|
||||||
|
# Abstract methods that must be implemented by specific workflow tools
|
||||||
|
# (These are inherited from BaseWorkflowMixin and must be implemented)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_required_actions(
|
||||||
|
self, step_number: int, confidence: str, findings: str, total_steps: int, request=None
|
||||||
|
) -> list[str]:
|
||||||
|
"""Define required actions for each work phase.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_number: Current step number
|
||||||
|
confidence: Current confidence level
|
||||||
|
findings: Current findings text
|
||||||
|
total_steps: Total estimated steps
|
||||||
|
request: Optional request object for continuation-aware decisions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of required actions for the current step
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def should_call_expert_analysis(self, consolidated_findings) -> bool:
|
||||||
|
"""Decide when to call external model based on tool-specific criteria"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def prepare_expert_analysis_context(self, consolidated_findings) -> str:
|
||||||
|
"""Prepare context for external model call"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Default execute method - delegates to workflow
|
||||||
|
async def execute(self, arguments: dict[str, Any]) -> list:
|
||||||
|
"""Execute the workflow tool - delegates to BaseWorkflowMixin."""
|
||||||
|
return await self.execute_workflow(arguments)
|
||||||
174
tools/workflow/schema_builders.py
Normal file
174
tools/workflow/schema_builders.py
Normal file
|
|
@ -0,0 +1,174 @@
|
||||||
|
"""
|
||||||
|
Schema builders for workflow MCP tools.
|
||||||
|
|
||||||
|
This module provides workflow-specific schema generation functionality,
|
||||||
|
keeping workflow concerns separated from simple tool concerns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..shared.base_models import WORKFLOW_FIELD_DESCRIPTIONS
|
||||||
|
from ..shared.schema_builders import SchemaBuilder
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowSchemaBuilder:
|
||||||
|
"""
|
||||||
|
Schema builder for workflow MCP tools.
|
||||||
|
|
||||||
|
This class extends the base SchemaBuilder with workflow-specific fields
|
||||||
|
and schema generation logic, maintaining separation of concerns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Workflow-specific field schemas
|
||||||
|
WORKFLOW_FIELD_SCHEMAS = {
|
||||||
|
"step": {
|
||||||
|
"type": "string",
|
||||||
|
"description": WORKFLOW_FIELD_DESCRIPTIONS["step"],
|
||||||
|
},
|
||||||
|
"step_number": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
"description": WORKFLOW_FIELD_DESCRIPTIONS["step_number"],
|
||||||
|
},
|
||||||
|
"total_steps": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
"description": WORKFLOW_FIELD_DESCRIPTIONS["total_steps"],
|
||||||
|
},
|
||||||
|
"next_step_required": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"],
|
||||||
|
},
|
||||||
|
"findings": {
|
||||||
|
"type": "string",
|
||||||
|
"description": WORKFLOW_FIELD_DESCRIPTIONS["findings"],
|
||||||
|
},
|
||||||
|
"files_checked": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": WORKFLOW_FIELD_DESCRIPTIONS["files_checked"],
|
||||||
|
},
|
||||||
|
"relevant_files": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"],
|
||||||
|
},
|
||||||
|
"relevant_context": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": WORKFLOW_FIELD_DESCRIPTIONS["relevant_context"],
|
||||||
|
},
|
||||||
|
"issues_found": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "object"},
|
||||||
|
"description": WORKFLOW_FIELD_DESCRIPTIONS["issues_found"],
|
||||||
|
},
|
||||||
|
"confidence": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["exploring", "low", "medium", "high", "very_high", "almost_certain", "certain"],
|
||||||
|
"description": WORKFLOW_FIELD_DESCRIPTIONS["confidence"],
|
||||||
|
},
|
||||||
|
"hypothesis": {
|
||||||
|
"type": "string",
|
||||||
|
"description": WORKFLOW_FIELD_DESCRIPTIONS["hypothesis"],
|
||||||
|
},
|
||||||
|
"backtrack_from_step": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
"description": WORKFLOW_FIELD_DESCRIPTIONS["backtrack_from_step"],
|
||||||
|
},
|
||||||
|
"use_assistant_model": {
|
||||||
|
"type": "boolean",
|
||||||
|
"default": True,
|
||||||
|
"description": WORKFLOW_FIELD_DESCRIPTIONS["use_assistant_model"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_schema(
|
||||||
|
tool_specific_fields: dict[str, dict[str, Any]] = None,
|
||||||
|
required_fields: list[str] = None,
|
||||||
|
model_field_schema: dict[str, Any] = None,
|
||||||
|
auto_mode: bool = False,
|
||||||
|
tool_name: str = None,
|
||||||
|
excluded_workflow_fields: list[str] = None,
|
||||||
|
excluded_common_fields: list[str] = None,
|
||||||
|
require_model: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Build complete schema for workflow tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_specific_fields: Additional fields specific to the tool
|
||||||
|
required_fields: List of required field names (beyond workflow defaults)
|
||||||
|
model_field_schema: Schema for the model field
|
||||||
|
auto_mode: Whether the tool is in auto mode (affects model requirement)
|
||||||
|
tool_name: Name of the tool (for schema title)
|
||||||
|
excluded_workflow_fields: Workflow fields to exclude from schema (e.g., for planning tools)
|
||||||
|
excluded_common_fields: Common fields to exclude from schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Complete JSON schema for the workflow tool
|
||||||
|
"""
|
||||||
|
properties = {}
|
||||||
|
|
||||||
|
# Add workflow fields first, excluding any specified fields
|
||||||
|
workflow_fields = WorkflowSchemaBuilder.WORKFLOW_FIELD_SCHEMAS.copy()
|
||||||
|
if excluded_workflow_fields:
|
||||||
|
for field in excluded_workflow_fields:
|
||||||
|
workflow_fields.pop(field, None)
|
||||||
|
properties.update(workflow_fields)
|
||||||
|
|
||||||
|
# Add common fields (temperature, thinking_mode, etc.) from base builder, excluding any specified fields
|
||||||
|
common_fields = SchemaBuilder.COMMON_FIELD_SCHEMAS.copy()
|
||||||
|
if excluded_common_fields:
|
||||||
|
for field in excluded_common_fields:
|
||||||
|
common_fields.pop(field, None)
|
||||||
|
properties.update(common_fields)
|
||||||
|
|
||||||
|
# Add model field if provided
|
||||||
|
if model_field_schema:
|
||||||
|
properties["model"] = model_field_schema
|
||||||
|
|
||||||
|
# Add tool-specific fields if provided
|
||||||
|
if tool_specific_fields:
|
||||||
|
properties.update(tool_specific_fields)
|
||||||
|
|
||||||
|
# Build required fields list - workflow tools have standard required fields
|
||||||
|
standard_required = ["step", "step_number", "total_steps", "next_step_required", "findings"]
|
||||||
|
|
||||||
|
# Filter out excluded fields from required fields
|
||||||
|
if excluded_workflow_fields:
|
||||||
|
standard_required = [field for field in standard_required if field not in excluded_workflow_fields]
|
||||||
|
|
||||||
|
required = standard_required + (required_fields or [])
|
||||||
|
|
||||||
|
if (auto_mode or require_model) and "model" not in required:
|
||||||
|
required.append("model")
|
||||||
|
|
||||||
|
# Build the complete schema
|
||||||
|
schema = {
|
||||||
|
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
"required": required,
|
||||||
|
"additionalProperties": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tool_name:
|
||||||
|
schema["title"] = f"{tool_name.capitalize()}Request"
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_workflow_fields() -> dict[str, dict[str, Any]]:
|
||||||
|
"""Get the standard field schemas for workflow tools."""
|
||||||
|
combined = {}
|
||||||
|
combined.update(WorkflowSchemaBuilder.WORKFLOW_FIELD_SCHEMAS)
|
||||||
|
combined.update(SchemaBuilder.COMMON_FIELD_SCHEMAS)
|
||||||
|
return combined
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_workflow_only_fields() -> dict[str, dict[str, Any]]:
|
||||||
|
"""Get only the workflow-specific field schemas."""
|
||||||
|
return WorkflowSchemaBuilder.WORKFLOW_FIELD_SCHEMAS.copy()
|
||||||
1619
tools/workflow/workflow_mixin.py
Normal file
1619
tools/workflow/workflow_mixin.py
Normal file
File diff suppressed because it is too large
Load diff
21
utils/__init__.py
Normal file
21
utils/__init__.py
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
"""
|
||||||
|
Utility functions for Zen MCP Server
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .file_types import CODE_EXTENSIONS, FILE_CATEGORIES, PROGRAMMING_EXTENSIONS, TEXT_EXTENSIONS
|
||||||
|
from .file_utils import expand_paths, read_file_content, read_files
|
||||||
|
from .security_config import EXCLUDED_DIRS
|
||||||
|
from .token_utils import check_token_limit, estimate_tokens
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"read_files",
|
||||||
|
"read_file_content",
|
||||||
|
"expand_paths",
|
||||||
|
"CODE_EXTENSIONS",
|
||||||
|
"PROGRAMMING_EXTENSIONS",
|
||||||
|
"TEXT_EXTENSIONS",
|
||||||
|
"FILE_CATEGORIES",
|
||||||
|
"EXCLUDED_DIRS",
|
||||||
|
"estimate_tokens",
|
||||||
|
"check_token_limit",
|
||||||
|
]
|
||||||
293
utils/client_info.py
Normal file
293
utils/client_info.py
Normal file
|
|
@ -0,0 +1,293 @@
|
||||||
|
"""
|
||||||
|
Client Information Utility for MCP Server
|
||||||
|
|
||||||
|
This module provides utilities to extract and format client information
|
||||||
|
from the MCP protocol's clientInfo sent during initialization.
|
||||||
|
|
||||||
|
It also provides friendly name mapping and caching for consistent client
|
||||||
|
identification across the application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global cache for client information
|
||||||
|
_client_info_cache: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
# Mapping of known client names to friendly names
|
||||||
|
# This is case-insensitive and checks if the key is contained in the client name
|
||||||
|
CLIENT_NAME_MAPPINGS = {
|
||||||
|
# Claude variants
|
||||||
|
"claude-ai": "Claude",
|
||||||
|
"claude": "Claude",
|
||||||
|
"claude-desktop": "Claude",
|
||||||
|
"claude-code": "Claude",
|
||||||
|
"anthropic": "Claude",
|
||||||
|
# Gemini variants
|
||||||
|
"gemini-cli-mcp-client": "Gemini",
|
||||||
|
"gemini-cli": "Gemini",
|
||||||
|
"gemini": "Gemini",
|
||||||
|
"google": "Gemini",
|
||||||
|
# Other known clients
|
||||||
|
"cursor": "Cursor",
|
||||||
|
"vscode": "VS Code",
|
||||||
|
"codeium": "Codeium",
|
||||||
|
"copilot": "GitHub Copilot",
|
||||||
|
# Generic MCP clients
|
||||||
|
"mcp-client": "MCP Client",
|
||||||
|
"test-client": "Test Client",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Default friendly name when no match is found
|
||||||
|
DEFAULT_FRIENDLY_NAME = "Claude"
|
||||||
|
|
||||||
|
|
||||||
|
def get_friendly_name(client_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Map a client name to a friendly name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_name: The raw client name from clientInfo
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A friendly name for display (e.g., "Claude", "Gemini")
|
||||||
|
"""
|
||||||
|
if not client_name:
|
||||||
|
return DEFAULT_FRIENDLY_NAME
|
||||||
|
|
||||||
|
# Convert to lowercase for case-insensitive matching
|
||||||
|
client_name_lower = client_name.lower()
|
||||||
|
|
||||||
|
# Check each mapping - using 'in' to handle partial matches
|
||||||
|
for key, friendly_name in CLIENT_NAME_MAPPINGS.items():
|
||||||
|
if key.lower() in client_name_lower:
|
||||||
|
return friendly_name
|
||||||
|
|
||||||
|
# If no match found, return the default
|
||||||
|
return DEFAULT_FRIENDLY_NAME
|
||||||
|
|
||||||
|
|
||||||
|
def get_cached_client_info() -> Optional[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get cached client information if available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached client info dictionary or None
|
||||||
|
"""
|
||||||
|
global _client_info_cache
|
||||||
|
return _client_info_cache
|
||||||
|
|
||||||
|
|
||||||
|
def get_client_info_from_context(server: Any) -> Optional[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Extract client information from the MCP server's request context.
|
||||||
|
|
||||||
|
The MCP protocol sends clientInfo during initialization containing:
|
||||||
|
- name: The client application name (e.g., "Claude Code", "Claude Desktop")
|
||||||
|
- version: The client version string
|
||||||
|
|
||||||
|
This function also adds a friendly_name field and caches the result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server: The MCP server instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with client info or None if not available:
|
||||||
|
{
|
||||||
|
"name": "claude-ai",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"friendly_name": "Claude"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
global _client_info_cache
|
||||||
|
|
||||||
|
# Return cached info if available
|
||||||
|
if _client_info_cache is not None:
|
||||||
|
return _client_info_cache
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to access the request context and session
|
||||||
|
if not server:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if server has request_context property
|
||||||
|
request_context = None
|
||||||
|
try:
|
||||||
|
request_context = server.request_context
|
||||||
|
except AttributeError:
|
||||||
|
logger.debug("Server does not have request_context property")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not request_context:
|
||||||
|
logger.debug("Request context is None")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Try to access session from request context
|
||||||
|
session = None
|
||||||
|
try:
|
||||||
|
session = request_context.session
|
||||||
|
except AttributeError:
|
||||||
|
logger.debug("Request context does not have session property")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
logger.debug("Session is None")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Try to access client params from session
|
||||||
|
client_params = None
|
||||||
|
try:
|
||||||
|
# The clientInfo is stored in _client_params.clientInfo
|
||||||
|
client_params = session._client_params
|
||||||
|
except AttributeError:
|
||||||
|
logger.debug("Session does not have _client_params property")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not client_params:
|
||||||
|
logger.debug("Client params is None")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Try to extract clientInfo
|
||||||
|
client_info = None
|
||||||
|
try:
|
||||||
|
client_info = client_params.clientInfo
|
||||||
|
except AttributeError:
|
||||||
|
logger.debug("Client params does not have clientInfo property")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not client_info:
|
||||||
|
logger.debug("Client info is None")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract name and version
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result["name"] = client_info.name
|
||||||
|
except AttributeError:
|
||||||
|
logger.debug("Client info does not have name property")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result["version"] = client_info.version
|
||||||
|
except AttributeError:
|
||||||
|
logger.debug("Client info does not have version property")
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Add friendly name
|
||||||
|
raw_name = result.get("name", "")
|
||||||
|
result["friendly_name"] = get_friendly_name(raw_name)
|
||||||
|
|
||||||
|
# Cache the result
|
||||||
|
_client_info_cache = result
|
||||||
|
logger.debug(f"Cached client info: {result}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error extracting client info: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def format_client_info(client_info: Optional[dict[str, Any]], use_friendly_name: bool = True) -> str:
|
||||||
|
"""
|
||||||
|
Format client information for display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_info: Dictionary with client info or None
|
||||||
|
use_friendly_name: If True, use the friendly name instead of raw name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string like "Claude v1.0.0" or "Claude"
|
||||||
|
"""
|
||||||
|
if not client_info:
|
||||||
|
return DEFAULT_FRIENDLY_NAME
|
||||||
|
|
||||||
|
if use_friendly_name:
|
||||||
|
name = client_info.get("friendly_name", client_info.get("name", DEFAULT_FRIENDLY_NAME))
|
||||||
|
else:
|
||||||
|
name = client_info.get("name", "Unknown")
|
||||||
|
|
||||||
|
version = client_info.get("version", "")
|
||||||
|
|
||||||
|
if version and not use_friendly_name:
|
||||||
|
return f"{name} v{version}"
|
||||||
|
else:
|
||||||
|
# For friendly names, we just return the name without version
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def get_client_friendly_name() -> str:
|
||||||
|
"""
|
||||||
|
Get the cached client's friendly name.
|
||||||
|
|
||||||
|
This is a convenience function that returns just the friendly name
|
||||||
|
from the cached client info, defaulting to "Claude" if not available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The friendly name (e.g., "Claude", "Gemini")
|
||||||
|
"""
|
||||||
|
cached_info = get_cached_client_info()
|
||||||
|
if cached_info:
|
||||||
|
return cached_info.get("friendly_name", DEFAULT_FRIENDLY_NAME)
|
||||||
|
return DEFAULT_FRIENDLY_NAME
|
||||||
|
|
||||||
|
|
||||||
|
def log_client_info(server: Any, logger_instance: Optional[logging.Logger] = None) -> None:
|
||||||
|
"""
|
||||||
|
Log client information extracted from the server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server: The MCP server instance
|
||||||
|
logger_instance: Optional logger to use (defaults to module logger)
|
||||||
|
"""
|
||||||
|
log = logger_instance or logger
|
||||||
|
|
||||||
|
client_info = get_client_info_from_context(server)
|
||||||
|
if client_info:
|
||||||
|
# Log with both raw and friendly names for debugging
|
||||||
|
raw_name = client_info.get("name", "Unknown")
|
||||||
|
friendly_name = client_info.get("friendly_name", DEFAULT_FRIENDLY_NAME)
|
||||||
|
version = client_info.get("version", "")
|
||||||
|
|
||||||
|
if raw_name != friendly_name:
|
||||||
|
log.info(f"MCP Client Connected: {friendly_name} (raw: {raw_name} v{version})")
|
||||||
|
else:
|
||||||
|
log.info(f"MCP Client Connected: {friendly_name} v{version}")
|
||||||
|
|
||||||
|
# Log to activity logger as well
|
||||||
|
try:
|
||||||
|
activity_logger = logging.getLogger("mcp_activity")
|
||||||
|
activity_logger.info(f"CLIENT_IDENTIFIED: {friendly_name} (name={raw_name}, version={version})")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
log.debug("Could not extract client info from MCP protocol")
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage in tools:
|
||||||
|
#
|
||||||
|
# from utils.client_info import get_client_friendly_name, get_cached_client_info
|
||||||
|
#
|
||||||
|
# # In a tool's execute method:
|
||||||
|
# def execute(self, arguments: dict[str, Any]) -> list[TextContent]:
|
||||||
|
# # Get the friendly name of the connected client
|
||||||
|
# client_name = get_client_friendly_name() # Returns "Claude" or "Gemini" etc.
|
||||||
|
#
|
||||||
|
# # Or get full cached info if needed
|
||||||
|
# client_info = get_cached_client_info()
|
||||||
|
# if client_info:
|
||||||
|
# raw_name = client_info['name'] # e.g., "claude-ai"
|
||||||
|
# version = client_info['version'] # e.g., "1.0.0"
|
||||||
|
# friendly = client_info['friendly_name'] # e.g., "Claude"
|
||||||
|
#
|
||||||
|
# # Customize response based on client
|
||||||
|
# if client_name == "Claude":
|
||||||
|
# response = f"Hello from Zen MCP Server to {client_name}!"
|
||||||
|
# elif client_name == "Gemini":
|
||||||
|
# response = f"Greetings {client_name}, welcome to Zen MCP Server!"
|
||||||
|
# else:
|
||||||
|
# response = f"Welcome {client_name}!"
|
||||||
1095
utils/conversation_memory.py
Normal file
1095
utils/conversation_memory.py
Normal file
File diff suppressed because it is too large
Load diff
271
utils/file_types.py
Normal file
271
utils/file_types.py
Normal file
|
|
@ -0,0 +1,271 @@
|
||||||
|
"""
|
||||||
|
File type definitions and constants for file processing
|
||||||
|
|
||||||
|
This module centralizes all file type and extension definitions used
|
||||||
|
throughout the MCP server for consistent file handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Programming language file extensions - core code files
|
||||||
|
PROGRAMMING_LANGUAGES = {
|
||||||
|
".py", # Python
|
||||||
|
".js", # JavaScript
|
||||||
|
".ts", # TypeScript
|
||||||
|
".jsx", # React JavaScript
|
||||||
|
".tsx", # React TypeScript
|
||||||
|
".java", # Java
|
||||||
|
".cpp", # C++
|
||||||
|
".c", # C
|
||||||
|
".h", # C/C++ Header
|
||||||
|
".hpp", # C++ Header
|
||||||
|
".cs", # C#
|
||||||
|
".go", # Go
|
||||||
|
".rs", # Rust
|
||||||
|
".rb", # Ruby
|
||||||
|
".php", # PHP
|
||||||
|
".swift", # Swift
|
||||||
|
".kt", # Kotlin
|
||||||
|
".scala", # Scala
|
||||||
|
".r", # R
|
||||||
|
".m", # Objective-C
|
||||||
|
".mm", # Objective-C++
|
||||||
|
}
|
||||||
|
|
||||||
|
# Script and shell file extensions
|
||||||
|
SCRIPTS = {
|
||||||
|
".sql", # SQL
|
||||||
|
".sh", # Shell
|
||||||
|
".bash", # Bash
|
||||||
|
".zsh", # Zsh
|
||||||
|
".fish", # Fish shell
|
||||||
|
".ps1", # PowerShell
|
||||||
|
".bat", # Batch
|
||||||
|
".cmd", # Command
|
||||||
|
}
|
||||||
|
|
||||||
|
# Configuration and data file extensions
|
||||||
|
CONFIGS = {
|
||||||
|
".yml", # YAML
|
||||||
|
".yaml", # YAML
|
||||||
|
".json", # JSON
|
||||||
|
".xml", # XML
|
||||||
|
".toml", # TOML
|
||||||
|
".ini", # INI
|
||||||
|
".cfg", # Config
|
||||||
|
".conf", # Config
|
||||||
|
".properties", # Properties
|
||||||
|
".env", # Environment
|
||||||
|
}
|
||||||
|
|
||||||
|
# Documentation and markup file extensions
|
||||||
|
DOCS = {
|
||||||
|
".txt", # Text
|
||||||
|
".md", # Markdown
|
||||||
|
".rst", # reStructuredText
|
||||||
|
".tex", # LaTeX
|
||||||
|
}
|
||||||
|
|
||||||
|
# Web development file extensions
|
||||||
|
WEB = {
|
||||||
|
".html", # HTML
|
||||||
|
".css", # CSS
|
||||||
|
".scss", # Sass
|
||||||
|
".sass", # Sass
|
||||||
|
".less", # Less
|
||||||
|
}
|
||||||
|
|
||||||
|
# Additional text file extensions for logs and data
|
||||||
|
TEXT_DATA = {
|
||||||
|
".log", # Log files
|
||||||
|
".csv", # CSV
|
||||||
|
".tsv", # TSV
|
||||||
|
".gitignore", # Git ignore
|
||||||
|
".dockerfile", # Dockerfile
|
||||||
|
".makefile", # Make
|
||||||
|
".cmake", # CMake
|
||||||
|
".gradle", # Gradle
|
||||||
|
".sbt", # SBT
|
||||||
|
".pom", # Maven POM
|
||||||
|
".lock", # Lock files
|
||||||
|
".changeset", # Precommit changeset
|
||||||
|
}
|
||||||
|
|
||||||
|
# Image file extensions - limited to what AI models actually support
|
||||||
|
# Based on OpenAI and Gemini supported formats: PNG, JPEG, GIF, WebP
|
||||||
|
IMAGES = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||||
|
|
||||||
|
# Binary executable and library extensions
|
||||||
|
BINARIES = {
|
||||||
|
".exe", # Windows executable
|
||||||
|
".dll", # Windows library
|
||||||
|
".so", # Linux shared object
|
||||||
|
".dylib", # macOS dynamic library
|
||||||
|
".bin", # Binary
|
||||||
|
".class", # Java class
|
||||||
|
}
|
||||||
|
|
||||||
|
# Archive and package file extensions
|
||||||
|
ARCHIVES = {
|
||||||
|
".jar",
|
||||||
|
".war",
|
||||||
|
".ear", # Java archives
|
||||||
|
".zip",
|
||||||
|
".tar",
|
||||||
|
".gz", # General archives
|
||||||
|
".7z",
|
||||||
|
".rar", # Compression
|
||||||
|
".deb",
|
||||||
|
".rpm", # Linux packages
|
||||||
|
".dmg",
|
||||||
|
".pkg", # macOS packages
|
||||||
|
}
|
||||||
|
|
||||||
|
# Derived sets for different use cases
|
||||||
|
CODE_EXTENSIONS = PROGRAMMING_LANGUAGES | SCRIPTS | CONFIGS | DOCS | WEB
|
||||||
|
PROGRAMMING_EXTENSIONS = PROGRAMMING_LANGUAGES # For line numbering
|
||||||
|
TEXT_EXTENSIONS = CODE_EXTENSIONS | TEXT_DATA
|
||||||
|
IMAGE_EXTENSIONS = IMAGES
|
||||||
|
BINARY_EXTENSIONS = BINARIES | ARCHIVES
|
||||||
|
|
||||||
|
# All extensions by category for easy access
|
||||||
|
FILE_CATEGORIES = {
|
||||||
|
"programming": PROGRAMMING_LANGUAGES,
|
||||||
|
"scripts": SCRIPTS,
|
||||||
|
"configs": CONFIGS,
|
||||||
|
"docs": DOCS,
|
||||||
|
"web": WEB,
|
||||||
|
"text_data": TEXT_DATA,
|
||||||
|
"images": IMAGES,
|
||||||
|
"binaries": BINARIES,
|
||||||
|
"archives": ARCHIVES,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_category(file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
Determine the category of a file based on its extension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Category name or "unknown" if not recognized
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
extension = Path(file_path).suffix.lower()
|
||||||
|
|
||||||
|
for category, extensions in FILE_CATEGORIES.items():
|
||||||
|
if extension in extensions:
|
||||||
|
return category
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def is_code_file(file_path: str) -> bool:
|
||||||
|
"""Check if a file is a code file (programming language)."""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
return Path(file_path).suffix.lower() in PROGRAMMING_LANGUAGES
|
||||||
|
|
||||||
|
|
||||||
|
def is_text_file(file_path: str) -> bool:
|
||||||
|
"""Check if a file is a text file."""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
return Path(file_path).suffix.lower() in TEXT_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
def is_binary_file(file_path: str) -> bool:
|
||||||
|
"""Check if a file is a binary file."""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
return Path(file_path).suffix.lower() in BINARY_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
# File-type specific token-to-byte ratios for accurate token estimation
|
||||||
|
# Based on empirical analysis of file compression characteristics and tokenization patterns
|
||||||
|
TOKEN_ESTIMATION_RATIOS = {
|
||||||
|
# Programming languages
|
||||||
|
".py": 3.5, # Python - moderate verbosity
|
||||||
|
".js": 3.2, # JavaScript - compact syntax
|
||||||
|
".ts": 3.3, # TypeScript - type annotations add tokens
|
||||||
|
".jsx": 3.1, # React JSX - JSX tags are tokenized efficiently
|
||||||
|
".tsx": 3.0, # React TSX - combination of TypeScript + JSX
|
||||||
|
".java": 3.6, # Java - verbose syntax, long identifiers
|
||||||
|
".cpp": 3.7, # C++ - preprocessor directives, templates
|
||||||
|
".c": 3.8, # C - function definitions, struct declarations
|
||||||
|
".go": 3.9, # Go - explicit error handling, package names
|
||||||
|
".rs": 3.5, # Rust - similar to Python in verbosity
|
||||||
|
".php": 3.3, # PHP - mixed HTML/code, variable prefixes
|
||||||
|
".rb": 3.6, # Ruby - descriptive method names
|
||||||
|
".swift": 3.4, # Swift - modern syntax, type inference
|
||||||
|
".kt": 3.5, # Kotlin - similar to modern languages
|
||||||
|
".scala": 3.2, # Scala - functional programming, concise
|
||||||
|
# Scripts and configuration
|
||||||
|
".sh": 4.1, # Shell scripts - commands and paths
|
||||||
|
".bat": 4.0, # Batch files - similar to shell
|
||||||
|
".ps1": 3.8, # PowerShell - more structured than bash
|
||||||
|
".sql": 3.8, # SQL - keywords and table/column names
|
||||||
|
# Data and configuration formats
|
||||||
|
".json": 2.5, # JSON - lots of punctuation and quotes
|
||||||
|
".yaml": 3.0, # YAML - structured but readable
|
||||||
|
".yml": 3.0, # YAML (alternative extension)
|
||||||
|
".xml": 2.8, # XML - tags and attributes
|
||||||
|
".toml": 3.2, # TOML - similar to config files
|
||||||
|
# Documentation and text
|
||||||
|
".md": 4.2, # Markdown - natural language with formatting
|
||||||
|
".txt": 4.0, # Plain text - mostly natural language
|
||||||
|
".rst": 4.1, # reStructuredText - documentation format
|
||||||
|
# Web technologies
|
||||||
|
".html": 2.9, # HTML - tags and attributes
|
||||||
|
".css": 3.4, # CSS - properties and selectors
|
||||||
|
# Logs and data
|
||||||
|
".log": 4.5, # Log files - timestamps, messages, stack traces
|
||||||
|
".csv": 3.1, # CSV - data with delimiters
|
||||||
|
# Infrastructure files
|
||||||
|
".dockerfile": 3.7, # Dockerfile - commands and paths
|
||||||
|
".tf": 3.5, # Terraform - infrastructure as code
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_token_estimation_ratio(file_path: str) -> float:
|
||||||
|
"""
|
||||||
|
Get the token estimation ratio for a file based on its extension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token-to-byte ratio for the file type (default: 3.5 for unknown types)
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
extension = Path(file_path).suffix.lower()
|
||||||
|
return TOKEN_ESTIMATION_RATIOS.get(extension, 3.5) # Conservative default
|
||||||
|
|
||||||
|
|
||||||
|
# MIME type mappings for image files - limited to what AI models actually support
|
||||||
|
# Based on OpenAI and Gemini supported formats: PNG, JPEG, GIF, WebP
|
||||||
|
IMAGE_MIME_TYPES = {
|
||||||
|
".jpg": "image/jpeg",
|
||||||
|
".jpeg": "image/jpeg",
|
||||||
|
".png": "image/png",
|
||||||
|
".gif": "image/gif",
|
||||||
|
".webp": "image/webp",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_mime_type(extension: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the MIME type for an image file extension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
extension: File extension (with or without leading dot)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MIME type string (default: image/jpeg for unknown extensions)
|
||||||
|
"""
|
||||||
|
if not extension.startswith("."):
|
||||||
|
extension = "." + extension
|
||||||
|
extension = extension.lower()
|
||||||
|
return IMAGE_MIME_TYPES.get(extension, "image/jpeg")
|
||||||
864
utils/file_utils.py
Normal file
864
utils/file_utils.py
Normal file
|
|
@ -0,0 +1,864 @@
|
||||||
|
"""
|
||||||
|
File reading utilities with directory support and token management
|
||||||
|
|
||||||
|
This module provides secure file access functionality for the MCP server.
|
||||||
|
It implements critical security measures to prevent unauthorized file access
|
||||||
|
and manages token limits to ensure efficient API usage.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
- Path validation and sandboxing to prevent directory traversal attacks
|
||||||
|
- Support for both individual files and recursive directory reading
|
||||||
|
- Token counting and management to stay within API limits
|
||||||
|
- Automatic file type detection and filtering
|
||||||
|
- Comprehensive error handling with informative messages
|
||||||
|
|
||||||
|
Security Model:
|
||||||
|
- All file access is restricted to PROJECT_ROOT and its subdirectories
|
||||||
|
- Absolute paths are required to prevent ambiguity
|
||||||
|
- Symbolic links are resolved to ensure they stay within bounds
|
||||||
|
|
||||||
|
CONVERSATION MEMORY INTEGRATION:
|
||||||
|
This module works with the conversation memory system to support efficient
|
||||||
|
multi-turn file handling:
|
||||||
|
|
||||||
|
1. DEDUPLICATION SUPPORT:
|
||||||
|
- File reading functions are called by conversation-aware tools
|
||||||
|
- Supports newest-first file prioritization by providing accurate token estimation
|
||||||
|
- Enables efficient file content caching and token budget management
|
||||||
|
|
||||||
|
2. TOKEN BUDGET OPTIMIZATION:
|
||||||
|
- Provides accurate token estimation for file content before reading
|
||||||
|
- Supports the dual prioritization strategy by enabling precise budget calculations
|
||||||
|
- Enables tools to make informed decisions about which files to include
|
||||||
|
|
||||||
|
3. CROSS-TOOL FILE PERSISTENCE:
|
||||||
|
- File reading results are used across different tools in conversation chains
|
||||||
|
- Consistent file access patterns support conversation continuation scenarios
|
||||||
|
- Error handling preserves conversation flow when files become unavailable
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .file_types import BINARY_EXTENSIONS, CODE_EXTENSIONS, IMAGE_EXTENSIONS, TEXT_EXTENSIONS
|
||||||
|
from .security_config import EXCLUDED_DIRS, is_dangerous_path
|
||||||
|
from .token_utils import DEFAULT_CONTEXT_WINDOW, estimate_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _is_builtin_custom_models_config(path_str: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if path points to the server's built-in custom_models.json config file.
|
||||||
|
|
||||||
|
This only matches the server's internal config, not user-specified CUSTOM_MODELS_CONFIG_PATH.
|
||||||
|
We identify the built-in config by checking if it resolves to the server's conf directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path_str: Path to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if this is the server's built-in custom_models.json config file
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
path = Path(path_str)
|
||||||
|
|
||||||
|
# Get the server root by going up from this file: utils/file_utils.py -> server_root
|
||||||
|
server_root = Path(__file__).parent.parent
|
||||||
|
builtin_config = server_root / "conf" / "custom_models.json"
|
||||||
|
|
||||||
|
# Check if the path resolves to the same file as our built-in config
|
||||||
|
# This handles both relative and absolute paths to the same file
|
||||||
|
return path.resolve() == builtin_config.resolve()
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# If path resolution fails, it's not our built-in config
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_mcp_directory(path: Path) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a directory is the MCP server's own directory.
|
||||||
|
|
||||||
|
This prevents the MCP from including its own code when scanning projects
|
||||||
|
where the MCP has been cloned as a subdirectory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Directory path to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if this is the MCP server directory or a subdirectory
|
||||||
|
"""
|
||||||
|
if not path.is_dir():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Get the directory where the MCP server is running from
|
||||||
|
# __file__ is utils/file_utils.py, so parent.parent is the MCP root
|
||||||
|
mcp_server_dir = Path(__file__).parent.parent.resolve()
|
||||||
|
|
||||||
|
# Check if the given path is the MCP server directory or a subdirectory
|
||||||
|
try:
|
||||||
|
path.resolve().relative_to(mcp_server_dir)
|
||||||
|
logger.info(f"Detected MCP server directory at {path}, will exclude from scanning")
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
# Not a subdirectory of MCP server
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_home_directory() -> Optional[Path]:
|
||||||
|
"""
|
||||||
|
Get the user's home directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User's home directory path
|
||||||
|
"""
|
||||||
|
return Path.home()
|
||||||
|
|
||||||
|
|
||||||
|
def is_home_directory_root(path: Path) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given path is the user's home directory root.
|
||||||
|
|
||||||
|
This prevents scanning the entire home directory which could include
|
||||||
|
sensitive data and non-project files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Directory path to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if this is the home directory root
|
||||||
|
"""
|
||||||
|
user_home = get_user_home_directory()
|
||||||
|
if not user_home:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
resolved_path = path.resolve()
|
||||||
|
resolved_home = user_home.resolve()
|
||||||
|
|
||||||
|
# Check if this is exactly the home directory
|
||||||
|
if resolved_path == resolved_home:
|
||||||
|
logger.warning(
|
||||||
|
f"Attempted to scan user home directory root: {path}. Please specify a subdirectory instead."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Also check common home directory patterns
|
||||||
|
path_str = str(resolved_path).lower()
|
||||||
|
home_patterns = [
|
||||||
|
"/users/", # macOS
|
||||||
|
"/home/", # Linux
|
||||||
|
"c:\\users\\", # Windows
|
||||||
|
"c:/users/", # Windows with forward slashes
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in home_patterns:
|
||||||
|
if pattern in path_str:
|
||||||
|
# Extract the user directory path
|
||||||
|
# e.g., /Users/fahad or /home/username
|
||||||
|
parts = path_str.split(pattern)
|
||||||
|
if len(parts) > 1:
|
||||||
|
# Get the part after the pattern
|
||||||
|
after_pattern = parts[1]
|
||||||
|
# Check if we're at the user's root (no subdirectories)
|
||||||
|
if "/" not in after_pattern and "\\" not in after_pattern:
|
||||||
|
logger.warning(
|
||||||
|
f"Attempted to scan user home directory root: {path}. "
|
||||||
|
f"Please specify a subdirectory instead."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error checking if path is home directory: {e}")
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def detect_file_type(file_path: str) -> str:
|
||||||
|
"""
|
||||||
|
Detect file type for appropriate processing strategy.
|
||||||
|
|
||||||
|
This function is intended for specific file type handling (e.g., image processing,
|
||||||
|
binary file analysis, or enhanced file filtering).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to analyze
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: "text", "binary", or "image"
|
||||||
|
"""
|
||||||
|
path = Path(file_path)
|
||||||
|
|
||||||
|
# Check extension first (fast)
|
||||||
|
extension = path.suffix.lower()
|
||||||
|
if extension in TEXT_EXTENSIONS:
|
||||||
|
return "text"
|
||||||
|
elif extension in IMAGE_EXTENSIONS:
|
||||||
|
return "image"
|
||||||
|
elif extension in BINARY_EXTENSIONS:
|
||||||
|
return "binary"
|
||||||
|
|
||||||
|
# Fallback: check magic bytes for text vs binary
|
||||||
|
# This is helpful for files without extensions or unknown extensions
|
||||||
|
try:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
chunk = f.read(1024)
|
||||||
|
# Simple heuristic: if we can decode as UTF-8, likely text
|
||||||
|
chunk.decode("utf-8")
|
||||||
|
return "text"
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return "binary"
|
||||||
|
except (FileNotFoundError, PermissionError) as e:
|
||||||
|
logger.warning(f"Could not access file {file_path} for type detection: {e}")
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def should_add_line_numbers(file_path: str, include_line_numbers: Optional[bool] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Determine if line numbers should be added to a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
include_line_numbers: Explicit preference, or None for auto-detection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if line numbers should be added
|
||||||
|
"""
|
||||||
|
if include_line_numbers is not None:
|
||||||
|
return include_line_numbers
|
||||||
|
|
||||||
|
# Default: DO NOT add line numbers
|
||||||
|
# Tools that want line numbers must explicitly request them
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_line_endings(content: str) -> str:
|
||||||
|
"""
|
||||||
|
Normalize line endings for consistent line numbering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: File content with potentially mixed line endings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Content with normalized LF line endings
|
||||||
|
"""
|
||||||
|
# Normalize all line endings to LF for consistent counting
|
||||||
|
return content.replace("\r\n", "\n").replace("\r", "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _add_line_numbers(content: str) -> str:
|
||||||
|
"""
|
||||||
|
Add line numbers to text content for precise referencing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Text content to number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Content with line numbers in format " 45│ actual code line"
|
||||||
|
Supports files up to 99,999 lines with dynamic width allocation
|
||||||
|
"""
|
||||||
|
# Normalize line endings first
|
||||||
|
normalized_content = _normalize_line_endings(content)
|
||||||
|
lines = normalized_content.split("\n")
|
||||||
|
|
||||||
|
# Dynamic width allocation based on total line count
|
||||||
|
# This supports files of any size by computing required width
|
||||||
|
total_lines = len(lines)
|
||||||
|
width = len(str(total_lines))
|
||||||
|
width = max(width, 4) # Minimum padding for readability
|
||||||
|
|
||||||
|
# Format with dynamic width and clear separator
|
||||||
|
numbered_lines = [f"{i + 1:{width}d}│ {line}" for i, line in enumerate(lines)]
|
||||||
|
|
||||||
|
return "\n".join(numbered_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_and_validate_path(path_str: str) -> Path:
|
||||||
|
"""
|
||||||
|
Resolves and validates a path against security policies.
|
||||||
|
|
||||||
|
This function ensures safe file access by:
|
||||||
|
1. Requiring absolute paths (no ambiguity)
|
||||||
|
2. Resolving symlinks to prevent deception
|
||||||
|
3. Blocking access to dangerous system directories
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path_str: Path string (must be absolute)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved Path object that is safe to access
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If path is not absolute or otherwise invalid
|
||||||
|
PermissionError: If path is in a dangerous location
|
||||||
|
"""
|
||||||
|
# Step 1: Create a Path object
|
||||||
|
user_path = Path(path_str)
|
||||||
|
|
||||||
|
# Step 2: Security Policy - Require absolute paths
|
||||||
|
# Relative paths could be interpreted differently depending on working directory
|
||||||
|
if not user_path.is_absolute():
|
||||||
|
raise ValueError(f"Relative paths are not supported. Please provide an absolute path.\nReceived: {path_str}")
|
||||||
|
|
||||||
|
# Step 3: Resolve the absolute path (follows symlinks, removes .. and .)
|
||||||
|
# This is critical for security as it reveals the true destination of symlinks
|
||||||
|
resolved_path = user_path.resolve()
|
||||||
|
|
||||||
|
# Step 4: Check against dangerous paths
|
||||||
|
if is_dangerous_path(resolved_path):
|
||||||
|
logger.warning(f"Access denied - dangerous path: {resolved_path}")
|
||||||
|
raise PermissionError(f"Access to system directory denied: {path_str}")
|
||||||
|
|
||||||
|
# Step 5: Check if it's the home directory root
|
||||||
|
if is_home_directory_root(resolved_path):
|
||||||
|
raise PermissionError(
|
||||||
|
f"Cannot scan entire home directory: {path_str}\n" f"Please specify a subdirectory within your home folder."
|
||||||
|
)
|
||||||
|
|
||||||
|
return resolved_path
|
||||||
|
|
||||||
|
|
||||||
|
def expand_paths(paths: list[str], extensions: Optional[set[str]] = None) -> list[str]:
|
||||||
|
"""
|
||||||
|
Expand paths to individual files, handling both files and directories.
|
||||||
|
|
||||||
|
This function recursively walks directories to find all matching files.
|
||||||
|
It automatically filters out hidden files and common non-code directories
|
||||||
|
like __pycache__ to avoid including generated or system files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paths: List of file or directory paths (must be absolute)
|
||||||
|
extensions: Optional set of file extensions to include (defaults to CODE_EXTENSIONS)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of individual file paths, sorted for consistent ordering
|
||||||
|
"""
|
||||||
|
if extensions is None:
|
||||||
|
extensions = CODE_EXTENSIONS
|
||||||
|
|
||||||
|
expanded_files = []
|
||||||
|
seen = set()
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
try:
|
||||||
|
# Validate each path for security before processing
|
||||||
|
path_obj = resolve_and_validate_path(path)
|
||||||
|
except (ValueError, PermissionError):
|
||||||
|
# Skip invalid paths silently to allow partial success
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not path_obj.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Safety checks for directory scanning
|
||||||
|
if path_obj.is_dir():
|
||||||
|
# Check 1: Prevent scanning user's home directory root
|
||||||
|
if is_home_directory_root(path_obj):
|
||||||
|
logger.warning(f"Skipping home directory root: {path}. Please specify a project subdirectory instead.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check 2: Skip if this is the MCP's own directory
|
||||||
|
if is_mcp_directory(path_obj):
|
||||||
|
logger.info(
|
||||||
|
f"Skipping MCP server directory: {path}. The MCP server code is excluded from project scans."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if path_obj.is_file():
|
||||||
|
# Add file directly
|
||||||
|
if str(path_obj) not in seen:
|
||||||
|
expanded_files.append(str(path_obj))
|
||||||
|
seen.add(str(path_obj))
|
||||||
|
|
||||||
|
elif path_obj.is_dir():
|
||||||
|
# Walk directory recursively to find all files
|
||||||
|
for root, dirs, files in os.walk(path_obj):
|
||||||
|
# Filter directories in-place to skip hidden and excluded directories
|
||||||
|
# This prevents descending into .git, .venv, __pycache__, node_modules, etc.
|
||||||
|
original_dirs = dirs[:]
|
||||||
|
dirs[:] = []
|
||||||
|
for d in original_dirs:
|
||||||
|
# Skip hidden directories
|
||||||
|
if d.startswith("."):
|
||||||
|
continue
|
||||||
|
# Skip excluded directories
|
||||||
|
if d in EXCLUDED_DIRS:
|
||||||
|
continue
|
||||||
|
# Skip MCP directories found during traversal
|
||||||
|
dir_path = Path(root) / d
|
||||||
|
if is_mcp_directory(dir_path):
|
||||||
|
logger.debug(f"Skipping MCP directory during traversal: {dir_path}")
|
||||||
|
continue
|
||||||
|
dirs.append(d)
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
# Skip hidden files (e.g., .DS_Store, .gitignore)
|
||||||
|
if file.startswith("."):
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_path = Path(root) / file
|
||||||
|
|
||||||
|
# Filter by extension if specified
|
||||||
|
if not extensions or file_path.suffix.lower() in extensions:
|
||||||
|
full_path = str(file_path)
|
||||||
|
# Use set to prevent duplicates
|
||||||
|
if full_path not in seen:
|
||||||
|
expanded_files.append(full_path)
|
||||||
|
seen.add(full_path)
|
||||||
|
|
||||||
|
# Sort for consistent ordering across different runs
|
||||||
|
# This makes output predictable and easier to debug
|
||||||
|
expanded_files.sort()
|
||||||
|
return expanded_files
|
||||||
|
|
||||||
|
|
||||||
|
def read_file_content(
|
||||||
|
file_path: str, max_size: int = 1_000_000, *, include_line_numbers: Optional[bool] = None
|
||||||
|
) -> tuple[str, int]:
|
||||||
|
"""
|
||||||
|
Read a single file and format it for inclusion in AI prompts.
|
||||||
|
|
||||||
|
This function handles various error conditions gracefully and always
|
||||||
|
returns formatted content, even for errors. This ensures the AI model
|
||||||
|
gets context about what files were attempted but couldn't be read.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to file (must be absolute)
|
||||||
|
max_size: Maximum file size to read (default 1MB to prevent memory issues)
|
||||||
|
include_line_numbers: Whether to add line numbers. If None, auto-detects based on file type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (formatted_content, estimated_tokens)
|
||||||
|
Content is wrapped with clear delimiters for AI parsing
|
||||||
|
"""
|
||||||
|
logger.debug(f"[FILES] read_file_content called for: {file_path}")
|
||||||
|
try:
|
||||||
|
# Validate path security before any file operations
|
||||||
|
path = resolve_and_validate_path(file_path)
|
||||||
|
logger.debug(f"[FILES] Path validated and resolved: {path}")
|
||||||
|
except (ValueError, PermissionError) as e:
|
||||||
|
# Return error in a format that provides context to the AI
|
||||||
|
logger.debug(f"[FILES] Path validation failed for {file_path}: {type(e).__name__}: {e}")
|
||||||
|
error_msg = str(e)
|
||||||
|
content = f"\n--- ERROR ACCESSING FILE: {file_path} ---\nError: {error_msg}\n--- END FILE ---\n"
|
||||||
|
tokens = estimate_tokens(content)
|
||||||
|
logger.debug(f"[FILES] Returning error content for {file_path}: {tokens} tokens")
|
||||||
|
return content, tokens
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Validate file existence and type
|
||||||
|
if not path.exists():
|
||||||
|
logger.debug(f"[FILES] File does not exist: {file_path}")
|
||||||
|
content = f"\n--- FILE NOT FOUND: {file_path} ---\nError: File does not exist\n--- END FILE ---\n"
|
||||||
|
return content, estimate_tokens(content)
|
||||||
|
|
||||||
|
if not path.is_file():
|
||||||
|
logger.debug(f"[FILES] Path is not a file: {file_path}")
|
||||||
|
content = f"\n--- NOT A FILE: {file_path} ---\nError: Path is not a file\n--- END FILE ---\n"
|
||||||
|
return content, estimate_tokens(content)
|
||||||
|
|
||||||
|
# Check file size to prevent memory exhaustion
|
||||||
|
file_size = path.stat().st_size
|
||||||
|
logger.debug(f"[FILES] File size for {file_path}: {file_size:,} bytes")
|
||||||
|
if file_size > max_size:
|
||||||
|
logger.debug(f"[FILES] File too large: {file_path} ({file_size:,} > {max_size:,} bytes)")
|
||||||
|
content = f"\n--- FILE TOO LARGE: {file_path} ---\nFile size: {file_size:,} bytes (max: {max_size:,})\n--- END FILE ---\n"
|
||||||
|
return content, estimate_tokens(content)
|
||||||
|
|
||||||
|
# Determine if we should add line numbers
|
||||||
|
add_line_numbers = should_add_line_numbers(file_path, include_line_numbers)
|
||||||
|
logger.debug(f"[FILES] Line numbers for {file_path}: {'enabled' if add_line_numbers else 'disabled'}")
|
||||||
|
|
||||||
|
# Read the file with UTF-8 encoding, replacing invalid characters
|
||||||
|
# This ensures we can handle files with mixed encodings
|
||||||
|
logger.debug(f"[FILES] Reading file content for {file_path}")
|
||||||
|
with open(path, encoding="utf-8", errors="replace") as f:
|
||||||
|
file_content = f.read()
|
||||||
|
|
||||||
|
logger.debug(f"[FILES] Successfully read {len(file_content)} characters from {file_path}")
|
||||||
|
|
||||||
|
# Add line numbers if requested or auto-detected
|
||||||
|
if add_line_numbers:
|
||||||
|
file_content = _add_line_numbers(file_content)
|
||||||
|
logger.debug(f"[FILES] Added line numbers to {file_path}")
|
||||||
|
else:
|
||||||
|
# Still normalize line endings for consistency
|
||||||
|
file_content = _normalize_line_endings(file_content)
|
||||||
|
|
||||||
|
# Format with clear delimiters that help the AI understand file boundaries
|
||||||
|
# Using consistent markers makes it easier for the model to parse
|
||||||
|
# NOTE: These markers ("--- BEGIN FILE: ... ---") are distinct from git diff markers
|
||||||
|
# ("--- BEGIN DIFF: ... ---") to allow AI to distinguish between complete file content
|
||||||
|
# vs. partial diff content when files appear in both sections
|
||||||
|
formatted = f"\n--- BEGIN FILE: {file_path} ---\n{file_content}\n--- END FILE: {file_path} ---\n"
|
||||||
|
tokens = estimate_tokens(formatted)
|
||||||
|
logger.debug(f"[FILES] Formatted content for {file_path}: {len(formatted)} chars, {tokens} tokens")
|
||||||
|
return formatted, tokens
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"[FILES] Exception reading file {file_path}: {type(e).__name__}: {e}")
|
||||||
|
content = f"\n--- ERROR READING FILE: {file_path} ---\nError: {str(e)}\n--- END FILE ---\n"
|
||||||
|
tokens = estimate_tokens(content)
|
||||||
|
logger.debug(f"[FILES] Returning error content for {file_path}: {tokens} tokens")
|
||||||
|
return content, tokens
|
||||||
|
|
||||||
|
|
||||||
|
def read_files(
|
||||||
|
file_paths: list[str],
|
||||||
|
code: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
reserve_tokens: int = 50_000,
|
||||||
|
*,
|
||||||
|
include_line_numbers: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Read multiple files and optional direct code with smart token management.
|
||||||
|
|
||||||
|
This function implements intelligent token budgeting to maximize the amount
|
||||||
|
of relevant content that can be included in an AI prompt while staying
|
||||||
|
within token limits. It prioritizes direct code and reads files until
|
||||||
|
the token budget is exhausted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_paths: List of file or directory paths (absolute paths required)
|
||||||
|
code: Optional direct code to include (prioritized over files)
|
||||||
|
max_tokens: Maximum tokens to use (defaults to DEFAULT_CONTEXT_WINDOW)
|
||||||
|
reserve_tokens: Tokens to reserve for prompt and response (default 50K)
|
||||||
|
include_line_numbers: Whether to add line numbers to file content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: All file contents formatted for AI consumption
|
||||||
|
"""
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = DEFAULT_CONTEXT_WINDOW
|
||||||
|
|
||||||
|
logger.debug(f"[FILES] read_files called with {len(file_paths)} paths")
|
||||||
|
logger.debug(
|
||||||
|
f"[FILES] Token budget: max={max_tokens:,}, reserve={reserve_tokens:,}, available={max_tokens - reserve_tokens:,}"
|
||||||
|
)
|
||||||
|
|
||||||
|
content_parts = []
|
||||||
|
total_tokens = 0
|
||||||
|
available_tokens = max_tokens - reserve_tokens
|
||||||
|
|
||||||
|
files_skipped = []
|
||||||
|
|
||||||
|
# Priority 1: Handle direct code if provided
|
||||||
|
# Direct code is prioritized because it's explicitly provided by the user
|
||||||
|
if code:
|
||||||
|
formatted_code = f"\n--- BEGIN DIRECT CODE ---\n{code}\n--- END DIRECT CODE ---\n"
|
||||||
|
code_tokens = estimate_tokens(formatted_code)
|
||||||
|
|
||||||
|
if code_tokens <= available_tokens:
|
||||||
|
content_parts.append(formatted_code)
|
||||||
|
total_tokens += code_tokens
|
||||||
|
available_tokens -= code_tokens
|
||||||
|
|
||||||
|
# Priority 2: Process file paths
|
||||||
|
if file_paths:
|
||||||
|
# Expand directories to get all individual files
|
||||||
|
logger.debug(f"[FILES] Expanding {len(file_paths)} file paths")
|
||||||
|
all_files = expand_paths(file_paths)
|
||||||
|
logger.debug(f"[FILES] After expansion: {len(all_files)} individual files")
|
||||||
|
|
||||||
|
if not all_files and file_paths:
|
||||||
|
# No files found but paths were provided
|
||||||
|
logger.debug("[FILES] No files found from provided paths")
|
||||||
|
content_parts.append(f"\n--- NO FILES FOUND ---\nProvided paths: {', '.join(file_paths)}\n--- END ---\n")
|
||||||
|
else:
|
||||||
|
# Read files sequentially until token limit is reached
|
||||||
|
logger.debug(f"[FILES] Reading {len(all_files)} files with token budget {available_tokens:,}")
|
||||||
|
for i, file_path in enumerate(all_files):
|
||||||
|
if total_tokens >= available_tokens:
|
||||||
|
logger.debug(f"[FILES] Token budget exhausted, skipping remaining {len(all_files) - i} files")
|
||||||
|
files_skipped.extend(all_files[i:])
|
||||||
|
break
|
||||||
|
|
||||||
|
file_content, file_tokens = read_file_content(file_path, include_line_numbers=include_line_numbers)
|
||||||
|
logger.debug(f"[FILES] File {file_path}: {file_tokens:,} tokens")
|
||||||
|
|
||||||
|
# Check if adding this file would exceed limit
|
||||||
|
if total_tokens + file_tokens <= available_tokens:
|
||||||
|
content_parts.append(file_content)
|
||||||
|
total_tokens += file_tokens
|
||||||
|
logger.debug(f"[FILES] Added file {file_path}, total tokens: {total_tokens:,}")
|
||||||
|
else:
|
||||||
|
# File too large for remaining budget
|
||||||
|
logger.debug(
|
||||||
|
f"[FILES] File {file_path} too large for remaining budget ({file_tokens:,} tokens, {available_tokens - total_tokens:,} remaining)"
|
||||||
|
)
|
||||||
|
files_skipped.append(file_path)
|
||||||
|
|
||||||
|
# Add informative note about skipped files to help users understand
|
||||||
|
# what was omitted and why
|
||||||
|
if files_skipped:
|
||||||
|
logger.debug(f"[FILES] {len(files_skipped)} files skipped due to token limits")
|
||||||
|
skip_note = "\n\n--- SKIPPED FILES (TOKEN LIMIT) ---\n"
|
||||||
|
skip_note += f"Total skipped: {len(files_skipped)}\n"
|
||||||
|
# Show first 10 skipped files as examples
|
||||||
|
for _i, file_path in enumerate(files_skipped[:10]):
|
||||||
|
skip_note += f" - {file_path}\n"
|
||||||
|
if len(files_skipped) > 10:
|
||||||
|
skip_note += f" ... and {len(files_skipped) - 10} more\n"
|
||||||
|
skip_note += "--- END SKIPPED FILES ---\n"
|
||||||
|
content_parts.append(skip_note)
|
||||||
|
|
||||||
|
result = "\n\n".join(content_parts) if content_parts else ""
|
||||||
|
logger.debug(f"[FILES] read_files complete: {len(result)} chars, {total_tokens:,} tokens used")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_file_tokens(file_path: str) -> int:
|
||||||
|
"""
|
||||||
|
Estimate tokens for a file using file-type aware ratios.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count for the file
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not os.path.exists(file_path) or not os.path.isfile(file_path):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
file_size = os.path.getsize(file_path)
|
||||||
|
|
||||||
|
# Get the appropriate ratio for this file type
|
||||||
|
from .file_types import get_token_estimation_ratio
|
||||||
|
|
||||||
|
ratio = get_token_estimation_ratio(file_path)
|
||||||
|
|
||||||
|
return int(file_size / ratio)
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def check_files_size_limit(files: list[str], max_tokens: int, threshold_percent: float = 1.0) -> tuple[bool, int, int]:
|
||||||
|
"""
|
||||||
|
Check if a list of files would exceed token limits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: List of file paths to check
|
||||||
|
max_tokens: Maximum allowed tokens
|
||||||
|
threshold_percent: Percentage of max_tokens to use as threshold (0.0-1.0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (within_limit, total_estimated_tokens, file_count)
|
||||||
|
"""
|
||||||
|
if not files:
|
||||||
|
return True, 0, 0
|
||||||
|
|
||||||
|
total_estimated_tokens = 0
|
||||||
|
file_count = 0
|
||||||
|
threshold = int(max_tokens * threshold_percent)
|
||||||
|
|
||||||
|
for file_path in files:
|
||||||
|
try:
|
||||||
|
estimated_tokens = estimate_file_tokens(file_path)
|
||||||
|
total_estimated_tokens += estimated_tokens
|
||||||
|
if estimated_tokens > 0: # Only count accessible files
|
||||||
|
file_count += 1
|
||||||
|
except Exception:
|
||||||
|
# Skip files that can't be accessed for size check
|
||||||
|
continue
|
||||||
|
|
||||||
|
within_limit = total_estimated_tokens <= threshold
|
||||||
|
return within_limit, total_estimated_tokens, file_count
|
||||||
|
|
||||||
|
|
||||||
|
def read_json_file(file_path: str) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Read and parse a JSON file with proper error handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the JSON file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed JSON data as dict, or None if file doesn't exist or invalid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
with open(file_path, encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except (json.JSONDecodeError, OSError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def write_json_file(file_path: str, data: dict, indent: int = 2) -> bool:
|
||||||
|
"""
|
||||||
|
Write data to a JSON file with proper formatting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to write the JSON file
|
||||||
|
data: Dictionary data to serialize
|
||||||
|
indent: JSON indentation level
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
|
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(data, f, indent=indent, ensure_ascii=False)
|
||||||
|
return True
|
||||||
|
except (OSError, TypeError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_size(file_path: str) -> int:
|
||||||
|
"""
|
||||||
|
Get file size in bytes with proper error handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File size in bytes, or 0 if file doesn't exist or error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if os.path.exists(file_path) and os.path.isfile(file_path):
|
||||||
|
return os.path.getsize(file_path)
|
||||||
|
return 0
|
||||||
|
except OSError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_directory_exists(file_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Ensure the parent directory of a file path exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to file (directory will be created for parent)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if directory exists or was created, False on error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
directory = os.path.dirname(file_path)
|
||||||
|
if directory:
|
||||||
|
os.makedirs(directory, exist_ok=True)
|
||||||
|
return True
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_text_file(file_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a file is likely a text file based on extension and content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if file appears to be text, False otherwise
|
||||||
|
"""
|
||||||
|
from .file_types import is_text_file as check_text_type
|
||||||
|
|
||||||
|
return check_text_type(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def read_file_safely(file_path: str, max_size: int = 10 * 1024 * 1024) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Read a file with size limits and encoding handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
max_size: Maximum file size in bytes (default 10MB)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content as string, or None if file too large or unreadable
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not os.path.exists(file_path) or not os.path.isfile(file_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
file_size = os.path.getsize(file_path)
|
||||||
|
if file_size > max_size:
|
||||||
|
return None
|
||||||
|
|
||||||
|
with open(file_path, encoding="utf-8", errors="ignore") as f:
|
||||||
|
return f.read()
|
||||||
|
except OSError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def check_total_file_size(files: list[str], model_name: str) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Check if total file sizes would exceed token threshold before embedding.
|
||||||
|
|
||||||
|
IMPORTANT: This performs STRICT REJECTION at MCP boundary.
|
||||||
|
No partial inclusion - either all files fit or request is rejected.
|
||||||
|
This forces Claude to make better file selection decisions.
|
||||||
|
|
||||||
|
This function MUST be called with the effective model name (after resolution).
|
||||||
|
It should never receive 'auto' or None - model resolution happens earlier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: List of file paths to check
|
||||||
|
model_name: The resolved model name for context-aware thresholds (required)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with `code_too_large` response if too large, None if acceptable
|
||||||
|
"""
|
||||||
|
if not files:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Validate we have a proper model name (not auto or None)
|
||||||
|
if not model_name or model_name.lower() == "auto":
|
||||||
|
raise ValueError(
|
||||||
|
f"check_total_file_size called with unresolved model: '{model_name}'. "
|
||||||
|
"Model must be resolved before file size checking."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"File size check: Using model '{model_name}' for token limit calculation")
|
||||||
|
|
||||||
|
from utils.model_context import ModelContext
|
||||||
|
|
||||||
|
model_context = ModelContext(model_name)
|
||||||
|
token_allocation = model_context.calculate_token_allocation()
|
||||||
|
|
||||||
|
# Dynamic threshold based on model capacity
|
||||||
|
context_window = token_allocation.total_tokens
|
||||||
|
if context_window >= 1_000_000: # Gemini-class models
|
||||||
|
threshold_percent = 0.8 # Can be more generous
|
||||||
|
elif context_window >= 500_000: # Mid-range models
|
||||||
|
threshold_percent = 0.7 # Moderate
|
||||||
|
else: # OpenAI-class models (200K)
|
||||||
|
threshold_percent = 0.6 # Conservative
|
||||||
|
|
||||||
|
max_file_tokens = int(token_allocation.file_tokens * threshold_percent)
|
||||||
|
|
||||||
|
# Use centralized file size checking (threshold already applied to max_file_tokens)
|
||||||
|
within_limit, total_estimated_tokens, file_count = check_files_size_limit(files, max_file_tokens)
|
||||||
|
|
||||||
|
if not within_limit:
|
||||||
|
return {
|
||||||
|
"status": "code_too_large",
|
||||||
|
"content": (
|
||||||
|
f"The selected files are too large for analysis "
|
||||||
|
f"(estimated {total_estimated_tokens:,} tokens, limit {max_file_tokens:,}). "
|
||||||
|
f"Please select fewer, more specific files that are most relevant "
|
||||||
|
f"to your question, then invoke the tool again."
|
||||||
|
),
|
||||||
|
"content_type": "text",
|
||||||
|
"metadata": {
|
||||||
|
"total_estimated_tokens": total_estimated_tokens,
|
||||||
|
"limit": max_file_tokens,
|
||||||
|
"file_count": file_count,
|
||||||
|
"threshold_percent": threshold_percent,
|
||||||
|
"model_context_window": context_window,
|
||||||
|
"model_name": model_name,
|
||||||
|
"instructions": "Reduce file selection and try again - all files must fit within budget. If this persists, please use a model with a larger context window where available.",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return None # Proceed with ALL files
|
||||||
94
utils/image_utils.py
Normal file
94
utils/image_utils.py
Normal file
|
|
@ -0,0 +1,94 @@
|
||||||
|
"""Utility helpers for validating image inputs."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import binascii
|
||||||
|
import os
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
from utils.file_types import IMAGES, get_image_mime_type
|
||||||
|
|
||||||
|
DEFAULT_MAX_IMAGE_SIZE_MB = 20.0
|
||||||
|
|
||||||
|
__all__ = ["DEFAULT_MAX_IMAGE_SIZE_MB", "validate_image"]
|
||||||
|
|
||||||
|
|
||||||
|
def _valid_mime_types() -> Iterable[str]:
|
||||||
|
"""Return the MIME types permitted by the IMAGES whitelist."""
|
||||||
|
return (get_image_mime_type(ext) for ext in IMAGES)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_image(image_path: str, max_size_mb: float = None) -> tuple[bytes, str]:
|
||||||
|
"""Validate a user-supplied image path or data URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: Either a filesystem path or a data URL.
|
||||||
|
max_size_mb: Optional size limit (defaults to ``DEFAULT_MAX_IMAGE_SIZE_MB``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple ``(image_bytes, mime_type)`` ready for upstream providers.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: When the image is missing, malformed, or exceeds limits.
|
||||||
|
"""
|
||||||
|
if max_size_mb is None:
|
||||||
|
max_size_mb = DEFAULT_MAX_IMAGE_SIZE_MB
|
||||||
|
|
||||||
|
if image_path.startswith("data:"):
|
||||||
|
return _validate_data_url(image_path, max_size_mb)
|
||||||
|
|
||||||
|
return _validate_file_path(image_path, max_size_mb)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_data_url(image_data_url: str, max_size_mb: float) -> tuple[bytes, str]:
|
||||||
|
"""Validate a data URL and return image bytes plus MIME type."""
|
||||||
|
try:
|
||||||
|
header, data = image_data_url.split(",", 1)
|
||||||
|
mime_type = header.split(";")[0].split(":")[1]
|
||||||
|
except (ValueError, IndexError) as exc:
|
||||||
|
raise ValueError(f"Invalid data URL format: {exc}")
|
||||||
|
|
||||||
|
valid_mime_types = list(_valid_mime_types())
|
||||||
|
if mime_type not in valid_mime_types:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported image type: {mime}. Supported types: {supported}".format(
|
||||||
|
mime=mime_type, supported=", ".join(valid_mime_types)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_bytes = base64.b64decode(data)
|
||||||
|
except binascii.Error as exc:
|
||||||
|
raise ValueError(f"Invalid base64 data: {exc}")
|
||||||
|
|
||||||
|
_validate_size(image_bytes, max_size_mb)
|
||||||
|
return image_bytes, mime_type
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_file_path(file_path: str, max_size_mb: float) -> tuple[bytes, str]:
|
||||||
|
"""Validate an image loaded from the filesystem."""
|
||||||
|
try:
|
||||||
|
with open(file_path, "rb") as handle:
|
||||||
|
image_bytes = handle.read()
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise ValueError(f"Image file not found: {file_path}")
|
||||||
|
except OSError as exc:
|
||||||
|
raise ValueError(f"Failed to read image file: {exc}")
|
||||||
|
|
||||||
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
if ext not in IMAGES:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported image format: {ext}. Supported formats: {supported}".format(
|
||||||
|
ext=ext, supported=", ".join(sorted(IMAGES))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mime_type = get_image_mime_type(ext)
|
||||||
|
_validate_size(image_bytes, max_size_mb)
|
||||||
|
return image_bytes, mime_type
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_size(image_bytes: bytes, max_size_mb: float) -> None:
|
||||||
|
"""Ensure the image does not exceed the configured size limit."""
|
||||||
|
size_mb = len(image_bytes) / (1024 * 1024)
|
||||||
|
if size_mb > max_size_mb:
|
||||||
|
raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)")
|
||||||
180
utils/model_context.py
Normal file
180
utils/model_context.py
Normal file
|
|
@ -0,0 +1,180 @@
|
||||||
|
"""
|
||||||
|
Model context management for dynamic token allocation.
|
||||||
|
|
||||||
|
This module provides a clean abstraction for model-specific token management,
|
||||||
|
ensuring that token limits are properly calculated based on the current model
|
||||||
|
being used, not global constants.
|
||||||
|
|
||||||
|
CONVERSATION MEMORY INTEGRATION:
|
||||||
|
This module works closely with the conversation memory system to provide
|
||||||
|
optimal token allocation for multi-turn conversations:
|
||||||
|
|
||||||
|
1. DUAL PRIORITIZATION STRATEGY SUPPORT:
|
||||||
|
- Provides separate token budgets for conversation history vs. files
|
||||||
|
- Enables the conversation memory system to apply newest-first prioritization
|
||||||
|
- Ensures optimal balance between context preservation and new content
|
||||||
|
|
||||||
|
2. MODEL-SPECIFIC ALLOCATION:
|
||||||
|
- Dynamic allocation based on model capabilities (context window size)
|
||||||
|
- Conservative allocation for smaller models (O3: 200K context)
|
||||||
|
- Generous allocation for larger models (Gemini: 1M+ context)
|
||||||
|
- Adapts token distribution ratios based on model capacity
|
||||||
|
|
||||||
|
3. CROSS-TOOL CONSISTENCY:
|
||||||
|
- Provides consistent token budgets across different tools
|
||||||
|
- Enables seamless conversation continuation between tools
|
||||||
|
- Supports conversation reconstruction with proper budget management
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from config import DEFAULT_MODEL
|
||||||
|
from providers import ModelCapabilities, ModelProviderRegistry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenAllocation:
|
||||||
|
"""Token allocation strategy for a model."""
|
||||||
|
|
||||||
|
total_tokens: int
|
||||||
|
content_tokens: int
|
||||||
|
response_tokens: int
|
||||||
|
file_tokens: int
|
||||||
|
history_tokens: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_for_prompt(self) -> int:
|
||||||
|
"""Tokens available for the actual prompt after allocations."""
|
||||||
|
return self.content_tokens - self.file_tokens - self.history_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class ModelContext:
|
||||||
|
"""
|
||||||
|
Encapsulates model-specific information and token calculations.
|
||||||
|
|
||||||
|
This class provides a single source of truth for all model-related
|
||||||
|
token calculations, ensuring consistency across the system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str, model_option: Optional[str] = None):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.model_option = model_option # Store optional model option (e.g., "for", "against", etc.)
|
||||||
|
self._provider = None
|
||||||
|
self._capabilities = None
|
||||||
|
self._token_allocation = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider(self):
|
||||||
|
"""Get the model provider lazily."""
|
||||||
|
if self._provider is None:
|
||||||
|
self._provider = ModelProviderRegistry.get_provider_for_model(self.model_name)
|
||||||
|
if not self._provider:
|
||||||
|
available_models = ModelProviderRegistry.get_available_model_names()
|
||||||
|
if available_models:
|
||||||
|
available_text = ", ".join(available_models)
|
||||||
|
else:
|
||||||
|
available_text = (
|
||||||
|
"No models detected. Configure provider credentials or set DEFAULT_MODEL to a valid option."
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{self.model_name}' is not available with current API keys. Available models: {available_text}."
|
||||||
|
)
|
||||||
|
return self._provider
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capabilities(self) -> ModelCapabilities:
|
||||||
|
"""Get model capabilities lazily."""
|
||||||
|
if self._capabilities is None:
|
||||||
|
self._capabilities = self.provider.get_capabilities(self.model_name)
|
||||||
|
return self._capabilities
|
||||||
|
|
||||||
|
def calculate_token_allocation(self, reserved_for_response: Optional[int] = None) -> TokenAllocation:
|
||||||
|
"""
|
||||||
|
Calculate token allocation based on model capacity and conversation requirements.
|
||||||
|
|
||||||
|
This method implements the core token budget calculation that supports the
|
||||||
|
dual prioritization strategy used in conversation memory and file processing:
|
||||||
|
|
||||||
|
TOKEN ALLOCATION STRATEGY:
|
||||||
|
1. CONTENT vs RESPONSE SPLIT:
|
||||||
|
- Smaller models (< 300K): 60% content, 40% response (conservative)
|
||||||
|
- Larger models (≥ 300K): 80% content, 20% response (generous)
|
||||||
|
|
||||||
|
2. CONTENT SUB-ALLOCATION:
|
||||||
|
- File tokens: 30-40% of content budget for newest file versions
|
||||||
|
- History tokens: 40-50% of content budget for conversation context
|
||||||
|
- Remaining: Available for tool-specific prompt content
|
||||||
|
|
||||||
|
3. CONVERSATION MEMORY INTEGRATION:
|
||||||
|
- History allocation enables conversation reconstruction in reconstruct_thread_context()
|
||||||
|
- File allocation supports newest-first file prioritization in tools
|
||||||
|
- Remaining budget passed to tools via _remaining_tokens parameter
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reserved_for_response: Override response token reservation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenAllocation with calculated budgets for dual prioritization strategy
|
||||||
|
"""
|
||||||
|
total_tokens = self.capabilities.context_window
|
||||||
|
|
||||||
|
# Dynamic allocation based on model capacity
|
||||||
|
if total_tokens < 300_000:
|
||||||
|
# Smaller context models (O3): Conservative allocation
|
||||||
|
content_ratio = 0.6 # 60% for content
|
||||||
|
response_ratio = 0.4 # 40% for response
|
||||||
|
file_ratio = 0.3 # 30% of content for files
|
||||||
|
history_ratio = 0.5 # 50% of content for history
|
||||||
|
else:
|
||||||
|
# Larger context models (Gemini): More generous allocation
|
||||||
|
content_ratio = 0.8 # 80% for content
|
||||||
|
response_ratio = 0.2 # 20% for response
|
||||||
|
file_ratio = 0.4 # 40% of content for files
|
||||||
|
history_ratio = 0.4 # 40% of content for history
|
||||||
|
|
||||||
|
# Calculate allocations
|
||||||
|
content_tokens = int(total_tokens * content_ratio)
|
||||||
|
response_tokens = reserved_for_response or int(total_tokens * response_ratio)
|
||||||
|
|
||||||
|
# Sub-allocations within content budget
|
||||||
|
file_tokens = int(content_tokens * file_ratio)
|
||||||
|
history_tokens = int(content_tokens * history_ratio)
|
||||||
|
|
||||||
|
allocation = TokenAllocation(
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
content_tokens=content_tokens,
|
||||||
|
response_tokens=response_tokens,
|
||||||
|
file_tokens=file_tokens,
|
||||||
|
history_tokens=history_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Token allocation for {self.model_name}:")
|
||||||
|
logger.debug(f" Total: {allocation.total_tokens:,}")
|
||||||
|
logger.debug(f" Content: {allocation.content_tokens:,} ({content_ratio:.0%})")
|
||||||
|
logger.debug(f" Response: {allocation.response_tokens:,} ({response_ratio:.0%})")
|
||||||
|
logger.debug(f" Files: {allocation.file_tokens:,} ({file_ratio:.0%} of content)")
|
||||||
|
logger.debug(f" History: {allocation.history_tokens:,} ({history_ratio:.0%} of content)")
|
||||||
|
|
||||||
|
return allocation
|
||||||
|
|
||||||
|
def estimate_tokens(self, text: str) -> int:
|
||||||
|
"""
|
||||||
|
Estimate token count for text using model-specific tokenizer.
|
||||||
|
|
||||||
|
For now, uses simple estimation. Can be enhanced with model-specific
|
||||||
|
tokenizers (tiktoken for OpenAI, etc.) in the future.
|
||||||
|
"""
|
||||||
|
# TODO: Integrate model-specific tokenizers
|
||||||
|
# For now, use conservative estimation
|
||||||
|
return len(text) // 3 # Conservative estimate
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_arguments(cls, arguments: dict[str, Any]) -> "ModelContext":
|
||||||
|
"""Create ModelContext from tool arguments."""
|
||||||
|
model_name = arguments.get("model") or DEFAULT_MODEL
|
||||||
|
return cls(model_name)
|
||||||
226
utils/model_restrictions.py
Normal file
226
utils/model_restrictions.py
Normal file
|
|
@ -0,0 +1,226 @@
|
||||||
|
"""
|
||||||
|
Model Restriction Service
|
||||||
|
|
||||||
|
This module provides centralized management of model usage restrictions
|
||||||
|
based on environment variables. It allows organizations to limit which
|
||||||
|
models can be used from each provider for cost control, compliance, or
|
||||||
|
standardization purposes.
|
||||||
|
|
||||||
|
Environment Variables:
|
||||||
|
- OPENAI_ALLOWED_MODELS: Comma-separated list of allowed OpenAI models
|
||||||
|
- GOOGLE_ALLOWED_MODELS: Comma-separated list of allowed Gemini models
|
||||||
|
- XAI_ALLOWED_MODELS: Comma-separated list of allowed X.AI GROK models
|
||||||
|
- OPENROUTER_ALLOWED_MODELS: Comma-separated list of allowed OpenRouter models
|
||||||
|
- DIAL_ALLOWED_MODELS: Comma-separated list of allowed DIAL models
|
||||||
|
|
||||||
|
Example:
|
||||||
|
OPENAI_ALLOWED_MODELS=o3-mini,o4-mini
|
||||||
|
GOOGLE_ALLOWED_MODELS=flash
|
||||||
|
XAI_ALLOWED_MODELS=grok-3,grok-3-fast
|
||||||
|
OPENROUTER_ALLOWED_MODELS=opus,sonnet,mistral
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from providers.shared import ProviderType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRestrictionService:
|
||||||
|
"""Central authority for environment-driven model allowlists.
|
||||||
|
|
||||||
|
Role
|
||||||
|
Interpret ``*_ALLOWED_MODELS`` environment variables, keep their
|
||||||
|
entries normalised (lowercase), and answer whether a provider/model
|
||||||
|
pairing is permitted.
|
||||||
|
|
||||||
|
Responsibilities
|
||||||
|
* Parse, cache, and expose per-provider restriction sets
|
||||||
|
* Validate configuration by cross-checking each entry against the
|
||||||
|
provider’s alias-aware model list
|
||||||
|
* Offer helper methods such as ``is_allowed`` and ``filter_models`` to
|
||||||
|
enforce policy everywhere model names appear (tool selection, CLI
|
||||||
|
commands, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Environment variable names
|
||||||
|
ENV_VARS = {
|
||||||
|
ProviderType.OPENAI: "OPENAI_ALLOWED_MODELS",
|
||||||
|
ProviderType.GOOGLE: "GOOGLE_ALLOWED_MODELS",
|
||||||
|
ProviderType.XAI: "XAI_ALLOWED_MODELS",
|
||||||
|
ProviderType.OPENROUTER: "OPENROUTER_ALLOWED_MODELS",
|
||||||
|
ProviderType.DIAL: "DIAL_ALLOWED_MODELS",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the restriction service by loading from environment."""
|
||||||
|
self.restrictions: dict[ProviderType, set[str]] = {}
|
||||||
|
self._load_from_env()
|
||||||
|
|
||||||
|
def _load_from_env(self) -> None:
|
||||||
|
"""Load restrictions from environment variables."""
|
||||||
|
for provider_type, env_var in self.ENV_VARS.items():
|
||||||
|
env_value = os.getenv(env_var)
|
||||||
|
|
||||||
|
if env_value is None or env_value == "":
|
||||||
|
# Not set or empty - no restrictions (allow all models)
|
||||||
|
logger.debug(f"{env_var} not set or empty - all {provider_type.value} models allowed")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse comma-separated list
|
||||||
|
models = set()
|
||||||
|
for model in env_value.split(","):
|
||||||
|
cleaned = model.strip().lower()
|
||||||
|
if cleaned:
|
||||||
|
models.add(cleaned)
|
||||||
|
|
||||||
|
if models:
|
||||||
|
self.restrictions[provider_type] = models
|
||||||
|
logger.info(f"{provider_type.value} allowed models: {sorted(models)}")
|
||||||
|
else:
|
||||||
|
# All entries were empty after cleaning - treat as no restrictions
|
||||||
|
logger.debug(f"{env_var} contains only whitespace - all {provider_type.value} models allowed")
|
||||||
|
|
||||||
|
def validate_against_known_models(self, provider_instances: dict[ProviderType, any]) -> None:
|
||||||
|
"""
|
||||||
|
Validate restrictions against known models from providers.
|
||||||
|
|
||||||
|
This should be called after providers are initialized to warn about
|
||||||
|
typos or invalid model names in the restriction lists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_instances: Dictionary of provider type to provider instance
|
||||||
|
"""
|
||||||
|
for provider_type, allowed_models in self.restrictions.items():
|
||||||
|
provider = provider_instances.get(provider_type)
|
||||||
|
if not provider:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get all supported models using the clean polymorphic interface
|
||||||
|
try:
|
||||||
|
# Gather canonical models and aliases with consistent formatting
|
||||||
|
all_models = provider.list_models(
|
||||||
|
respect_restrictions=False,
|
||||||
|
include_aliases=True,
|
||||||
|
lowercase=True,
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
supported_models = set(all_models)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not get model list from {provider_type.value} provider: {e}")
|
||||||
|
supported_models = set()
|
||||||
|
|
||||||
|
# Check each allowed model
|
||||||
|
for allowed_model in allowed_models:
|
||||||
|
if allowed_model not in supported_models:
|
||||||
|
logger.warning(
|
||||||
|
f"Model '{allowed_model}' in {self.ENV_VARS[provider_type]} "
|
||||||
|
f"is not a recognized {provider_type.value} model. "
|
||||||
|
f"Please check for typos. Known models: {sorted(supported_models)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_allowed(self, provider_type: ProviderType, model_name: str, original_name: Optional[str] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a model is allowed for a specific provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: The provider type (OPENAI, GOOGLE, etc.)
|
||||||
|
model_name: The canonical model name (after alias resolution)
|
||||||
|
original_name: The original model name before alias resolution (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if allowed (or no restrictions), False if restricted
|
||||||
|
"""
|
||||||
|
if provider_type not in self.restrictions:
|
||||||
|
# No restrictions for this provider
|
||||||
|
return True
|
||||||
|
|
||||||
|
allowed_set = self.restrictions[provider_type]
|
||||||
|
|
||||||
|
if len(allowed_set) == 0:
|
||||||
|
# Empty set - allowed
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check both the resolved name and original name (if different)
|
||||||
|
names_to_check = {model_name.lower()}
|
||||||
|
if original_name and original_name.lower() != model_name.lower():
|
||||||
|
names_to_check.add(original_name.lower())
|
||||||
|
|
||||||
|
# If any of the names is in the allowed set, it's allowed
|
||||||
|
return any(name in allowed_set for name in names_to_check)
|
||||||
|
|
||||||
|
def get_allowed_models(self, provider_type: ProviderType) -> Optional[set[str]]:
|
||||||
|
"""
|
||||||
|
Get the set of allowed models for a provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: The provider type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of allowed model names, or None if no restrictions
|
||||||
|
"""
|
||||||
|
return self.restrictions.get(provider_type)
|
||||||
|
|
||||||
|
def has_restrictions(self, provider_type: ProviderType) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a provider has any restrictions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: The provider type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if restrictions exist, False otherwise
|
||||||
|
"""
|
||||||
|
return provider_type in self.restrictions
|
||||||
|
|
||||||
|
def filter_models(self, provider_type: ProviderType, models: list[str]) -> list[str]:
|
||||||
|
"""
|
||||||
|
Filter a list of models based on restrictions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider_type: The provider type
|
||||||
|
models: List of model names to filter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list containing only allowed models
|
||||||
|
"""
|
||||||
|
if not self.has_restrictions(provider_type):
|
||||||
|
return models
|
||||||
|
|
||||||
|
return [m for m in models if self.is_allowed(provider_type, m)]
|
||||||
|
|
||||||
|
def get_restriction_summary(self) -> dict[str, any]:
|
||||||
|
"""
|
||||||
|
Get a summary of all restrictions for logging/debugging.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with provider names and their restrictions
|
||||||
|
"""
|
||||||
|
summary = {}
|
||||||
|
for provider_type, allowed_set in self.restrictions.items():
|
||||||
|
if allowed_set:
|
||||||
|
summary[provider_type.value] = sorted(allowed_set)
|
||||||
|
else:
|
||||||
|
summary[provider_type.value] = "none (provider disabled)"
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance (singleton pattern)
|
||||||
|
_restriction_service: Optional[ModelRestrictionService] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_restriction_service() -> ModelRestrictionService:
|
||||||
|
"""
|
||||||
|
Get the global restriction service instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The singleton ModelRestrictionService instance
|
||||||
|
"""
|
||||||
|
global _restriction_service
|
||||||
|
if _restriction_service is None:
|
||||||
|
_restriction_service = ModelRestrictionService()
|
||||||
|
return _restriction_service
|
||||||
104
utils/security_config.py
Normal file
104
utils/security_config.py
Normal file
|
|
@ -0,0 +1,104 @@
|
||||||
|
"""
|
||||||
|
Security configuration and path validation constants
|
||||||
|
|
||||||
|
This module contains security-related constants and configurations
|
||||||
|
for file access control.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Dangerous paths that should never be scanned
|
||||||
|
# These would give overly broad access and pose security risks
|
||||||
|
DANGEROUS_PATHS = {
|
||||||
|
"/",
|
||||||
|
"/etc",
|
||||||
|
"/usr",
|
||||||
|
"/bin",
|
||||||
|
"/var",
|
||||||
|
"/root",
|
||||||
|
"/home",
|
||||||
|
"C:\\",
|
||||||
|
"C:\\Windows",
|
||||||
|
"C:\\Program Files",
|
||||||
|
"C:\\Users",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Directories to exclude from recursive file search
|
||||||
|
# These typically contain generated code, dependencies, or build artifacts
|
||||||
|
EXCLUDED_DIRS = {
|
||||||
|
# Python
|
||||||
|
"__pycache__",
|
||||||
|
".venv",
|
||||||
|
"venv",
|
||||||
|
"env",
|
||||||
|
".env",
|
||||||
|
"*.egg-info",
|
||||||
|
".eggs",
|
||||||
|
"wheels",
|
||||||
|
".Python",
|
||||||
|
".mypy_cache",
|
||||||
|
".pytest_cache",
|
||||||
|
".tox",
|
||||||
|
"htmlcov",
|
||||||
|
".coverage",
|
||||||
|
"coverage",
|
||||||
|
# Node.js / JavaScript
|
||||||
|
"node_modules",
|
||||||
|
".next",
|
||||||
|
".nuxt",
|
||||||
|
"bower_components",
|
||||||
|
".sass-cache",
|
||||||
|
# Version Control
|
||||||
|
".git",
|
||||||
|
".svn",
|
||||||
|
".hg",
|
||||||
|
# Build Output
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
"target",
|
||||||
|
"out",
|
||||||
|
# IDEs
|
||||||
|
".idea",
|
||||||
|
".vscode",
|
||||||
|
".sublime",
|
||||||
|
".atom",
|
||||||
|
".brackets",
|
||||||
|
# Temporary / Cache
|
||||||
|
".cache",
|
||||||
|
".temp",
|
||||||
|
".tmp",
|
||||||
|
"*.swp",
|
||||||
|
"*.swo",
|
||||||
|
"*~",
|
||||||
|
# OS-specific
|
||||||
|
".DS_Store",
|
||||||
|
"Thumbs.db",
|
||||||
|
# Java / JVM
|
||||||
|
".gradle",
|
||||||
|
".m2",
|
||||||
|
# Documentation build
|
||||||
|
"_build",
|
||||||
|
"site",
|
||||||
|
# Mobile development
|
||||||
|
".expo",
|
||||||
|
".flutter",
|
||||||
|
# Package managers
|
||||||
|
"vendor",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_dangerous_path(path: Path) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a path is in the dangerous paths list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the path is dangerous and should not be accessed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
resolved = path.resolve()
|
||||||
|
return str(resolved) in DANGEROUS_PATHS or resolved.parent == resolved
|
||||||
|
except Exception:
|
||||||
|
return True # If we can't resolve, consider it dangerous
|
||||||
113
utils/storage_backend.py
Normal file
113
utils/storage_backend.py
Normal file
|
|
@ -0,0 +1,113 @@
|
||||||
|
"""
|
||||||
|
In-memory storage backend for conversation threads
|
||||||
|
|
||||||
|
This module provides a thread-safe, in-memory alternative to Redis for storing
|
||||||
|
conversation contexts. It's designed for ephemeral MCP server sessions where
|
||||||
|
conversations only need to persist during a single Claude session.
|
||||||
|
|
||||||
|
⚠️ PROCESS-SPECIFIC STORAGE: This storage is confined to a single Python process.
|
||||||
|
Data stored in one process is NOT accessible from other processes or subprocesses.
|
||||||
|
This is why simulator tests that run server.py as separate subprocesses cannot
|
||||||
|
share conversation state between tool calls.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
- Thread-safe operations using locks
|
||||||
|
- TTL support with automatic expiration
|
||||||
|
- Background cleanup thread for memory management
|
||||||
|
- Singleton pattern for consistent state within a single process
|
||||||
|
- Drop-in replacement for Redis storage (for single-process scenarios)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryStorage:
|
||||||
|
"""Thread-safe in-memory storage for conversation threads"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._store: dict[str, tuple[str, float]] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
# Match Redis behavior: cleanup interval based on conversation timeout
|
||||||
|
# Run cleanup at 1/10th of timeout interval (e.g., 18 mins for 3 hour timeout)
|
||||||
|
timeout_hours = int(os.getenv("CONVERSATION_TIMEOUT_HOURS", "3"))
|
||||||
|
self._cleanup_interval = (timeout_hours * 3600) // 10
|
||||||
|
self._cleanup_interval = max(300, self._cleanup_interval) # Minimum 5 minutes
|
||||||
|
self._shutdown = False
|
||||||
|
|
||||||
|
# Start background cleanup thread
|
||||||
|
self._cleanup_thread = threading.Thread(target=self._cleanup_worker, daemon=True)
|
||||||
|
self._cleanup_thread.start()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"In-memory storage initialized with {timeout_hours}h timeout, cleanup every {self._cleanup_interval//60}m"
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_with_ttl(self, key: str, ttl_seconds: int, value: str) -> None:
|
||||||
|
"""Store value with expiration time"""
|
||||||
|
with self._lock:
|
||||||
|
expires_at = time.time() + ttl_seconds
|
||||||
|
self._store[key] = (value, expires_at)
|
||||||
|
logger.debug(f"Stored key {key} with TTL {ttl_seconds}s")
|
||||||
|
|
||||||
|
def get(self, key: str) -> Optional[str]:
|
||||||
|
"""Retrieve value if not expired"""
|
||||||
|
with self._lock:
|
||||||
|
if key in self._store:
|
||||||
|
value, expires_at = self._store[key]
|
||||||
|
if time.time() < expires_at:
|
||||||
|
logger.debug(f"Retrieved key {key}")
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
# Clean up expired entry
|
||||||
|
del self._store[key]
|
||||||
|
logger.debug(f"Key {key} expired and removed")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def setex(self, key: str, ttl_seconds: int, value: str) -> None:
|
||||||
|
"""Redis-compatible setex method"""
|
||||||
|
self.set_with_ttl(key, ttl_seconds, value)
|
||||||
|
|
||||||
|
def _cleanup_worker(self):
|
||||||
|
"""Background thread that periodically cleans up expired entries"""
|
||||||
|
while not self._shutdown:
|
||||||
|
time.sleep(self._cleanup_interval)
|
||||||
|
self._cleanup_expired()
|
||||||
|
|
||||||
|
def _cleanup_expired(self):
|
||||||
|
"""Remove all expired entries"""
|
||||||
|
with self._lock:
|
||||||
|
current_time = time.time()
|
||||||
|
expired_keys = [k for k, (_, exp) in self._store.items() if exp < current_time]
|
||||||
|
for key in expired_keys:
|
||||||
|
del self._store[key]
|
||||||
|
|
||||||
|
if expired_keys:
|
||||||
|
logger.debug(f"Cleaned up {len(expired_keys)} expired conversation threads")
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
"""Graceful shutdown of background thread"""
|
||||||
|
self._shutdown = True
|
||||||
|
if self._cleanup_thread.is_alive():
|
||||||
|
self._cleanup_thread.join(timeout=1)
|
||||||
|
|
||||||
|
|
||||||
|
# Global singleton instance
|
||||||
|
_storage_instance = None
|
||||||
|
_storage_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_storage_backend() -> InMemoryStorage:
|
||||||
|
"""Get the global storage instance (singleton pattern)"""
|
||||||
|
global _storage_instance
|
||||||
|
if _storage_instance is None:
|
||||||
|
with _storage_lock:
|
||||||
|
if _storage_instance is None:
|
||||||
|
_storage_instance = InMemoryStorage()
|
||||||
|
logger.info("Initialized in-memory conversation storage")
|
||||||
|
return _storage_instance
|
||||||
54
utils/token_utils.py
Normal file
54
utils/token_utils.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
"""
|
||||||
|
Token counting utilities for managing API context limits
|
||||||
|
|
||||||
|
This module provides functions for estimating token counts to ensure
|
||||||
|
requests stay within the Gemini API's context window limits.
|
||||||
|
|
||||||
|
Note: The estimation uses a simple character-to-token ratio which is
|
||||||
|
approximate. For production systems requiring precise token counts,
|
||||||
|
consider using the actual tokenizer for the specific model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default fallback for token limit (conservative estimate)
|
||||||
|
DEFAULT_CONTEXT_WINDOW = 200_000 # Conservative fallback for unknown models
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_tokens(text: str) -> int:
|
||||||
|
"""
|
||||||
|
Estimate token count using a character-based approximation.
|
||||||
|
|
||||||
|
This uses a rough heuristic where 1 token ≈ 4 characters, which is
|
||||||
|
a reasonable approximation for English text. The actual token count
|
||||||
|
may vary based on:
|
||||||
|
- Language (non-English text may have different ratios)
|
||||||
|
- Code vs prose (code often has more tokens per character)
|
||||||
|
- Special characters and formatting
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to estimate tokens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Estimated number of tokens
|
||||||
|
"""
|
||||||
|
return len(text) // 4
|
||||||
|
|
||||||
|
|
||||||
|
def check_token_limit(text: str, context_window: int = DEFAULT_CONTEXT_WINDOW) -> tuple[bool, int]:
|
||||||
|
"""
|
||||||
|
Check if text exceeds the specified token limit.
|
||||||
|
|
||||||
|
This function is used to validate that prepared prompts will fit
|
||||||
|
within the model's context window, preventing API errors and ensuring
|
||||||
|
reliable operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to check
|
||||||
|
context_window: The model's context window size (defaults to conservative fallback)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, int]: (is_within_limit, estimated_tokens)
|
||||||
|
- is_within_limit: True if the text fits within context_window
|
||||||
|
- estimated_tokens: The estimated token count
|
||||||
|
"""
|
||||||
|
estimated = estimate_tokens(text)
|
||||||
|
return estimated <= context_window, estimated
|
||||||
Loading…
Reference in a new issue