From 371806488d73fc127f6e622a0b5bd63b9e30ce73 Mon Sep 17 00:00:00 2001 From: Ben Date: Fri, 7 Nov 2025 11:35:17 -0400 Subject: [PATCH] Initial commit: Zen-Marketing MCP Server v0.1.0 - Core architecture from zen-mcp-server - OpenRouter and Gemini provider configuration - Content variant generator tool (first marketing tool) - Chat tool for marketing strategy - Version and model listing tools - Configuration system with .env support - Logging infrastructure - Ready for Claude Desktop integration --- .env.example | 40 + .gitignore | 190 +++ CLAUDE.md | 566 +++++++++ LICENSE | 197 +++ PLAN.md | 513 ++++++++ README.md | 163 +++ config.py | 107 ++ providers/__init__.py | 20 + providers/base.py | 268 ++++ providers/custom.py | 196 +++ providers/dial.py | 473 +++++++ providers/gemini.py | 578 +++++++++ providers/openai_compatible.py | 826 ++++++++++++ providers/openai_provider.py | 296 +++++ providers/openrouter.py | 251 ++++ providers/openrouter_registry.py | 292 +++++ providers/registry.py | 397 ++++++ providers/shared/__init__.py | 21 + providers/shared/model_capabilities.py | 122 ++ providers/shared/model_response.py | 26 + providers/shared/provider_type.py | 16 + providers/shared/temperature.py | 188 +++ providers/xai.py | 157 +++ requirements.txt | 11 + run-server.sh | 66 + server.py | 352 ++++++ systemprompts/__init__.py | 6 + systemprompts/chat_prompt.py | 29 + systemprompts/contentvariant_prompt.py | 62 + tools/__init__.py | 39 + tools/chat.py | 189 +++ tools/contentvariant.py | 180 +++ tools/listmodels.py | 299 +++++ tools/models.py | 373 ++++++ tools/shared/__init__.py | 19 + tools/shared/base_models.py | 165 +++ tools/shared/base_tool.py | 1399 ++++++++++++++++++++ tools/shared/schema_builders.py | 159 +++ tools/simple/__init__.py | 18 + tools/simple/base.py | 985 ++++++++++++++ tools/version.py | 368 ++++++ tools/workflow/__init__.py | 22 + tools/workflow/base.py | 444 +++++++ tools/workflow/schema_builders.py | 174 +++ tools/workflow/workflow_mixin.py | 1619 ++++++++++++++++++++++++ utils/__init__.py | 21 + utils/client_info.py | 293 +++++ utils/conversation_memory.py | 1095 ++++++++++++++++ utils/file_types.py | 271 ++++ utils/file_utils.py | 864 +++++++++++++ utils/image_utils.py | 94 ++ utils/model_context.py | 180 +++ utils/model_restrictions.py | 226 ++++ utils/security_config.py | 104 ++ utils/storage_backend.py | 113 ++ utils/token_utils.py | 54 + 56 files changed, 16196 insertions(+) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 CLAUDE.md create mode 100644 LICENSE create mode 100644 PLAN.md create mode 100644 README.md create mode 100644 config.py create mode 100644 providers/__init__.py create mode 100644 providers/base.py create mode 100644 providers/custom.py create mode 100644 providers/dial.py create mode 100644 providers/gemini.py create mode 100644 providers/openai_compatible.py create mode 100644 providers/openai_provider.py create mode 100644 providers/openrouter.py create mode 100644 providers/openrouter_registry.py create mode 100644 providers/registry.py create mode 100644 providers/shared/__init__.py create mode 100644 providers/shared/model_capabilities.py create mode 100644 providers/shared/model_response.py create mode 100644 providers/shared/provider_type.py create mode 100644 providers/shared/temperature.py create mode 100644 providers/xai.py create mode 100644 requirements.txt create mode 100755 run-server.sh create mode 100644 server.py create mode 100644 systemprompts/__init__.py create mode 100644 systemprompts/chat_prompt.py create mode 100644 systemprompts/contentvariant_prompt.py create mode 100644 tools/__init__.py create mode 100644 tools/chat.py create mode 100644 tools/contentvariant.py create mode 100644 tools/listmodels.py create mode 100644 tools/models.py create mode 100644 tools/shared/__init__.py create mode 100644 tools/shared/base_models.py create mode 100644 tools/shared/base_tool.py create mode 100644 tools/shared/schema_builders.py create mode 100644 tools/simple/__init__.py create mode 100644 tools/simple/base.py create mode 100644 tools/version.py create mode 100644 tools/workflow/__init__.py create mode 100644 tools/workflow/base.py create mode 100644 tools/workflow/schema_builders.py create mode 100644 tools/workflow/workflow_mixin.py create mode 100644 utils/__init__.py create mode 100644 utils/client_info.py create mode 100644 utils/conversation_memory.py create mode 100644 utils/file_types.py create mode 100644 utils/file_utils.py create mode 100644 utils/image_utils.py create mode 100644 utils/model_context.py create mode 100644 utils/model_restrictions.py create mode 100644 utils/security_config.py create mode 100644 utils/storage_backend.py create mode 100644 utils/token_utils.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..50e43af --- /dev/null +++ b/.env.example @@ -0,0 +1,40 @@ +# Zen-Marketing MCP Server Configuration + +# API Keys +# Required: At least one API key must be configured +OPENROUTER_API_KEY=your-openrouter-api-key-here +GEMINI_API_KEY=your-gemini-api-key-here + +# Model Configuration +# DEFAULT_MODEL: Primary model for analytical and strategic work +# Options: google/gemini-2.5-pro-latest, minimax/minimax-m2, or any OpenRouter model +DEFAULT_MODEL=google/gemini-2.5-pro-latest + +# FAST_MODEL: Model for quick generation (subject lines, variations) +FAST_MODEL=google/gemini-2.5-flash-preview-09-2025 + +# CREATIVE_MODEL: Model for creative content generation +CREATIVE_MODEL=minimax/minimax-m2 + +# Web Search +# Enable web search for fact-checking and current information +ENABLE_WEB_SEARCH=true + +# Tool Configuration +# Comma-separated list of tools to disable +# Available tools: contentvariant, platformadapt, subjectlines, styleguide, +# seooptimize, guestedit, linkstrategy, factcheck, +# voiceanalysis, campaignmap, chat, thinkdeep, planner +DISABLED_TOOLS= + +# Logging +# Options: DEBUG, INFO, WARNING, ERROR +LOG_LEVEL=INFO + +# Language/Locale +# Leave empty for English, or specify locale (e.g., fr-FR, es-ES, de-DE) +LOCALE= + +# Thinking Mode for Deep Analysis +# Options: minimal, low, medium, high, max +DEFAULT_THINKING_MODE_THINKDEEP=high diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..be60b01 --- /dev/null +++ b/.gitignore @@ -0,0 +1,190 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +Pipfile.lock + +# poetry +poetry.lock + +# pdm +.pdm.toml +.pdm-python +pdm.lock + +# PEP 582 +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.env~ +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +.idea/ + +# VS Code +.vscode/ + +# macOS +.DS_Store + +# API Keys and secrets +*.key +*.pem +.env.local +.env.*.local + +# Test outputs +test_output/ +*.test.log +.coverage +htmlcov/ +coverage.xml +.pytest_cache/ + +# Test simulation artifacts (dynamically created during testing) +test_simulation_files/.claude/ + +# Temporary test directories +test-setup/ + +# Scratch feature documentation files +FEATURE_*.md +# Temporary files +/tmp/ + +# Local user instructions +CLAUDE.local.md + +# Claude Code personal settings +.claude/settings.local.json + +# Standalone mode files +.zen_venv/ +.docker_cleaned +logs/ +*.backup +/.desktop_configured + +/worktrees/ +test_simulation_files/ +.mcp.json diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..c91413c --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,566 @@ +# Claude Development Guide for Zen-Marketing MCP Server + +This file contains essential commands and workflows for developing the Zen-Marketing MCP Server - a specialized marketing-focused fork of Zen MCP Server. + +## Project Context + +**What is Zen-Marketing?** +A Claude Desktop MCP server providing AI-powered marketing tools focused on: +- Content variation generation for A/B testing +- Cross-platform content adaptation +- Writing style enforcement +- SEO optimization for WordPress +- Guest content editing with voice preservation +- Technical fact verification +- Internal linking strategy +- Multi-channel campaign planning + +**Target User:** Solo marketing professionals managing technical B2B content, particularly in industries like HVAC, SaaS, and technical education. + +**Key Difference from Zen Code:** This is for marketing/content work, not software development. Tools generate content variations, enforce writing styles, and optimize for platforms like LinkedIn, newsletters, and WordPress - not code review or debugging. + +## Quick Reference Commands + +### Initial Setup + +```bash +# Navigate to project directory +cd ~/mcp/zen-marketing + +# Copy core files from zen-mcp-server (if starting fresh) +# We'll do this in the new session + +# Create virtual environment +python3 -m venv .venv +source .venv/bin/activate + +# Install dependencies (once requirements.txt is created) +pip install -r requirements.txt + +# Create .env file +cp .env.example .env +# Edit .env with your API keys +``` + +### Development Workflow + +```bash +# Activate environment +source .venv/bin/activate + +# Run code quality checks (once implemented) +./code_quality_checks.sh + +# Run server locally for testing +python server.py + +# View logs +tail -f logs/mcp_server.log + +# Run tests +python -m pytest tests/ -v +``` + +### Claude Desktop Configuration + +Add to `~/.claude.json`: + +```json +{ + "mcpServers": { + "zen-marketing": { + "command": "/home/ben/mcp/zen-marketing/.venv/bin/python", + "args": ["/home/ben/mcp/zen-marketing/server.py"], + "env": { + "OPENROUTER_API_KEY": "your-openrouter-key", + "GEMINI_API_KEY": "your-gemini-key", + "DEFAULT_MODEL": "gemini-2.5-pro", + "FAST_MODEL": "gemini-flash", + "CREATIVE_MODEL": "minimax-m2", + "ENABLE_WEB_SEARCH": "true", + "DISABLED_TOOLS": "", + "LOG_LEVEL": "INFO" + } + } + } +} +``` + +**After modifying config:** Restart Claude Desktop for changes to take effect. + +## Tool Development Guidelines + +### Tool Categories + +**Simple Tools** (single-shot, fast response): +- Inherit from `SimpleTool` base class +- Focus on speed and iteration +- Examples: `contentvariant`, `platformadapt`, `subjectlines`, `factcheck` +- Use fast models (gemini-flash) when possible + +**Workflow Tools** (multi-step processes): +- Inherit from `WorkflowTool` base class +- Systematic step-by-step workflows +- Track progress, confidence, findings +- Examples: `styleguide`, `seooptimize`, `guestedit`, `linkstrategy` + +### Temperature Guidelines for Marketing Tools + +- **High (0.7-0.8)**: Content variation, creative adaptation +- **Medium (0.5-0.6)**: Balanced tasks, campaign planning +- **Low (0.3-0.4)**: Analytical work, SEO optimization +- **Very Low (0.2)**: Fact-checking, technical verification + +### Model Selection Strategy + +**Gemini 2.5 Pro** (`gemini-2.5-pro`): +- Analytical and strategic work +- SEO optimization +- Guest editing +- Internal linking analysis +- Voice analysis +- Campaign planning +- Fact-checking + +**Gemini Flash** (`gemini-flash`): +- Fast bulk generation +- Subject line creation +- Quick variations +- Cost-effective iterations + +**Minimax M2** (`minimax-m2`): +- Creative content generation +- Platform adaptation +- Content repurposing +- Marketing copy variations + +### System Prompt Best Practices + +Marketing tool prompts should: +1. **Specify output format clearly** (JSON, markdown, numbered list) +2. **Include platform constraints** (character limits, formatting rules) +3. **Emphasize preservation** (voice, expertise, technical accuracy) +4. **Request rationale** (why certain variations work, what to test) +5. **Avoid code terminology** (use "content" not "implementation") + +Example prompt structure: +```python +CONTENTVARIANT_PROMPT = """ +You are a marketing content strategist specializing in A/B testing and variation generation. + +TASK: Generate multiple variations of marketing content for testing different approaches. + +OUTPUT FORMAT: +Return variations as numbered list, each with: +1. The variation text +2. The testing angle (what makes it different) +3. Predicted audience response + +CONSTRAINTS: +- Maintain core message across variations +- Respect platform character limits if specified +- Preserve brand voice characteristics +- Generate genuinely different approaches, not just word swaps + +VARIATION TYPES: +- Hook variations: Different opening angles +- Length variations: Short, medium, long +- Tone variations: Professional, conversational, urgent +- Structure variations: Question, statement, story +- CTA variations: Different calls-to-action +""" +``` + +## Implementation Phases + +### Phase 1: Foundation ✓ (You Are Here) +- [x] Create project directory +- [x] Write implementation plan (PLAN.md) +- [x] Create development guide (CLAUDE.md) +- [ ] Copy core architecture from zen-mcp-server +- [ ] Configure minimax provider +- [ ] Remove code-specific tools +- [ ] Test basic chat functionality + +### Phase 2: Simple Tools (Priority: High) +Implementation order based on real-world usage: +1. **`contentvariant`** - Most frequently used (subject lines, social posts) +2. **`subjectlines`** - Specific workflow mentioned in project memories +3. **`platformadapt`** - Multi-channel content distribution +4. **`factcheck`** - Technical accuracy verification + +### Phase 3: Workflow Tools (Priority: Medium) +5. **`styleguide`** - Writing rule enforcement (no em-dashes, etc.) +6. **`seooptimize`** - WordPress SEO optimization +7. **`guestedit`** - Guest content editing workflow +8. **`linkstrategy`** - Internal linking and cross-platform integration + +### Phase 4: Advanced Features (Priority: Lower) +9. **`voiceanalysis`** - Voice extraction and consistency checking +10. **`campaignmap`** - Multi-touch campaign planning + +## Tool Implementation Checklist + +For each new tool: + +**Code Files:** +- [ ] Create tool file in `tools/` (e.g., `tools/contentvariant.py`) +- [ ] Create system prompt in `systemprompts/` (e.g., `systemprompts/contentvariant_prompt.py`) +- [ ] Create test file in `tests/` (e.g., `tests/test_contentvariant.py`) +- [ ] Register tool in `server.py` + +**Tool Class Requirements:** +- [ ] Inherit from `SimpleTool` or `WorkflowTool` +- [ ] Implement `get_name()` - tool name +- [ ] Implement `get_description()` - what it does +- [ ] Implement `get_system_prompt()` - behavior instructions +- [ ] Implement `get_default_temperature()` - creativity level +- [ ] Implement `get_model_category()` - FAST_RESPONSE or DEEP_THINKING +- [ ] Implement `get_request_model()` - Pydantic request schema +- [ ] Implement `get_input_schema()` - MCP tool schema +- [ ] Implement request/response formatting hooks + +**Testing:** +- [ ] Unit tests for request validation +- [ ] Unit tests for response formatting +- [ ] Integration test with real model (optional) +- [ ] Add to quality checks script + +**Documentation:** +- [ ] Add tool to README.md +- [ ] Create examples in docs/tools/ +- [ ] Update PLAN.md progress + +## Common Development Tasks + +### Adding a New Simple Tool + +```python +# tools/mynewtool.py +from typing import Optional +from pydantic import Field +from tools.shared.base_models import ToolRequest +from .simple.base import SimpleTool +from systemprompts import MYNEWTOOL_PROMPT +from config import TEMPERATURE_BALANCED + +class MyNewToolRequest(ToolRequest): + """Request model for MyNewTool""" + prompt: str = Field(..., description="What you want to accomplish") + files: Optional[list[str]] = Field(default_factory=list) + +class MyNewTool(SimpleTool): + def get_name(self) -> str: + return "mynewtool" + + def get_description(self) -> str: + return "Brief description of what this tool does" + + def get_system_prompt(self) -> str: + return MYNEWTOOL_PROMPT + + def get_default_temperature(self) -> float: + return TEMPERATURE_BALANCED + + def get_model_category(self) -> "ToolModelCategory": + from tools.models import ToolModelCategory + return ToolModelCategory.FAST_RESPONSE + + def get_request_model(self): + return MyNewToolRequest +``` + +### Adding a New Workflow Tool + +```python +# tools/mynewworkflow.py +from typing import Optional +from pydantic import Field +from tools.shared.base_models import WorkflowRequest +from .workflow.base import WorkflowTool +from systemprompts import MYNEWWORKFLOW_PROMPT + +class MyNewWorkflowRequest(WorkflowRequest): + """Request model for workflow tool""" + step: str = Field(description="Current step content") + step_number: int = Field(ge=1) + total_steps: int = Field(ge=1) + next_step_required: bool + findings: str = Field(description="What was discovered") + # Add workflow-specific fields + +class MyNewWorkflow(WorkflowTool): + # Implementation similar to Simple Tool + # but with workflow-specific logic +``` + +### Testing a Tool Manually + +```bash +# Start server with debug logging +LOG_LEVEL=DEBUG python server.py + +# In another terminal, watch logs +tail -f logs/mcp_server.log | grep -E "(TOOL_CALL|ERROR|MyNewTool)" + +# In Claude Desktop, test: +# "Use zen-marketing to generate 10 subject lines about HVAC maintenance" +``` + +## Project Structure + +``` +zen-marketing/ +├── server.py # Main MCP server entry point +├── config.py # Configuration constants +├── PLAN.md # Implementation plan (this doc) +├── CLAUDE.md # Development guide +├── README.md # User-facing documentation +├── requirements.txt # Python dependencies +├── .env.example # Environment variable template +├── .env # Local config (gitignored) +├── run-server.sh # Setup and run script +├── code_quality_checks.sh # Linting and testing +│ +├── tools/ # Tool implementations +│ ├── __init__.py +│ ├── contentvariant.py # Bulk variation generator +│ ├── platformadapt.py # Cross-platform adapter +│ ├── subjectlines.py # Email subject line generator +│ ├── styleguide.py # Writing style enforcer +│ ├── seooptimize.py # SEO optimizer +│ ├── guestedit.py # Guest content editor +│ ├── linkstrategy.py # Internal linking strategist +│ ├── factcheck.py # Technical fact checker +│ ├── voiceanalysis.py # Voice extractor/validator +│ ├── campaignmap.py # Campaign planner +│ ├── chat.py # General chat (from zen) +│ ├── thinkdeep.py # Deep thinking (from zen) +│ ├── planner.py # Planning (from zen) +│ ├── models.py # Shared models +│ ├── simple/ # Simple tool base classes +│ │ └── base.py +│ ├── workflow/ # Workflow tool base classes +│ │ └── base.py +│ └── shared/ # Shared utilities +│ └── base_models.py +│ +├── providers/ # AI provider implementations +│ ├── __init__.py +│ ├── base.py # Base provider interface +│ ├── gemini.py # Google Gemini +│ ├── minimax.py # Minimax (NEW) +│ ├── openrouter.py # OpenRouter fallback +│ ├── registry.py # Provider registry +│ └── shared/ +│ +├── systemprompts/ # System prompts for tools +│ ├── __init__.py +│ ├── contentvariant_prompt.py +│ ├── platformadapt_prompt.py +│ ├── subjectlines_prompt.py +│ ├── styleguide_prompt.py +│ ├── seooptimize_prompt.py +│ ├── guestedit_prompt.py +│ ├── linkstrategy_prompt.py +│ ├── factcheck_prompt.py +│ ├── voiceanalysis_prompt.py +│ ├── campaignmap_prompt.py +│ ├── chat_prompt.py # From zen +│ ├── thinkdeep_prompt.py # From zen +│ └── planner_prompt.py # From zen +│ +├── utils/ # Utility functions +│ ├── conversation_memory.py # Conversation continuity +│ ├── file_utils.py # File handling +│ └── web_search.py # Web search integration +│ +├── tests/ # Test suite +│ ├── __init__.py +│ ├── test_contentvariant.py +│ ├── test_platformadapt.py +│ ├── test_subjectlines.py +│ └── ... +│ +├── logs/ # Log files (gitignored) +│ ├── mcp_server.log +│ └── mcp_activity.log +│ +└── docs/ # Documentation + ├── getting-started.md + ├── tools/ + │ ├── contentvariant.md + │ ├── platformadapt.md + │ └── ... + └── examples/ + └── marketing-workflows.md +``` + +## Key Concepts from Zen Architecture + +### Conversation Continuity +Every tool supports `continuation_id` to maintain context across interactions: + +```python +# First call +result1 = await tool.execute({ + "prompt": "Analyze this brand voice", + "files": ["brand_samples/post1.txt", "brand_samples/post2.txt"] +}) +# Returns: continuation_id: "abc123" + +# Follow-up call (remembers previous context) +result2 = await tool.execute({ + "prompt": "Now check if this new draft matches the voice", + "files": ["new_draft.txt"], + "continuation_id": "abc123" # Preserves context +}) +``` + +### File Handling +Tools automatically: +- Expand directories to individual files +- Deduplicate file lists +- Handle absolute paths +- Process images (screenshots, brand assets) + +### Web Search Integration +Tools can request Claude to perform web searches: +```python +# In system prompt: +"If you need current information about [topic], request a web search from Claude." + +# Claude will then use WebSearch tool and provide results +``` + +### Multi-Model Orchestration +Tools specify model category, server selects best available: +- `FAST_RESPONSE` → gemini-flash or equivalent +- `DEEP_THINKING` → gemini-2.5-pro or equivalent +- User can override with `model` parameter + +## Debugging Common Issues + +### Tool Not Appearing in Claude Desktop +1. Check `server.py` registers the tool +2. Verify tool is not in `DISABLED_TOOLS` env var +3. Restart Claude Desktop after config changes +4. Check logs: `tail -f logs/mcp_server.log` + +### Model Selection Issues +1. Verify API keys in `.env` +2. Check provider registration in `providers/registry.py` +3. Test with explicit model name: `"model": "gemini-2.5-pro"` +4. Check logs for provider errors + +### Response Formatting Issues +1. Validate system prompt specifies output format +2. Check response doesn't exceed token limits +3. Test with simpler input first +4. Review logs for truncation warnings + +### Conversation Continuity Not Working +1. Verify `continuation_id` is being passed correctly +2. Check conversation hasn't expired (default 6 hours) +3. Validate conversation memory storage +4. Review logs: `grep "continuation_id" logs/mcp_server.log` + +## Code Quality Standards + +Before committing: + +```bash +# Run all quality checks +./code_quality_checks.sh + +# Manual checks: +ruff check . --fix # Linting +black . # Formatting +isort . # Import sorting +pytest tests/ -v # Run tests +``` + +## Marketing-Specific Considerations + +### Character Limits by Platform +Tools should be aware of: +- **Twitter/Bluesky**: 280 characters +- **LinkedIn**: 3000 chars (1300 optimal) +- **Instagram**: 2200 characters +- **Facebook**: No hard limit (500 chars optimal) +- **Email subject**: 60 characters optimal +- **Email preview**: 90-100 characters +- **Meta description**: 156 characters +- **Page title**: 60 characters + +### Writing Style Rules from Project Memories +- No em-dashes (use periods or semicolons) +- No "This isn't X, it's Y" constructions +- Direct affirmative statements over negations +- Semantic variety in paragraph openings +- Concrete metrics over abstract claims +- Technical accuracy preserved +- Author voice maintained + +### Testing Angles for Variations +- Technical curiosity +- Contrarian/provocative +- Knowledge gap emphasis +- Urgency/timeliness +- Insider knowledge positioning +- Problem-solution framing +- Before-after transformation +- Social proof/credibility +- FOMO (fear of missing out) +- Educational value + +## Next Session Goals + +When you start the new session in `~/mcp/zen-marketing/`: + +1. **Copy Core Files from Zen** + - Copy base architecture preserving git history + - Remove code-specific tools + - Update imports and references + +2. **Configure Minimax Provider** + - Add minimax support to providers/ + - Register in provider registry + - Test basic model calls + +3. **Implement First Simple Tool** + - Start with `contentvariant` (highest priority) + - Create tool, system prompt, and tests + - Test end-to-end with Claude Desktop + +4. **Validate Architecture** + - Ensure conversation continuity works + - Verify file handling + - Test web search integration + +## Questions to Consider + +Before implementing each tool: +1. What real-world workflow does this solve? (Reference project memories) +2. What's the minimum viable version? +3. What can go wrong? (Character limits, API errors, invalid input) +4. How will users test variations? (Output format) +5. Does it need web search? (Current info, fact-checking) +6. What's the right temperature? (Creative vs analytical) +7. Simple or workflow tool? (Single-shot vs multi-step) + +## Resources + +- **Zen MCP Server Repo**: Source for architecture and patterns +- **MCP Protocol Docs**: https://modelcontextprotocol.com +- **Claude Desktop Config**: `~/.claude.json` +- **Project Memories**: See PLAN.md for user workflow examples +- **Platform Best Practices**: Research current 2025 guidelines + +--- + +**Ready to build?** Start the new session with: +```bash +cd ~/mcp/zen-marketing +# Then ask Claude to begin Phase 1 implementation +``` diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..2d18748 --- /dev/null +++ b/LICENSE @@ -0,0 +1,197 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship covered by this License, + whether in source or binary form, which is made available under the + License, as indicated by a copyright notice that is included in or + attached to the work. (The copyright notice requirement does not + apply to derivative works of the License holder.) + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based upon (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and derivative works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control + systems, and issue tracking systems that are managed by, or on behalf + of, the Licensor for the purpose of discussing and improving the Work, + but excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to use, reproduce, modify, distribute, and otherwise + transfer the Work as part of a Derivative Work. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright notice to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Support. You may choose to offer, and to + charge a fee for, warranty, support, indemnity or other liability + obligations and/or rights consistent with this License. However, + in accepting such obligations, You may act only on Your own behalf + and on Your sole responsibility, not on behalf of any other + Contributor, and only if You agree to indemnify, defend, and hold + each Contributor harmless for any liability incurred by, or claims + asserted against, such Contributor by reason of your accepting any + such warranty or support. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in comments for the + particular file format. An identification line is also useful. + + Copyright 2025 Beehive Innovations + https://github.com/BeehiveInnovations + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..07acabf --- /dev/null +++ b/PLAN.md @@ -0,0 +1,513 @@ +# Zen-Marketing MCP Server - Implementation Plan + +## Project Overview + +A specialized MCP server for Claude Desktop focused on marketing workflows, derived from the Zen MCP Server codebase. Designed for solo marketing professionals working in technical B2B content, particularly HVAC/technical education and software marketing. + +**Target User Profile:** Marketing professional managing: +- Technical newsletters with guest contributors +- Multi-platform social media campaigns +- Educational content with SEO optimization +- A/B testing and content variation generation +- WordPress content management with internal linking +- Brand voice consistency across channels + +## Core Principles + +1. **Variation at Scale**: Generate 5-20 variations of content for A/B testing +2. **Platform Intelligence**: Understand character limits, tone, and best practices per platform +3. **Technical Accuracy**: Verify facts via web search before publishing +4. **Voice Preservation**: Maintain authentic voices (guest authors, brand persona) +5. **Cross-Platform Integration**: Connect content across blog, social, email, video +6. **Style Enforcement**: Apply specific writing guidelines automatically + +## Planned Tools + +### 1. `contentvariant` (Simple Tool) +**Purpose:** Rapid generation of multiple content variations for testing + +**Real-world use case from project memory:** +- "Generate 15 email subject line variations testing different angles: technical curiosity, contrarian statements, FOMO, educational value" +- "Create 10 LinkedIn post variations of this announcement, mixing lengths and hooks" + +**Key features:** +- Generate 5-20 variations in one call +- Specify variation dimensions (tone, length, hook, CTA placement) +- Fast model (gemini-flash) for speed +- Output formatted for easy copy-paste testing + +**Parameters:** +- `content` (str): Base content to vary +- `variation_count` (int): Number of variations (default 10) +- `variation_types` (list): Types to vary - "tone", "length", "hook", "structure", "cta" +- `platform` (str): Target platform for context +- `constraints` (str): Character limits, style requirements + +**Model:** gemini-flash (fast, cost-effective for bulk generation) +**Temperature:** 0.8 (creative variation while maintaining message) + +--- + +### 2. `platformadapt` (Simple Tool) +**Purpose:** Adapt single content piece across multiple social platforms + +**Real-world use case from project memory:** +- "Take this blog post intro and create versions for LinkedIn (1300 chars), Twitter (280 chars), Instagram caption (2200 chars), Facebook (500 chars), and Bluesky (300 chars)" + +**Key features:** +- Single source → multiple platform versions +- Respects character limits automatically +- Platform-specific best practices (hashtags, tone, formatting) +- Preserves core message across adaptations + +**Parameters:** +- `source_content` (str): Original content +- `source_platform` (str): Where content originated +- `target_platforms` (list): ["linkedin", "twitter", "instagram", "facebook", "bluesky"] +- `preserve_urls` (bool): Keep links intact vs. adapt +- `brand_voice` (str): Voice guidelines to maintain + +**Model:** minimax/minimax-m2 (creative adaptations) +**Temperature:** 0.7 + +--- + +### 3. `subjectlines` (Simple Tool) +**Purpose:** Specialized tool for email subject line generation with angle testing + +**Real-world use case from project memory:** +- "Generate subject lines testing: technical curiosity hook, contrarian/provocative angle, knowledge gap emphasis, urgency/timeliness, insider knowledge positioning" + +**Key features:** +- Generates 15-25 subject lines by default +- Groups by psychological angle/hook +- Includes A/B testing rationale for each +- Character count validation (under 60 chars optimal) +- Emoji suggestions (optional) + +**Parameters:** +- `email_topic` (str): Main topic/content +- `target_audience` (str): Who receives this +- `angles_to_test` (list): Psychological hooks to explore +- `include_emoji` (bool): Add emoji options +- `preview_text` (str): Optional - generate matching preview text + +**Model:** gemini-flash (fast generation, creative but focused) +**Temperature:** 0.8 + +--- + +### 4. `styleguide` (Workflow Tool) +**Purpose:** Enforce specific writing guidelines and polish content + +**Real-world use case from project memory:** +- "Check this content for: em-dashes (remove), 'This isn't X, it's Y' constructions (rewrite), semantic variety in paragraph structure, direct affirmative statements instead of negations" + +**Key features:** +- Multi-step workflow: detection → flagging → rewriting → validation +- Tracks style violations with severity +- Provides before/after comparisons +- Custom rule definitions + +**Workflow steps:** +1. Analyze content for style violations +2. Flag issues with line numbers and severity +3. Generate corrected version +4. Validate improvements + +**Parameters:** +- `content` (str): Content to check +- `rules` (list): Style rules to enforce +- `rewrite_violations` (bool): Auto-fix vs. flag only +- `preserve_technical_terms` (bool): Don't change HVAC/technical terminology + +**Common rules (from project memory):** +- No em-dashes (use periods or semicolons) +- No "This isn't X, it's Y" constructions +- Direct affirmative statements over negations +- Semantic variety in paragraph openings +- No clichéd structures +- Concrete metrics over abstract claims + +**Model:** gemini-2.5-pro (analytical precision) +**Temperature:** 0.3 (consistency enforcement) + +--- + +### 5. `seooptimize` (Workflow Tool) +**Purpose:** Comprehensive SEO optimization for WordPress content + +**Real-world use case from project memory:** +- "Create SEO title under 60 chars, excerpt under 156 chars, suggest 5-7 WordPress tags, internal linking opportunities to foundational content" + +**Workflow steps:** +1. Analyze content for primary keywords and topics +2. Generate SEO title (under 60 chars) +3. Create meta description/excerpt (under 156 chars) +4. Suggest WordPress tags (5-10) +5. Identify internal linking opportunities +6. Validate character limits and keyword density + +**Parameters:** +- `content` (str): Full article content +- `existing_tags` (list): Current WordPress tag library +- `target_keywords` (list): Keywords to optimize for +- `internal_links_context` (str): Description of site structure for linking +- `platform` (str): "wordpress" (default), "ghost", "medium" + +**Model:** gemini-2.5-pro (analytical, search-aware) +**Temperature:** 0.4 +**Web search:** Enabled for keyword research + +--- + +### 6. `guestedit` (Workflow Tool) +**Purpose:** Edit guest content while preserving author expertise and voice + +**Real-world use case from project memory:** +- "Edit this guest article on PCB components: preserve technical authority, add key takeaways section, suggest internal links to foundational content, enhance educational flow, maintain expert's voice" + +**Workflow steps:** +1. Analyze guest author's voice and technical expertise level +2. Identify areas needing clarification vs. areas to preserve +3. Generate educational enhancements (key takeaways, callouts) +4. Suggest internal linking without fabricating URLs +5. Validate technical accuracy via web search +6. Present edits with rationale + +**Parameters:** +- `guest_content` (str): Original article +- `author_name` (str): Guest author for voice reference +- `expertise_level` (str): "expert", "practitioner", "educator" +- `enhancements_needed` (list): ["key_takeaways", "internal_links", "clarity", "seo"] +- `site_context` (str): Description of broader content ecosystem +- `verify_technical` (bool): Use web search for fact-checking + +**Model:** gemini-2.5-pro (analytical + creative balance) +**Temperature:** 0.5 +**Web search:** Enabled for technical verification + +--- + +### 7. `linkstrategy` (Workflow Tool) +**Purpose:** Internal linking and cross-platform content integration strategy + +**Real-world use case from project memory:** +- "Find opportunities to link this advanced PCB article to foundational HVAC knowledge. Connect blog with related podcast episode, Instagram post, and YouTube video from our content database." + +**Workflow steps:** +1. Analyze content topics and technical depth +2. Identify prerequisite concepts needing internal links +3. Search for related content across platforms (blog, podcast, video, social) +4. Generate linking strategy with anchor text suggestions +5. Validate URLs against actual content (no fabrication) +6. Create cross-platform promotion plan + +**Parameters:** +- `primary_content` (str): Main piece to link from/to +- `content_type` (str): "blog", "newsletter", "social_post" +- `available_platforms` (list): ["blog", "podcast", "youtube", "instagram"] +- `technical_depth` (str): "foundational", "intermediate", "advanced" +- `verify_links` (bool): Validate URLs exist before suggesting + +**Model:** gemini-2.5-pro (analytical, relationship mapping) +**Temperature:** 0.4 +**Web search:** Enabled for content discovery + +--- + +### 8. `factcheck` (Simple Tool) +**Purpose:** Quick technical fact verification via web search + +**Real-world use case from project memory:** +- "Verify these technical claims about voltage regulation chains in HVAC PCBs" +- "Check if this White-Rodgers universal control model number and compatibility claims are accurate" + +**Key features:** +- Web search-powered verification +- Cites sources for each claim +- Flags unsupported or questionable statements +- Quick turnaround for pre-publish checks + +**Parameters:** +- `content` (str): Content with claims to verify +- `technical_domain` (str): "hvac", "software", "general" +- `claim_type` (str): "product_specs", "technical_process", "statistics", "general" +- `confidence_threshold` (str): "high_confidence_only", "balanced", "comprehensive" + +**Model:** gemini-2.5-pro (search-augmented, analytical) +**Temperature:** 0.2 (precision for facts) +**Web search:** Required (core functionality) + +--- + +### 9. `voiceanalysis` (Workflow Tool) +**Purpose:** Extract and codify brand/author voice for consistency + +**Real-world use case from project memory:** +- "Analyze these 10 pieces of Gary McCreadie's content to extract voice patterns, then check if this new draft matches his authentic teaching voice" + +**Workflow steps:** +1. Analyze sample content for voice characteristics +2. Extract patterns: sentence structure, vocabulary, tone markers +3. Create voice profile with examples +4. Compare new content against profile +5. Flag inconsistencies with suggested rewrites +6. Track confidence in voice matching + +**Parameters:** +- `sample_content` (list): 5-15 pieces showing authentic voice +- `author_name` (str): Voice owner for reference +- `new_content` (str): Content to validate against voice +- `voice_aspects` (list): ["tone", "vocabulary", "structure", "educational_style"] +- `generate_profile` (bool): Create reusable voice profile + +**Model:** gemini-2.5-pro (pattern analysis) +**Temperature:** 0.3 + +--- + +### 10. `campaignmap` (Workflow Tool) +**Purpose:** Map content campaign across multiple touchpoints + +**Real-world use case from project memory:** +- "Plan content campaign for measureQuick National Championship: social teaser posts, email sequence, on-site promotion, post-event follow-up, blog recap with internal links" + +**Workflow steps:** +1. Define campaign goals and timeline +2. Map content across platforms and stages (awareness → consideration → action → retention) +3. Create content calendar with dependencies +4. Generate messaging for each touchpoint +5. Plan cross-promotion and internal linking +6. Set success metrics per channel + +**Parameters:** +- `campaign_goal` (str): Primary objective +- `timeline_days` (int): Campaign duration +- `platforms` (list): Channels to use +- `content_pillars` (list): Key messages to reinforce +- `target_audience` (str): Audience description +- `existing_assets` (list): Content to repurpose + +**Model:** gemini-2.5-pro (strategic planning) +**Temperature:** 0.6 + +--- + +## Tool Architecture + +### Simple Tools (3) +Fast, single-shot interactions for rapid iteration: +- `contentvariant` - Bulk variation generation +- `platformadapt` - Cross-platform adaptation +- `subjectlines` - Email subject line testing +- `factcheck` - Quick verification + +### Workflow Tools (6) +Multi-step systematic processes: +- `styleguide` - Style enforcement workflow +- `seooptimize` - Comprehensive SEO process +- `guestedit` - Guest content editing workflow +- `linkstrategy` - Internal linking analysis +- `voiceanalysis` - Voice extraction and validation +- `campaignmap` - Multi-touch campaign planning + +### Keep from Original Zen +- `chat` - General brainstorming +- `thinkdeep` - Deep strategic thinking +- `planner` - Project planning + +## Model Assignment Strategy + +**Primary Models:** +- **minimax/minimax-m2**: Creative content generation (contentvariant, platformadapt) +- **google/gemini-2.5-pro**: Analytical and strategic work (seooptimize, guestedit, linkstrategy, voiceanalysis, campaignmap, styleguide, factcheck) +- **google/gemini-flash**: Fast bulk generation (subjectlines) + +**Fallback Chain:** +1. Configured provider (minimax or gemini) +2. OpenRouter (if API key present) +3. Error (no fallback to Anthropic to avoid confusion) + +## Configuration Defaults + +```json +{ + "mcpServers": { + "zen-marketing": { + "command": "/home/ben/mcp/zen-marketing/.venv/bin/python", + "args": ["/home/ben/mcp/zen-marketing/server.py"], + "env": { + "OPENROUTER_API_KEY": "your-key", + "DEFAULT_MODEL": "gemini-2.5-pro", + "FAST_MODEL": "gemini-flash", + "CREATIVE_MODEL": "minimax-m2", + "DISABLED_TOOLS": "analyze,refactor,testgen,secaudit,docgen,tracer,precommit,codereview,debug,consensus,challenge", + "ENABLE_WEB_SEARCH": "true", + "LOG_LEVEL": "INFO" + } + } + } +} +``` + +## Tool Temperature Defaults + +| Tool | Temperature | Rationale | +|------|-------------|-----------| +| contentvariant | 0.8 | High creativity for diverse variations | +| platformadapt | 0.7 | Creative while maintaining message | +| subjectlines | 0.8 | Explore diverse psychological angles | +| styleguide | 0.3 | Precision for consistency | +| seooptimize | 0.4 | Balanced analytical + creative | +| guestedit | 0.5 | Balance preservation + enhancement | +| linkstrategy | 0.4 | Analytical relationship mapping | +| factcheck | 0.2 | Precision for accuracy | +| voiceanalysis | 0.3 | Pattern recognition precision | +| campaignmap | 0.6 | Strategic creativity | + +## Differences from Zen Code + +### What's Different +1. **Content-first**: Tools designed for writing, not coding +2. **Variation generation**: Built-in support for A/B testing at scale +3. **Platform awareness**: Character limits, formatting, best practices per channel +4. **Voice preservation**: Tools explicitly maintain author authenticity +5. **SEO native**: WordPress-specific optimization built in +6. **Verification emphasis**: Web search for fact-checking integrated +7. **No code tools**: Remove debug, codereview, precommit, refactor, testgen, secaudit, tracer + +### What's Kept +1. **Conversation continuity**: `continuation_id` across tools +2. **Multi-model orchestration**: Different models for different tasks +3. **Workflow architecture**: Multi-step systematic processes +4. **File handling**: Attach content files, brand guidelines +5. **Image support**: Analyze screenshots, brand assets +6. **Thinking modes**: Control depth vs. cost +7. **Web search**: Current information access + +### What's Enhanced +1. **Bulk operations**: Generate 5-20 variations in one call +2. **Style enforcement**: Custom writing rules +3. **Cross-platform thinking**: Native multi-channel workflows +4. **Technical verification**: Domain-specific fact-checking + +## Implementation Phases + +### Phase 1: Foundation (Week 1) +- Fork Zen MCP Server codebase +- Remove code-specific tools +- Configure minimax and gemini providers +- Update system prompts for marketing context +- Basic testing with chat tool + +### Phase 2: Simple Tools (Week 2) +- Implement `contentvariant` +- Implement `platformadapt` +- Implement `subjectlines` +- Implement `factcheck` +- Test variation generation workflows + +### Phase 3: Workflow Tools (Week 3-4) +- Implement `styleguide` +- Implement `seooptimize` +- Implement `guestedit` +- Implement `linkstrategy` +- Test multi-step workflows + +### Phase 4: Advanced Features (Week 5) +- Implement `voiceanalysis` +- Implement `campaignmap` +- Add voice profile storage +- Test complex campaign workflows + +### Phase 5: Polish & Documentation (Week 6) +- User documentation +- Example workflows +- Configuration guide +- Testing with real project memories + +## Success Metrics + +1. **Variation Quality**: Can generate 10+ usable subject lines in one call +2. **Platform Accuracy**: Respects character limits and formatting rules +3. **Voice Consistency**: Successfully preserves guest author expertise +4. **SEO Effectiveness**: Generated titles/descriptions pass WordPress SEO checks +5. **Cross-platform Integration**: Correctly maps content across blog/social/email +6. **Fact Accuracy**: Catches technical errors before publication +7. **Speed**: Simple tools respond in <10 seconds, workflow tools in <30 seconds + +## File Structure + +``` +zen-marketing/ +├── server.py # Main MCP server (adapted from zen) +├── config.py # Marketing-specific configuration +├── PLAN.md # This file +├── CLAUDE.md # Development guide +├── README.md # User documentation +├── requirements.txt # Python dependencies +├── .env.example # Environment template +├── run-server.sh # Setup and run script +├── tools/ +│ ├── contentvariant.py # Bulk variation generation +│ ├── platformadapt.py # Cross-platform adaptation +│ ├── subjectlines.py # Email subject lines +│ ├── styleguide.py # Writing style enforcement +│ ├── seooptimize.py # WordPress SEO workflow +│ ├── guestedit.py # Guest content editing +│ ├── linkstrategy.py # Internal linking strategy +│ ├── factcheck.py # Technical verification +│ ├── voiceanalysis.py # Voice extraction/validation +│ ├── campaignmap.py # Campaign planning +│ └── shared/ # Shared utilities from zen +├── providers/ +│ ├── gemini.py # Google Gemini provider +│ ├── minimax.py # Minimax provider (NEW) +│ ├── openrouter.py # OpenRouter fallback +│ └── registry.py # Provider registration +├── systemprompts/ +│ ├── contentvariant_prompt.py +│ ├── platformadapt_prompt.py +│ ├── subjectlines_prompt.py +│ ├── styleguide_prompt.py +│ ├── seooptimize_prompt.py +│ ├── guestedit_prompt.py +│ ├── linkstrategy_prompt.py +│ ├── factcheck_prompt.py +│ ├── voiceanalysis_prompt.py +│ └── campaignmap_prompt.py +└── tests/ + ├── test_contentvariant.py + ├── test_platformadapt.py + └── ... +``` + +## Development Priorities + +### High Priority (Essential) +1. `contentvariant` - Most used feature from project memories +2. `subjectlines` - Specific workflow mentioned multiple times +3. `platformadapt` - Core multi-channel need +4. `styleguide` - Explicit writing rules enforcement +5. `factcheck` - Technical accuracy verification + +### Medium Priority (Important) +6. `seooptimize` - WordPress-specific SEO needs +7. `guestedit` - Guest content workflow +8. `linkstrategy` - Internal linking mentioned frequently + +### Lower Priority (Nice to Have) +9. `voiceanalysis` - Advanced voice preservation +10. `campaignmap` - Strategic planning (can use planner tool initially) + +## Next Steps + +1. Create CLAUDE.md with development guide for zen-marketing +2. Start new session in ~/mcp/zen-marketing/ directory +3. Begin Phase 1 implementation: + - Copy core architecture from zen-mcp-server + - Configure minimax provider + - Remove code-specific tools + - Test with basic chat functionality +4. Implement Phase 2 simple tools starting with `contentvariant` diff --git a/README.md b/README.md new file mode 100644 index 0000000..ffb948d --- /dev/null +++ b/README.md @@ -0,0 +1,163 @@ +# Zen-Marketing MCP Server + +> **Status:** 🚧 In Development - Phase 1 + +AI-powered marketing tools for Claude Desktop, specialized for technical B2B content creation, multi-platform campaigns, and content variation testing. + +## What is This? + +A fork of the [Zen MCP Server](https://github.com/BeehiveInnovations/zen-mcp-server) optimized for marketing workflows instead of software development. Provides Claude Desktop with specialized tools for: + +- **Content Variation** - Generate 5-20 variations for A/B testing +- **Platform Adaptation** - Adapt content across LinkedIn, Twitter, newsletters, blogs +- **Style Enforcement** - Apply writing guidelines automatically +- **SEO Optimization** - WordPress-specific SEO workflows +- **Guest Editing** - Edit external content while preserving author voice +- **Fact Verification** - Technical accuracy checking via web search +- **Internal Linking** - Cross-platform content integration strategy +- **Campaign Planning** - Multi-touch campaign mapping + +## For Whom? + +Solo marketing professionals managing: +- Technical newsletters and educational content +- Multi-platform social media campaigns +- WordPress blogs with SEO requirements +- Guest contributor content +- A/B testing and content experimentation +- B2B SaaS or technical product marketing + +## Key Differences from Zen Code + +| Zen Code | Zen Marketing | +|----------|---------------| +| Code review, debugging, testing | Content variation, SEO, style guide | +| Software architecture analysis | Campaign planning, voice analysis | +| Development workflows | Marketing workflows | +| Technical accuracy for code | Technical accuracy for content | +| GitHub/git integration | WordPress/platform integration | + +## Tools (Planned) + +### Simple Tools (Fast, Single-Shot) +- `contentvariant` - Generate 5-20 variations for testing +- `platformadapt` - Cross-platform content adaptation +- `subjectlines` - Email subject line generation +- `factcheck` - Technical fact verification + +### Workflow Tools (Multi-Step) +- `styleguide` - Writing style enforcement +- `seooptimize` - WordPress SEO optimization +- `guestedit` - Guest content editing +- `linkstrategy` - Internal linking strategy +- `voiceanalysis` - Voice extraction and validation +- `campaignmap` - Campaign planning + +### Kept from Zen +- `chat` - General brainstorming +- `thinkdeep` - Deep strategic thinking +- `planner` - Project planning + +## Models + +**Primary Models:** +- **Gemini 2.5 Pro** - Analytical and strategic work +- **Gemini Flash** - Fast bulk generation +- **Minimax M2** - Creative content generation + +**Fallback:** +- OpenRouter (if configured) + +## Quick Start + +> **Note:** Not yet functional - see [PLAN.md](PLAN.md) for implementation roadmap + +```bash +# Clone and setup +git clone [repo-url] +cd zen-marketing +./run-server.sh + +# Configure Claude Desktop (~/.claude.json) +{ + "mcpServers": { + "zen-marketing": { + "command": "/home/ben/mcp/zen-marketing/.venv/bin/python", + "args": ["/home/ben/mcp/zen-marketing/server.py"], + "env": { + "GEMINI_API_KEY": "your-key", + "OPENROUTER_API_KEY": "your-key", + "DEFAULT_MODEL": "gemini-2.5-pro" + } + } + } +} + +# Restart Claude Desktop +``` + +## Example Usage + +``` +# Generate subject line variations +"Use zen-marketing contentvariant to generate 15 subject lines for an HVAC newsletter about PCB diagnostics. Test angles: technical curiosity, contrarian, knowledge gap, urgency." + +# Adapt content across platforms +"Use zen-marketing platformadapt to take this blog post intro and create versions for LinkedIn (1300 chars), Twitter (280), Instagram (2200), and newsletter." + +# Enforce writing style +"Use zen-marketing styleguide to check this draft for em-dashes, 'This isn't X, it's Y' constructions, and ensure direct affirmative statements." + +# SEO optimize +"Use zen-marketing seooptimize to create SEO title under 60 chars, excerpt under 156 chars, and suggest WordPress tags for this article about HVAC voltage regulation." +``` + +## Development Status + +### Phase 1: Foundation (Current) +- [x] Create directory structure +- [x] Write implementation plan +- [x] Create development guide +- [ ] Copy core architecture from zen-mcp-server +- [ ] Configure minimax provider +- [ ] Test basic functionality + +### Phase 2: Simple Tools (Next) +- [ ] Implement `contentvariant` +- [ ] Implement `subjectlines` +- [ ] Implement `platformadapt` +- [ ] Implement `factcheck` + +See [PLAN.md](PLAN.md) for complete roadmap. + +## Documentation + +- **[PLAN.md](PLAN.md)** - Detailed implementation plan with tool designs +- **[CLAUDE.md](CLAUDE.md)** - Development guide for contributors +- **[Project Memories](PLAN.md#project-overview)** - Real-world usage examples + +## Architecture + +Based on Zen MCP Server's proven architecture: +- **Conversation continuity** via `continuation_id` +- **Multi-model orchestration** (Gemini, Minimax, OpenRouter) +- **Simple tools** for fast iteration +- **Workflow tools** for multi-step processes +- **Web search integration** for current information +- **File handling** for content and brand assets + +## License + +Apache 2.0 License (inherited from Zen MCP Server) + +## Acknowledgments + +Built on the foundation of: +- [Zen MCP Server](https://github.com/BeehiveInnovations/zen-mcp-server) by Fahad Gilani +- [Model Context Protocol](https://modelcontextprotocol.com) by Anthropic +- [Claude Desktop](https://claude.ai/download) - AI interface + +--- + +**Status:** Planning phase complete, ready for implementation +**Next:** Start new Claude session in `~/mcp/zen-marketing/` to begin Phase 1 diff --git a/config.py b/config.py new file mode 100644 index 0000000..ed4a7c9 --- /dev/null +++ b/config.py @@ -0,0 +1,107 @@ +""" +Configuration and constants for Zen-Marketing MCP Server + +This module centralizes all configuration settings for the Zen-Marketing MCP Server. +It defines model configurations, token limits, temperature defaults, and other +constants used throughout the application. + +Configuration values can be overridden by environment variables where appropriate. +""" + +import os + +# Version and metadata +__version__ = "0.1.0" +__updated__ = "2025-11-07" +__author__ = "Ben (based on Zen MCP Server by Fahad Gilani)" + +# Model configuration +# DEFAULT_MODEL: The default model used for all AI operations +# Can be overridden by setting DEFAULT_MODEL environment variable +DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "google/gemini-2.5-pro-latest") + +# Fast model for quick operations (variations, subject lines) +FAST_MODEL = os.getenv("FAST_MODEL", "google/gemini-2.5-flash-preview-09-2025") + +# Creative model for content generation +CREATIVE_MODEL = os.getenv("CREATIVE_MODEL", "minimax/minimax-m2") + +# Auto mode detection - when DEFAULT_MODEL is "auto", Claude picks the model +IS_AUTO_MODE = DEFAULT_MODEL.lower() == "auto" + +# Temperature defaults for different content types +# Temperature controls the randomness/creativity of model responses +# Lower values (0.0-0.3) produce more deterministic, focused responses +# Higher values (0.7-1.0) produce more creative, varied responses + +# TEMPERATURE_PRECISION: Used for fact-checking and technical verification +TEMPERATURE_PRECISION = 0.2 # For factcheck, technical verification + +# TEMPERATURE_ANALYTICAL: Used for style enforcement and SEO optimization +TEMPERATURE_ANALYTICAL = 0.3 # For styleguide, seooptimize, voiceanalysis + +# TEMPERATURE_BALANCED: Used for strategic planning +TEMPERATURE_BALANCED = 0.5 # For guestedit, linkstrategy, campaignmap + +# TEMPERATURE_CREATIVE: Used for content variation and adaptation +TEMPERATURE_CREATIVE = 0.7 # For platformadapt + +# TEMPERATURE_HIGHLY_CREATIVE: Used for bulk variation generation +TEMPERATURE_HIGHLY_CREATIVE = 0.8 # For contentvariant, subjectlines + +# Thinking Mode Defaults +DEFAULT_THINKING_MODE_THINKDEEP = os.getenv("DEFAULT_THINKING_MODE_THINKDEEP", "high") + +# MCP Protocol Transport Limits +def _calculate_mcp_prompt_limit() -> int: + """ + Calculate MCP prompt size limit based on MAX_MCP_OUTPUT_TOKENS environment variable. + + Returns: + Maximum character count for user input prompts + """ + max_tokens_str = os.getenv("MAX_MCP_OUTPUT_TOKENS") + + if max_tokens_str: + try: + max_tokens = int(max_tokens_str) + # Allocate 60% of tokens for input, convert to characters (~4 chars per token) + input_token_budget = int(max_tokens * 0.6) + character_limit = input_token_budget * 4 + return character_limit + except (ValueError, TypeError): + pass + + # Default fallback: 60,000 characters + return 60_000 + + +MCP_PROMPT_SIZE_LIMIT = _calculate_mcp_prompt_limit() + +# Language/Locale Configuration +LOCALE = os.getenv("LOCALE", "") + +# Platform character limits +PLATFORM_LIMITS = { + "twitter": 280, + "bluesky": 300, + "linkedin": 3000, + "linkedin_optimal": 1300, + "instagram": 2200, + "facebook": 500, # Optimal length + "email_subject": 60, + "email_preview": 100, + "meta_description": 156, + "page_title": 60, +} + +# Web search configuration +ENABLE_WEB_SEARCH = os.getenv("ENABLE_WEB_SEARCH", "true").lower() == "true" + +# Tool disabling +# Comma-separated list of tools to disable +DISABLED_TOOLS_STR = os.getenv("DISABLED_TOOLS", "") +DISABLED_TOOLS = set(tool.strip() for tool in DISABLED_TOOLS_STR.split(",") if tool.strip()) + +# Logging configuration +LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") diff --git a/providers/__init__.py b/providers/__init__.py new file mode 100644 index 0000000..311fafa --- /dev/null +++ b/providers/__init__.py @@ -0,0 +1,20 @@ +"""Model provider abstractions for supporting multiple AI providers.""" + +from .base import ModelProvider +from .gemini import GeminiModelProvider +from .openai_compatible import OpenAICompatibleProvider +from .openai_provider import OpenAIModelProvider +from .openrouter import OpenRouterProvider +from .registry import ModelProviderRegistry +from .shared import ModelCapabilities, ModelResponse + +__all__ = [ + "ModelProvider", + "ModelResponse", + "ModelCapabilities", + "ModelProviderRegistry", + "GeminiModelProvider", + "OpenAIModelProvider", + "OpenAICompatibleProvider", + "OpenRouterProvider", +] diff --git a/providers/base.py b/providers/base.py new file mode 100644 index 0000000..fd316c1 --- /dev/null +++ b/providers/base.py @@ -0,0 +1,268 @@ +"""Base interfaces and common behaviour for model providers.""" + +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional + +if TYPE_CHECKING: + from tools.models import ToolModelCategory + +from .shared import ModelCapabilities, ModelResponse, ProviderType + +logger = logging.getLogger(__name__) + + +class ModelProvider(ABC): + """Abstract base class for all model backends in the MCP server. + + Role + Defines the interface every provider must implement so the registry, + restriction service, and tools have a uniform surface for listing + models, resolving aliases, and executing requests. + + Responsibilities + * expose static capability metadata for each supported model via + :class:`ModelCapabilities` + * accept user prompts, forward them to the underlying SDK, and wrap + responses in :class:`ModelResponse` + * report tokenizer counts for budgeting and validation logic + * advertise provider identity (``ProviderType``) so restriction + policies can map environment configuration onto providers + * validate whether a model name or alias is recognised by the provider + + Shared helpers like temperature validation, alias resolution, and + restriction-aware ``list_models`` live here so concrete subclasses only + need to supply their catalogue and wire up SDK-specific behaviour. + """ + + # All concrete providers must define their supported models + MODEL_CAPABILITIES: dict[str, Any] = {} + + def __init__(self, api_key: str, **kwargs): + """Initialize the provider with API key and optional configuration.""" + self.api_key = api_key + self.config = kwargs + + # ------------------------------------------------------------------ + # Provider identity & capability surface + # ------------------------------------------------------------------ + @abstractmethod + def get_provider_type(self) -> ProviderType: + """Return the concrete provider identity.""" + + def get_capabilities(self, model_name: str) -> ModelCapabilities: + """Resolve capability metadata for a model name. + + This centralises the alias resolution → lookup → restriction check + pipeline so providers only override the pieces they genuinely need to + customise. Subclasses usually only override ``_lookup_capabilities`` to + integrate a registry or dynamic source, or ``_finalise_capabilities`` to + tweak the returned object. + """ + + resolved_name = self._resolve_model_name(model_name) + capabilities = self._lookup_capabilities(resolved_name, model_name) + + if capabilities is None: + self._raise_unsupported_model(model_name) + + self._ensure_model_allowed(capabilities, resolved_name, model_name) + return self._finalise_capabilities(capabilities, resolved_name, model_name) + + def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: + """Return statically declared capabilities when available.""" + + model_map = getattr(self, "MODEL_CAPABILITIES", None) + if isinstance(model_map, dict) and model_map: + return {k: v for k, v in model_map.items() if isinstance(v, ModelCapabilities)} + return {} + + def list_models( + self, + *, + respect_restrictions: bool = True, + include_aliases: bool = True, + lowercase: bool = False, + unique: bool = False, + ) -> list[str]: + """Return formatted model names supported by this provider.""" + + model_configs = self.get_all_model_capabilities() + if not model_configs: + return [] + + restriction_service = None + if respect_restrictions: + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() + + if restriction_service: + allowed_configs = {} + for model_name, config in model_configs.items(): + if restriction_service.is_allowed(self.get_provider_type(), model_name): + allowed_configs[model_name] = config + model_configs = allowed_configs + + if not model_configs: + return [] + + return ModelCapabilities.collect_model_names( + model_configs, + include_aliases=include_aliases, + lowercase=lowercase, + unique=unique, + ) + + # ------------------------------------------------------------------ + # Request execution + # ------------------------------------------------------------------ + @abstractmethod + def generate_content( + self, + prompt: str, + model_name: str, + system_prompt: Optional[str] = None, + temperature: float = 0.3, + max_output_tokens: Optional[int] = None, + **kwargs, + ) -> ModelResponse: + """Generate content using the model.""" + + def count_tokens(self, text: str, model_name: str) -> int: + """Estimate token usage for a piece of text.""" + + resolved_model = self._resolve_model_name(model_name) + + if not text: + return 0 + + estimated = max(1, len(text) // 4) + logger.debug("Estimating %s tokens for model %s via character heuristic", estimated, resolved_model) + return estimated + + def close(self) -> None: + """Clean up any resources held by the provider.""" + + return + + # ------------------------------------------------------------------ + # Validation hooks + # ------------------------------------------------------------------ + def validate_model_name(self, model_name: str) -> bool: + """Return ``True`` when the model resolves to an allowed capability.""" + + try: + self.get_capabilities(model_name) + except ValueError: + return False + return True + + def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None: + """Validate model parameters against capabilities.""" + + capabilities = self.get_capabilities(model_name) + + if not capabilities.temperature_constraint.validate(temperature): + constraint_desc = capabilities.temperature_constraint.get_description() + raise ValueError(f"Temperature {temperature} is invalid for model {model_name}. {constraint_desc}") + + # ------------------------------------------------------------------ + # Preference / registry hooks + # ------------------------------------------------------------------ + def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]: + """Get the preferred model from this provider for a given category.""" + + return None + + def get_model_registry(self) -> Optional[dict[str, Any]]: + """Return the model registry backing this provider, if any.""" + + return None + + # ------------------------------------------------------------------ + # Capability lookup pipeline + # ------------------------------------------------------------------ + def _lookup_capabilities( + self, + canonical_name: str, + requested_name: Optional[str] = None, + ) -> Optional[ModelCapabilities]: + """Return ``ModelCapabilities`` for the canonical model name.""" + + return self.get_all_model_capabilities().get(canonical_name) + + def _ensure_model_allowed( + self, + capabilities: ModelCapabilities, + canonical_name: str, + requested_name: str, + ) -> None: + """Raise ``ValueError`` if the model violates restriction policy.""" + + try: + from utils.model_restrictions import get_restriction_service + except Exception: # pragma: no cover - only triggered if service import breaks + return + + restriction_service = get_restriction_service() + if not restriction_service: + return + + if restriction_service.is_allowed(self.get_provider_type(), canonical_name, requested_name): + return + + raise ValueError( + f"{self.get_provider_type().value} model '{canonical_name}' is not allowed by restriction policy." + ) + + def _finalise_capabilities( + self, + capabilities: ModelCapabilities, + canonical_name: str, + requested_name: str, + ) -> ModelCapabilities: + """Allow subclasses to adjust capability metadata before returning.""" + + return capabilities + + def _raise_unsupported_model(self, model_name: str) -> None: + """Raise the canonical unsupported-model error.""" + + raise ValueError(f"Unsupported model '{model_name}' for provider {self.get_provider_type().value}.") + + def _resolve_model_name(self, model_name: str) -> str: + """Resolve model shorthand to full name. + + This implementation uses the hook methods to support different + model configuration sources. + + Args: + model_name: Model name that may be an alias + + Returns: + Resolved model name + """ + # Get model configurations from the hook method + model_configs = self.get_all_model_capabilities() + + # First check if it's already a base model name (case-sensitive exact match) + if model_name in model_configs: + return model_name + + # Check case-insensitively for both base models and aliases + model_name_lower = model_name.lower() + + # Check base model names case-insensitively + for base_model in model_configs: + if base_model.lower() == model_name_lower: + return base_model + + # Check aliases from the model configurations + alias_map = ModelCapabilities.collect_aliases(model_configs) + for base_model, aliases in alias_map.items(): + if any(alias.lower() == model_name_lower for alias in aliases): + return base_model + + # If not found, return as-is + return model_name diff --git a/providers/custom.py b/providers/custom.py new file mode 100644 index 0000000..4f7eb50 --- /dev/null +++ b/providers/custom.py @@ -0,0 +1,196 @@ +"""Custom API provider implementation.""" + +import logging +import os +from typing import Optional + +from .openai_compatible import OpenAICompatibleProvider +from .openrouter_registry import OpenRouterModelRegistry +from .shared import ModelCapabilities, ModelResponse, ProviderType + + +class CustomProvider(OpenAICompatibleProvider): + """Adapter for self-hosted or local OpenAI-compatible endpoints. + + Role + Provide a uniform bridge between the MCP server and user-managed + OpenAI-compatible services (Ollama, vLLM, LM Studio, bespoke gateways). + By subclassing :class:`OpenAICompatibleProvider` it inherits request and + token handling, while the custom registry exposes locally defined model + metadata. + + Notable behaviour + * Uses :class:`OpenRouterModelRegistry` to load model definitions and + aliases so custom deployments share the same metadata pipeline as + OpenRouter itself. + * Normalises version-tagged model names (``model:latest``) and applies + restriction policies just like cloud providers, ensuring consistent + behaviour across environments. + """ + + FRIENDLY_NAME = "Custom API" + + # Model registry for managing configurations and aliases (shared with OpenRouter) + _registry: Optional[OpenRouterModelRegistry] = None + + def __init__(self, api_key: str = "", base_url: str = "", **kwargs): + """Initialize Custom provider for local/self-hosted models. + + This provider supports any OpenAI-compatible API endpoint including: + - Ollama (typically no API key required) + - vLLM (may require API key) + - LM Studio (may require API key) + - Text Generation WebUI (may require API key) + - Enterprise/self-hosted APIs (typically require API key) + + Args: + api_key: API key for the custom endpoint. Can be empty string for + providers that don't require authentication (like Ollama). + Falls back to CUSTOM_API_KEY environment variable if not provided. + base_url: Base URL for the custom API endpoint (e.g., 'http://localhost:11434/v1'). + Falls back to CUSTOM_API_URL environment variable if not provided. + **kwargs: Additional configuration passed to parent OpenAI-compatible provider + + Raises: + ValueError: If no base_url is provided via parameter or environment variable + """ + # Fall back to environment variables only if not provided + if not base_url: + base_url = os.getenv("CUSTOM_API_URL", "") + if not api_key: + api_key = os.getenv("CUSTOM_API_KEY", "") + + if not base_url: + raise ValueError( + "Custom API URL must be provided via base_url parameter or CUSTOM_API_URL environment variable" + ) + + # For Ollama and other providers that don't require authentication, + # set a dummy API key to avoid OpenAI client header issues + if not api_key: + api_key = "dummy-key-for-unauthenticated-endpoint" + logging.debug("Using dummy API key for unauthenticated custom endpoint") + + logging.info(f"Initializing Custom provider with endpoint: {base_url}") + + super().__init__(api_key, base_url=base_url, **kwargs) + + # Initialize model registry (shared with OpenRouter for consistent aliases) + if CustomProvider._registry is None: + CustomProvider._registry = OpenRouterModelRegistry() + # Log loaded models and aliases only on first load + models = self._registry.list_models() + aliases = self._registry.list_aliases() + logging.info(f"Custom provider loaded {len(models)} models with {len(aliases)} aliases") + + # ------------------------------------------------------------------ + # Capability surface + # ------------------------------------------------------------------ + def _lookup_capabilities( + self, + canonical_name: str, + requested_name: Optional[str] = None, + ) -> Optional[ModelCapabilities]: + """Return capabilities for models explicitly marked as custom.""" + + builtin = super()._lookup_capabilities(canonical_name, requested_name) + if builtin is not None: + return builtin + + registry_entry = self._registry.resolve(canonical_name) + if registry_entry and getattr(registry_entry, "is_custom", False): + registry_entry.provider = ProviderType.CUSTOM + return registry_entry + + logging.debug( + "Custom provider cannot resolve model '%s'; ensure it is declared with 'is_custom': true in custom_models.json", + canonical_name, + ) + return None + + def get_provider_type(self) -> ProviderType: + """Identify this provider for restriction and logging logic.""" + + return ProviderType.CUSTOM + + # ------------------------------------------------------------------ + # Validation + # ------------------------------------------------------------------ + + # ------------------------------------------------------------------ + # Request execution + # ------------------------------------------------------------------ + + def generate_content( + self, + prompt: str, + model_name: str, + system_prompt: Optional[str] = None, + temperature: float = 0.3, + max_output_tokens: Optional[int] = None, + **kwargs, + ) -> ModelResponse: + """Generate content using the custom API. + + Args: + prompt: User prompt to send to the model + model_name: Name of the model to use + system_prompt: Optional system prompt for model behavior + temperature: Sampling temperature + max_output_tokens: Maximum tokens to generate + **kwargs: Additional provider-specific parameters + + Returns: + ModelResponse with generated content and metadata + """ + # Resolve model alias to actual model name + resolved_model = self._resolve_model_name(model_name) + + # Call parent method with resolved model name + return super().generate_content( + prompt=prompt, + model_name=resolved_model, + system_prompt=system_prompt, + temperature=temperature, + max_output_tokens=max_output_tokens, + **kwargs, + ) + + # ------------------------------------------------------------------ + # Registry helpers + # ------------------------------------------------------------------ + + def _resolve_model_name(self, model_name: str) -> str: + """Resolve registry aliases and strip version tags for local models.""" + + config = self._registry.resolve(model_name) + if config: + if config.model_name != model_name: + logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'") + return config.model_name + + if ":" in model_name: + base_model = model_name.split(":")[0] + logging.debug(f"Stripped version tag from '{model_name}' -> '{base_model}'") + + base_config = self._registry.resolve(base_model) + if base_config: + logging.info(f"Resolved base model '{base_model}' to '{base_config.model_name}'") + return base_config.model_name + return base_model + + logging.debug(f"Model '{model_name}' not found in registry, using as-is") + return model_name + + def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: + """Expose registry capabilities for models marked as custom.""" + + if not self._registry: + return {} + + capabilities: dict[str, ModelCapabilities] = {} + for model_name in self._registry.list_models(): + config = self._registry.resolve(model_name) + if config and getattr(config, "is_custom", False): + capabilities[model_name] = config + return capabilities diff --git a/providers/dial.py b/providers/dial.py new file mode 100644 index 0000000..db11417 --- /dev/null +++ b/providers/dial.py @@ -0,0 +1,473 @@ +"""DIAL (Data & AI Layer) model provider implementation.""" + +import logging +import os +import threading +import time +from typing import Optional + +from .openai_compatible import OpenAICompatibleProvider +from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint + +logger = logging.getLogger(__name__) + + +class DIALModelProvider(OpenAICompatibleProvider): + """Client for the DIAL (Data & AI Layer) aggregation service. + + DIAL exposes several third-party models behind a single OpenAI-compatible + endpoint. This provider wraps the service, publishes capability metadata + for the known deployments, and centralises retry/backoff settings tailored + to DIAL's latency characteristics. + """ + + FRIENDLY_NAME = "DIAL" + + # Retry configuration for API calls + MAX_RETRIES = 4 + RETRY_DELAYS = [1, 3, 5, 8] # seconds + + # Model configurations using ModelCapabilities objects + MODEL_CAPABILITIES = { + "o3-2025-04-16": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="o3-2025-04-16", + friendly_name="DIAL (O3)", + context_window=200_000, + max_output_tokens=100_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # DIAL may not expose function calling + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=False, # O3 models don't accept temperature + temperature_constraint=TemperatureConstraint.create("fixed"), + description="OpenAI O3 via DIAL - Strong reasoning model", + aliases=["o3"], + ), + "o4-mini-2025-04-16": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="o4-mini-2025-04-16", + friendly_name="DIAL (O4-mini)", + context_window=200_000, + max_output_tokens=100_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # DIAL may not expose function calling + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=False, # O4 models don't accept temperature + temperature_constraint=TemperatureConstraint.create("fixed"), + description="OpenAI O4-mini via DIAL - Fast reasoning model", + aliases=["o4-mini"], + ), + "anthropic.claude-sonnet-4.1-20250805-v1:0": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="anthropic.claude-sonnet-4.1-20250805-v1:0", + friendly_name="DIAL (Sonnet 4.1)", + context_window=200_000, + max_output_tokens=64_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # Claude doesn't have function calling + supports_json_mode=False, # Claude doesn't have JSON mode + supports_images=True, + max_image_size_mb=5.0, + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("range"), + description="Claude Sonnet 4.1 via DIAL - Balanced performance", + aliases=["sonnet-4.1", "sonnet-4"], + ), + "anthropic.claude-sonnet-4.1-20250805-v1:0-with-thinking": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="anthropic.claude-sonnet-4.1-20250805-v1:0-with-thinking", + friendly_name="DIAL (Sonnet 4.1 Thinking)", + context_window=200_000, + max_output_tokens=64_000, + supports_extended_thinking=True, # Thinking mode variant + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # Claude doesn't have function calling + supports_json_mode=False, # Claude doesn't have JSON mode + supports_images=True, + max_image_size_mb=5.0, + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("range"), + description="Claude Sonnet 4.1 with thinking mode via DIAL", + aliases=["sonnet-4.1-thinking", "sonnet-4-thinking"], + ), + "anthropic.claude-opus-4.1-20250805-v1:0": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="anthropic.claude-opus-4.1-20250805-v1:0", + friendly_name="DIAL (Opus 4.1)", + context_window=200_000, + max_output_tokens=64_000, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # Claude doesn't have function calling + supports_json_mode=False, # Claude doesn't have JSON mode + supports_images=True, + max_image_size_mb=5.0, + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("range"), + description="Claude Opus 4.1 via DIAL - Most capable Claude model", + aliases=["opus-4.1", "opus-4"], + ), + "anthropic.claude-opus-4.1-20250805-v1:0-with-thinking": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="anthropic.claude-opus-4.1-20250805-v1:0-with-thinking", + friendly_name="DIAL (Opus 4.1 Thinking)", + context_window=200_000, + max_output_tokens=64_000, + supports_extended_thinking=True, # Thinking mode variant + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # Claude doesn't have function calling + supports_json_mode=False, # Claude doesn't have JSON mode + supports_images=True, + max_image_size_mb=5.0, + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("range"), + description="Claude Opus 4.1 with thinking mode via DIAL", + aliases=["opus-4.1-thinking", "opus-4-thinking"], + ), + "gemini-2.5-pro-preview-03-25-google-search": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="gemini-2.5-pro-preview-03-25-google-search", + friendly_name="DIAL (Gemini 2.5 Pro Search)", + context_window=1_000_000, + max_output_tokens=65_536, + supports_extended_thinking=False, # DIAL doesn't expose thinking mode + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # DIAL may not expose function calling + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("range"), + description="Gemini 2.5 Pro with Google Search via DIAL", + aliases=["gemini-2.5-pro-search"], + ), + "gemini-2.5-pro-preview-05-06": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="gemini-2.5-pro-preview-05-06", + friendly_name="DIAL (Gemini 2.5 Pro)", + context_window=1_000_000, + max_output_tokens=65_536, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # DIAL may not expose function calling + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("range"), + description="Gemini 2.5 Pro via DIAL - Deep reasoning", + aliases=["gemini-2.5-pro"], + ), + "gemini-2.5-flash-preview-05-20": ModelCapabilities( + provider=ProviderType.DIAL, + model_name="gemini-2.5-flash-preview-05-20", + friendly_name="DIAL (Gemini Flash 2.5)", + context_window=1_000_000, + max_output_tokens=65_536, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, # DIAL may not expose function calling + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("range"), + description="Gemini 2.5 Flash via DIAL - Ultra-fast", + aliases=["gemini-2.5-flash"], + ), + } + + def __init__(self, api_key: str, **kwargs): + """Initialize DIAL provider with API key and host. + + Args: + api_key: DIAL API key for authentication + **kwargs: Additional configuration options + """ + # Get DIAL API host from environment or kwargs + dial_host = kwargs.get("base_url") or os.getenv("DIAL_API_HOST") or "https://core.dialx.ai" + + # DIAL uses /openai endpoint for OpenAI-compatible API + if not dial_host.endswith("/openai"): + dial_host = f"{dial_host.rstrip('/')}/openai" + + kwargs["base_url"] = dial_host + + # Get API version from environment or use default + self.api_version = os.getenv("DIAL_API_VERSION", "2024-12-01-preview") + + # Add DIAL-specific headers + # DIAL uses Api-Key header instead of Authorization: Bearer + # Reference: https://dialx.ai/dial_api#section/Authorization + self.DEFAULT_HEADERS = { + "Api-Key": api_key, + } + + # Store the actual API key for use in Api-Key header + self._dial_api_key = api_key + + # Pass a placeholder API key to OpenAI client - we'll override the auth header in httpx + # The actual authentication happens via the Api-Key header in the httpx client + super().__init__("placeholder-not-used", **kwargs) + + # Cache for deployment-specific clients to avoid recreating them on each request + self._deployment_clients = {} + # Lock to ensure thread-safe client creation + self._client_lock = threading.Lock() + + # Create a SINGLE shared httpx client for the provider instance + import httpx + + # Create custom event hooks to remove Authorization header + def remove_auth_header(request): + """Remove Authorization header that OpenAI client adds.""" + # httpx headers are case-insensitive, so we need to check all variations + headers_to_remove = [] + for header_name in request.headers: + if header_name.lower() == "authorization": + headers_to_remove.append(header_name) + + for header_name in headers_to_remove: + del request.headers[header_name] + + self._http_client = httpx.Client( + timeout=self.timeout_config, + verify=True, + follow_redirects=True, + headers=self.DEFAULT_HEADERS.copy(), # Include DIAL headers including Api-Key + limits=httpx.Limits( + max_keepalive_connections=5, + max_connections=10, + keepalive_expiry=30.0, + ), + event_hooks={"request": [remove_auth_header]}, + ) + + logger.info(f"Initialized DIAL provider with host: {dial_host} and api-version: {self.api_version}") + + def get_provider_type(self) -> ProviderType: + """Get the provider type.""" + return ProviderType.DIAL + + def _get_deployment_client(self, deployment: str): + """Get or create a cached client for a specific deployment. + + This avoids recreating OpenAI clients on every request, improving performance. + Reuses the shared HTTP client for connection pooling. + + Args: + deployment: The deployment/model name + + Returns: + OpenAI client configured for the specific deployment + """ + # Check if client already exists without locking for performance + if deployment in self._deployment_clients: + return self._deployment_clients[deployment] + + # Use lock to ensure thread-safe client creation + with self._client_lock: + # Double-check pattern: check again inside the lock + if deployment not in self._deployment_clients: + from openai import OpenAI + + # Build deployment-specific URL + base_url = str(self.client.base_url) + if base_url.endswith("/"): + base_url = base_url[:-1] + + # Remove /openai suffix if present to reconstruct properly + if base_url.endswith("/openai"): + base_url = base_url[:-7] + + deployment_url = f"{base_url}/openai/deployments/{deployment}" + + # Create and cache the client, REUSING the shared http_client + # Use placeholder API key - Authorization header will be removed by http_client event hook + self._deployment_clients[deployment] = OpenAI( + api_key="placeholder-not-used", + base_url=deployment_url, + http_client=self._http_client, # Pass the shared client with Api-Key header + default_query={"api-version": self.api_version}, # Add api-version as query param + ) + + return self._deployment_clients[deployment] + + def generate_content( + self, + prompt: str, + model_name: str, + system_prompt: Optional[str] = None, + temperature: float = 0.3, + max_output_tokens: Optional[int] = None, + images: Optional[list[str]] = None, + **kwargs, + ) -> ModelResponse: + """Generate content using DIAL's deployment-specific endpoint. + + DIAL uses Azure OpenAI-style deployment endpoints: + /openai/deployments/{deployment}/chat/completions + + Args: + prompt: User prompt + model_name: Model name or alias + system_prompt: Optional system prompt + temperature: Sampling temperature + max_output_tokens: Maximum tokens to generate + **kwargs: Additional provider-specific parameters + + Returns: + ModelResponse with generated content and metadata + """ + # Validate model name against allow-list + if not self.validate_model_name(model_name): + raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}") + + # Validate parameters and fetch capabilities + self.validate_parameters(model_name, temperature) + capabilities = self.get_capabilities(model_name) + + # Prepare messages + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + # Build user message content + user_message_content = [] + if prompt: + user_message_content.append({"type": "text", "text": prompt}) + + if images and capabilities.supports_images: + for img_path in images: + processed_image = self._process_image(img_path) + if processed_image: + user_message_content.append(processed_image) + elif images: + logger.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)") + + # Add user message. If only text, content will be a string, otherwise a list. + if len(user_message_content) == 1 and user_message_content[0]["type"] == "text": + messages.append({"role": "user", "content": prompt}) + else: + messages.append({"role": "user", "content": user_message_content}) + + # Resolve model name + resolved_model = self._resolve_model_name(model_name) + + # Build completion parameters + completion_params = { + "model": resolved_model, + "messages": messages, + } + + # Determine temperature support from capabilities + supports_temperature = capabilities.supports_temperature + + # Add temperature parameter if supported + if supports_temperature: + completion_params["temperature"] = temperature + + # Add max tokens if specified and model supports it + if max_output_tokens and supports_temperature: + completion_params["max_tokens"] = max_output_tokens + + # Add additional parameters + for key, value in kwargs.items(): + if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]: + if not supports_temperature and key in ["top_p", "frequency_penalty", "presence_penalty"]: + continue + completion_params[key] = value + + # DIAL-specific: Get cached client for deployment endpoint + deployment_client = self._get_deployment_client(resolved_model) + + # Retry logic with progressive delays + last_exception = None + + for attempt in range(self.MAX_RETRIES): + try: + # Generate completion using deployment-specific client + response = deployment_client.chat.completions.create(**completion_params) + + # Extract content and usage + content = response.choices[0].message.content + usage = self._extract_usage(response) + + return ModelResponse( + content=content, + usage=usage, + model_name=model_name, + friendly_name=self.FRIENDLY_NAME, + provider=self.get_provider_type(), + metadata={ + "finish_reason": response.choices[0].finish_reason, + "model": response.model, + "id": response.id, + "created": response.created, + }, + ) + + except Exception as e: + last_exception = e + + # Check if this is a retryable error + is_retryable = self._is_error_retryable(e) + + if not is_retryable: + # Non-retryable error, raise immediately + raise ValueError(f"DIAL API error for model {model_name}: {str(e)}") + + # If this isn't the last attempt and error is retryable, wait and retry + if attempt < self.MAX_RETRIES - 1: + delay = self.RETRY_DELAYS[attempt] + logger.info( + f"DIAL API error (attempt {attempt + 1}/{self.MAX_RETRIES}), " f"retrying in {delay}s: {str(e)}" + ) + time.sleep(delay) + continue + + # All retries exhausted + raise ValueError( + f"DIAL API error for model {model_name} after {self.MAX_RETRIES} attempts: {str(last_exception)}" + ) + + def close(self) -> None: + """Clean up HTTP clients when provider is closed.""" + logger.info("Closing DIAL provider HTTP clients...") + + # Clear the deployment clients cache + # Note: We don't need to close individual OpenAI clients since they + # use the shared httpx.Client which we close separately + self._deployment_clients.clear() + + # Close the shared HTTP client + if hasattr(self, "_http_client"): + try: + self._http_client.close() + logger.debug("Closed shared HTTP client") + except Exception as e: + logger.warning(f"Error closing shared HTTP client: {e}") + + # Also close the client created by the superclass (OpenAICompatibleProvider) + # as it holds its own httpx.Client instance that is not used by DIAL's generate_content + if hasattr(self, "client") and self.client and hasattr(self.client, "close"): + try: + self.client.close() + logger.debug("Closed superclass's OpenAI client") + except Exception as e: + logger.warning(f"Error closing superclass's OpenAI client: {e}") diff --git a/providers/gemini.py b/providers/gemini.py new file mode 100644 index 0000000..de3fa4d --- /dev/null +++ b/providers/gemini.py @@ -0,0 +1,578 @@ +"""Gemini model provider implementation.""" + +import base64 +import logging +import time +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from tools.models import ToolModelCategory + +from google import genai +from google.genai import types + +from utils.image_utils import validate_image + +from .base import ModelProvider +from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint + +logger = logging.getLogger(__name__) + + +class GeminiModelProvider(ModelProvider): + """First-party Gemini integration built on the official Google SDK. + + The provider advertises detailed thinking-mode budgets, handles optional + custom endpoints, and performs image pre-processing before forwarding a + request to the Gemini APIs. + """ + + # Model configurations using ModelCapabilities objects + MODEL_CAPABILITIES = { + "gemini-2.5-pro": ModelCapabilities( + provider=ProviderType.GOOGLE, + model_name="gemini-2.5-pro", + friendly_name="Gemini (Pro 2.5)", + context_window=1_048_576, # 1M tokens + max_output_tokens=65_536, + supports_extended_thinking=True, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # Vision capability + max_image_size_mb=32.0, # Higher limit for Pro model + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("range"), + max_thinking_tokens=32768, # Max thinking tokens for Pro model + description="Deep reasoning + thinking mode (1M context) - Complex problems, architecture, deep analysis", + aliases=["pro", "gemini pro", "gemini-pro"], + ), + "gemini-2.0-flash": ModelCapabilities( + provider=ProviderType.GOOGLE, + model_name="gemini-2.0-flash", + friendly_name="Gemini (Flash 2.0)", + context_window=1_048_576, # 1M tokens + max_output_tokens=65_536, + supports_extended_thinking=True, # Experimental thinking mode + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # Vision capability + max_image_size_mb=20.0, # Conservative 20MB limit for reliability + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("range"), + max_thinking_tokens=24576, # Same as 2.5 flash for consistency + description="Gemini 2.0 Flash (1M context) - Latest fast model with experimental thinking, supports audio/video input", + aliases=["flash-2.0", "flash2"], + ), + "gemini-2.0-flash-lite": ModelCapabilities( + provider=ProviderType.GOOGLE, + model_name="gemini-2.0-flash-lite", + friendly_name="Gemin (Flash Lite 2.0)", + context_window=1_048_576, # 1M tokens + max_output_tokens=65_536, + supports_extended_thinking=False, # Not supported per user request + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=False, # Does not support images + max_image_size_mb=0.0, # No image support + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("range"), + description="Gemini 2.0 Flash Lite (1M context) - Lightweight fast model, text-only", + aliases=["flashlite", "flash-lite"], + ), + "gemini-2.5-flash": ModelCapabilities( + provider=ProviderType.GOOGLE, + model_name="gemini-2.5-flash", + friendly_name="Gemini (Flash 2.5)", + context_window=1_048_576, # 1M tokens + max_output_tokens=65_536, + supports_extended_thinking=True, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # Vision capability + max_image_size_mb=20.0, # Conservative 20MB limit for reliability + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("range"), + max_thinking_tokens=24576, # Flash 2.5 thinking budget limit + description="Ultra-fast (1M context) - Quick analysis, simple queries, rapid iterations", + aliases=["flash", "flash2.5"], + ), + } + + # Thinking mode configurations - percentages of model's max_thinking_tokens + # These percentages work across all models that support thinking + THINKING_BUDGETS = { + "minimal": 0.005, # 0.5% of max - minimal thinking for fast responses + "low": 0.08, # 8% of max - light reasoning tasks + "medium": 0.33, # 33% of max - balanced reasoning (default) + "high": 0.67, # 67% of max - complex analysis + "max": 1.0, # 100% of max - full thinking budget + } + + # Model-specific thinking token limits + MAX_THINKING_TOKENS = { + "gemini-2.0-flash": 24576, # Same as 2.5 flash for consistency + "gemini-2.0-flash-lite": 0, # No thinking support + "gemini-2.5-flash": 24576, # Flash 2.5 thinking budget limit + "gemini-2.5-pro": 32768, # Pro 2.5 thinking budget limit + } + + def __init__(self, api_key: str, **kwargs): + """Initialize Gemini provider with API key and optional base URL.""" + super().__init__(api_key, **kwargs) + self._client = None + self._token_counters = {} # Cache for token counting + self._base_url = kwargs.get("base_url", None) # Optional custom endpoint + + # ------------------------------------------------------------------ + # Capability surface + # ------------------------------------------------------------------ + + # ------------------------------------------------------------------ + # Client access + # ------------------------------------------------------------------ + + @property + def client(self): + """Lazy initialization of Gemini client.""" + if self._client is None: + # Check if custom base URL is provided + if self._base_url: + # Use HttpOptions to set custom endpoint + http_options = types.HttpOptions(baseUrl=self._base_url) + logger.debug(f"Initializing Gemini client with custom endpoint: {self._base_url}") + self._client = genai.Client(api_key=self.api_key, http_options=http_options) + else: + # Use default Google endpoint + self._client = genai.Client(api_key=self.api_key) + return self._client + + # ------------------------------------------------------------------ + # Request execution + # ------------------------------------------------------------------ + + def generate_content( + self, + prompt: str, + model_name: str, + system_prompt: Optional[str] = None, + temperature: float = 0.3, + max_output_tokens: Optional[int] = None, + thinking_mode: str = "medium", + images: Optional[list[str]] = None, + **kwargs, + ) -> ModelResponse: + """Generate content using Gemini model.""" + # Validate parameters and fetch capabilities + resolved_name = self._resolve_model_name(model_name) + self.validate_parameters(model_name, temperature) + capabilities = self.get_capabilities(model_name) + + # Prepare content parts (text and potentially images) + parts = [] + + # Add system and user prompts as text + if system_prompt: + full_prompt = f"{system_prompt}\n\n{prompt}" + else: + full_prompt = prompt + + parts.append({"text": full_prompt}) + + # Add images if provided and model supports vision + if images and capabilities.supports_images: + for image_path in images: + try: + image_part = self._process_image(image_path) + if image_part: + parts.append(image_part) + except Exception as e: + logger.warning(f"Failed to process image {image_path}: {e}") + # Continue with other images and text + continue + elif images and not capabilities.supports_images: + logger.warning(f"Model {resolved_name} does not support images, ignoring {len(images)} image(s)") + + # Create contents structure + contents = [{"parts": parts}] + + # Prepare generation config + generation_config = types.GenerateContentConfig( + temperature=temperature, + candidate_count=1, + ) + + # Add max output tokens if specified + if max_output_tokens: + generation_config.max_output_tokens = max_output_tokens + + # Add thinking configuration for models that support it + if capabilities.supports_extended_thinking and thinking_mode in self.THINKING_BUDGETS: + # Get model's max thinking tokens and calculate actual budget + model_config = self.MODEL_CAPABILITIES.get(resolved_name) + if model_config and model_config.max_thinking_tokens > 0: + max_thinking_tokens = model_config.max_thinking_tokens + actual_thinking_budget = int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode]) + generation_config.thinking_config = types.ThinkingConfig(thinking_budget=actual_thinking_budget) + + # Retry logic with progressive delays + max_retries = 4 # Total of 4 attempts + retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s + + last_exception = None + + for attempt in range(max_retries): + try: + # Generate content + response = self.client.models.generate_content( + model=resolved_name, + contents=contents, + config=generation_config, + ) + + # Extract usage information if available + usage = self._extract_usage(response) + + # Intelligently determine finish reason and safety blocks + finish_reason_str = "UNKNOWN" + is_blocked_by_safety = False + safety_feedback_details = None + + if response.candidates: + candidate = response.candidates[0] + + # Safely get finish reason + try: + finish_reason_enum = candidate.finish_reason + if finish_reason_enum: + # Handle both enum objects and string values + try: + finish_reason_str = finish_reason_enum.name + except AttributeError: + finish_reason_str = str(finish_reason_enum) + else: + finish_reason_str = "STOP" + except AttributeError: + finish_reason_str = "STOP" + + # If content is empty, check safety ratings for the definitive cause + if not response.text: + try: + safety_ratings = candidate.safety_ratings + if safety_ratings: # Check it's not None or empty + for rating in safety_ratings: + try: + if rating.blocked: + is_blocked_by_safety = True + # Provide details for logging/debugging + category_name = "UNKNOWN" + probability_name = "UNKNOWN" + + try: + category_name = rating.category.name + except (AttributeError, TypeError): + pass + + try: + probability_name = rating.probability.name + except (AttributeError, TypeError): + pass + + safety_feedback_details = ( + f"Category: {category_name}, Probability: {probability_name}" + ) + break + except (AttributeError, TypeError): + # Individual rating doesn't have expected attributes + continue + except (AttributeError, TypeError): + # candidate doesn't have safety_ratings or it's not iterable + pass + + # Also check for prompt-level blocking (request rejected entirely) + elif response.candidates is not None and len(response.candidates) == 0: + # No candidates is the primary indicator of a prompt-level block + is_blocked_by_safety = True + finish_reason_str = "SAFETY" + safety_feedback_details = "Prompt blocked, reason unavailable" # Default message + + try: + prompt_feedback = response.prompt_feedback + if prompt_feedback and prompt_feedback.block_reason: + try: + block_reason_name = prompt_feedback.block_reason.name + except AttributeError: + block_reason_name = str(prompt_feedback.block_reason) + safety_feedback_details = f"Prompt blocked, reason: {block_reason_name}" + except (AttributeError, TypeError): + # prompt_feedback doesn't exist or has unexpected attributes; stick with the default message + pass + + return ModelResponse( + content=response.text, + usage=usage, + model_name=resolved_name, + friendly_name="Gemini", + provider=ProviderType.GOOGLE, + metadata={ + "thinking_mode": thinking_mode if capabilities.supports_extended_thinking else None, + "finish_reason": finish_reason_str, + "is_blocked_by_safety": is_blocked_by_safety, + "safety_feedback": safety_feedback_details, + }, + ) + + except Exception as e: + last_exception = e + + # Check if this is a retryable error using structured error codes + is_retryable = self._is_error_retryable(e) + + # If this is the last attempt or not retryable, give up + if attempt == max_retries - 1 or not is_retryable: + break + + # Get progressive delay + delay = retry_delays[attempt] + + # Log retry attempt + logger.warning( + f"Gemini API error for model {resolved_name}, attempt {attempt + 1}/{max_retries}: {str(e)}. Retrying in {delay}s..." + ) + time.sleep(delay) + + # If we get here, all retries failed + actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count + error_msg = f"Gemini API error for model {resolved_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}" + raise RuntimeError(error_msg) from last_exception + + def get_provider_type(self) -> ProviderType: + """Get the provider type.""" + return ProviderType.GOOGLE + + def get_thinking_budget(self, model_name: str, thinking_mode: str) -> int: + """Get actual thinking token budget for a model and thinking mode.""" + resolved_name = self._resolve_model_name(model_name) + model_config = self.MODEL_CAPABILITIES.get(resolved_name) + + if not model_config or not model_config.supports_extended_thinking: + return 0 + + if thinking_mode not in self.THINKING_BUDGETS: + return 0 + + max_thinking_tokens = model_config.max_thinking_tokens + if max_thinking_tokens == 0: + return 0 + + return int(max_thinking_tokens * self.THINKING_BUDGETS[thinking_mode]) + + def _extract_usage(self, response) -> dict[str, int]: + """Extract token usage from Gemini response.""" + usage = {} + + # Try to extract usage metadata from response + # Note: The actual structure depends on the SDK version and response format + try: + metadata = response.usage_metadata + if metadata: + # Extract token counts with explicit None checks + input_tokens = None + output_tokens = None + + try: + value = metadata.prompt_token_count + if value is not None: + input_tokens = value + usage["input_tokens"] = value + except (AttributeError, TypeError): + pass + + try: + value = metadata.candidates_token_count + if value is not None: + output_tokens = value + usage["output_tokens"] = value + except (AttributeError, TypeError): + pass + + # Calculate total only if both values are available and valid + if input_tokens is not None and output_tokens is not None: + usage["total_tokens"] = input_tokens + output_tokens + except (AttributeError, TypeError): + # response doesn't have usage_metadata + pass + + return usage + + def _is_error_retryable(self, error: Exception) -> bool: + """Determine if an error should be retried based on structured error codes. + + Uses Gemini API error structure instead of text pattern matching for reliability. + + Args: + error: Exception from Gemini API call + + Returns: + True if error should be retried, False otherwise + """ + error_str = str(error).lower() + + # Check for 429 errors first - these need special handling + if "429" in error_str or "quota" in error_str or "resource_exhausted" in error_str: + # For Gemini, check for specific non-retryable error indicators + # These typically indicate permanent failures or quota/size limits + non_retryable_indicators = [ + "quota exceeded", + "resource exhausted", + "context length", + "token limit", + "request too large", + "invalid request", + "quota_exceeded", + "resource_exhausted", + ] + + # Also check if this is a structured error from Gemini SDK + try: + # Try to access error details if available + error_details = None + try: + error_details = error.details + except AttributeError: + try: + error_details = error.reason + except AttributeError: + pass + + if error_details: + error_details_str = str(error_details).lower() + # Check for non-retryable error codes/reasons + if any(indicator in error_details_str for indicator in non_retryable_indicators): + logger.debug(f"Non-retryable Gemini error: {error_details}") + return False + except Exception: + pass + + # Check main error string for non-retryable patterns + if any(indicator in error_str for indicator in non_retryable_indicators): + logger.debug(f"Non-retryable Gemini error based on message: {error_str[:200]}...") + return False + + # If it's a 429/quota error but doesn't match non-retryable patterns, it might be retryable rate limiting + logger.debug(f"Retryable Gemini rate limiting error: {error_str[:100]}...") + return True + + # For non-429 errors, check if they're retryable + retryable_indicators = [ + "timeout", + "connection", + "network", + "temporary", + "unavailable", + "retry", + "internal error", + "408", # Request timeout + "500", # Internal server error + "502", # Bad gateway + "503", # Service unavailable + "504", # Gateway timeout + "ssl", # SSL errors + "handshake", # Handshake failures + ] + + return any(indicator in error_str for indicator in retryable_indicators) + + def _process_image(self, image_path: str) -> Optional[dict]: + """Process an image for Gemini API.""" + try: + # Use base class validation + image_bytes, mime_type = validate_image(image_path) + + # For data URLs, extract the base64 data directly + if image_path.startswith("data:"): + # Extract base64 data from data URL + _, data = image_path.split(",", 1) + return {"inline_data": {"mime_type": mime_type, "data": data}} + else: + # For file paths, encode the bytes + image_data = base64.b64encode(image_bytes).decode() + return {"inline_data": {"mime_type": mime_type, "data": image_data}} + + except ValueError as e: + logger.warning(str(e)) + return None + except Exception as e: + logger.error(f"Error processing image {image_path}: {e}") + return None + + def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]: + """Get Gemini's preferred model for a given category from allowed models. + + Args: + category: The tool category requiring a model + allowed_models: Pre-filtered list of models allowed by restrictions + + Returns: + Preferred model name or None + """ + from tools.models import ToolModelCategory + + if not allowed_models: + return None + + # Helper to find best model from candidates + def find_best(candidates: list[str]) -> Optional[str]: + """Return best model from candidates (sorted for consistency).""" + return sorted(candidates, reverse=True)[0] if candidates else None + + if category == ToolModelCategory.EXTENDED_REASONING: + # For extended reasoning, prefer models with thinking support + # First try Pro models that support thinking + pro_thinking = [ + m + for m in allowed_models + if "pro" in m and m in self.MODEL_CAPABILITIES and self.MODEL_CAPABILITIES[m].supports_extended_thinking + ] + if pro_thinking: + return find_best(pro_thinking) + + # Then any model that supports thinking + any_thinking = [ + m + for m in allowed_models + if m in self.MODEL_CAPABILITIES and self.MODEL_CAPABILITIES[m].supports_extended_thinking + ] + if any_thinking: + return find_best(any_thinking) + + # Finally, just prefer Pro models even without thinking + pro_models = [m for m in allowed_models if "pro" in m] + if pro_models: + return find_best(pro_models) + + elif category == ToolModelCategory.FAST_RESPONSE: + # Prefer Flash models for speed + flash_models = [m for m in allowed_models if "flash" in m] + if flash_models: + return find_best(flash_models) + + # Default for BALANCED or as fallback + # Prefer Flash for balanced use, then Pro, then anything + flash_models = [m for m in allowed_models if "flash" in m] + if flash_models: + return find_best(flash_models) + + pro_models = [m for m in allowed_models if "pro" in m] + if pro_models: + return find_best(pro_models) + + # Ultimate fallback to best available model + return find_best(allowed_models) diff --git a/providers/openai_compatible.py b/providers/openai_compatible.py new file mode 100644 index 0000000..2da361d --- /dev/null +++ b/providers/openai_compatible.py @@ -0,0 +1,826 @@ +"""Base class for OpenAI-compatible API providers.""" + +import copy +import ipaddress +import logging +import os +import time +from typing import Optional +from urllib.parse import urlparse + +from openai import OpenAI + +from utils.image_utils import validate_image + +from .base import ModelProvider +from .shared import ( + ModelCapabilities, + ModelResponse, + ProviderType, +) + + +class OpenAICompatibleProvider(ModelProvider): + """Shared implementation for OpenAI API lookalikes. + + The class owns HTTP client configuration (timeouts, proxy hardening, + custom headers) and normalises the OpenAI SDK responses into + :class:`~providers.shared.ModelResponse`. Concrete subclasses only need to + provide capability metadata and any provider-specific request tweaks. + """ + + DEFAULT_HEADERS = {} + FRIENDLY_NAME = "OpenAI Compatible" + + def __init__(self, api_key: str, base_url: str = None, **kwargs): + """Initialize the provider with API key and optional base URL. + + Args: + api_key: API key for authentication + base_url: Base URL for the API endpoint + **kwargs: Additional configuration options including timeout + """ + super().__init__(api_key, **kwargs) + self._client = None + self.base_url = base_url + self.organization = kwargs.get("organization") + self.allowed_models = self._parse_allowed_models() + + # Configure timeouts - especially important for custom/local endpoints + self.timeout_config = self._configure_timeouts(**kwargs) + + # Validate base URL for security + if self.base_url: + self._validate_base_url() + + # Warn if using external URL without authentication + if self.base_url and not self._is_localhost_url() and not api_key: + logging.warning( + f"Using external URL '{self.base_url}' without API key. " + "This may be insecure. Consider setting an API key for authentication." + ) + + def _ensure_model_allowed( + self, + capabilities: ModelCapabilities, + canonical_name: str, + requested_name: str, + ) -> None: + """Respect provider-specific allowlists before default restriction checks.""" + + super()._ensure_model_allowed(capabilities, canonical_name, requested_name) + + if self.allowed_models is not None: + requested = requested_name.lower() + canonical = canonical_name.lower() + + if requested not in self.allowed_models and canonical not in self.allowed_models: + raise ValueError( + f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}" + ) + + def _parse_allowed_models(self) -> Optional[set[str]]: + """Parse allowed models from environment variable. + + Returns: + Set of allowed model names (lowercase) or None if not configured + """ + # Get provider-specific allowed models + provider_type = self.get_provider_type().value.upper() + env_var = f"{provider_type}_ALLOWED_MODELS" + models_str = os.getenv(env_var, "") + + if models_str: + # Parse and normalize to lowercase for case-insensitive comparison + models = {m.strip().lower() for m in models_str.split(",") if m.strip()} + if models: + logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}") + return models + + # Log info if no allow-list configured for proxy providers + if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]: + logging.info( + f"Model allow-list not configured for {self.FRIENDLY_NAME} - all models permitted. " + f"To restrict access, set {env_var} with comma-separated model names." + ) + + return None + + def _configure_timeouts(self, **kwargs): + """Configure timeout settings based on provider type and custom settings. + + Custom URLs and local models often need longer timeouts due to: + - Network latency on local networks + - Extended thinking models taking longer to respond + - Local inference being slower than cloud APIs + + Returns: + httpx.Timeout object with appropriate timeout settings + """ + import httpx + + # Default timeouts - more generous for custom/local endpoints + default_connect = 30.0 # 30 seconds for connection (vs OpenAI's 5s) + default_read = 600.0 # 10 minutes for reading (same as OpenAI default) + default_write = 600.0 # 10 minutes for writing + default_pool = 600.0 # 10 minutes for pool + + # For custom/local URLs, use even longer timeouts + if self.base_url and self._is_localhost_url(): + default_connect = 60.0 # 1 minute for local connections + default_read = 1800.0 # 30 minutes for local models (extended thinking) + default_write = 1800.0 # 30 minutes for local models + default_pool = 1800.0 # 30 minutes for local models + logging.info(f"Using extended timeouts for local endpoint: {self.base_url}") + elif self.base_url: + default_connect = 45.0 # 45 seconds for custom remote endpoints + default_read = 900.0 # 15 minutes for custom remote endpoints + default_write = 900.0 # 15 minutes for custom remote endpoints + default_pool = 900.0 # 15 minutes for custom remote endpoints + logging.info(f"Using extended timeouts for custom endpoint: {self.base_url}") + + # Allow override via kwargs or environment variables in future, for now... + connect_timeout = kwargs.get("connect_timeout", float(os.getenv("CUSTOM_CONNECT_TIMEOUT", default_connect))) + read_timeout = kwargs.get("read_timeout", float(os.getenv("CUSTOM_READ_TIMEOUT", default_read))) + write_timeout = kwargs.get("write_timeout", float(os.getenv("CUSTOM_WRITE_TIMEOUT", default_write))) + pool_timeout = kwargs.get("pool_timeout", float(os.getenv("CUSTOM_POOL_TIMEOUT", default_pool))) + + timeout = httpx.Timeout(connect=connect_timeout, read=read_timeout, write=write_timeout, pool=pool_timeout) + + logging.debug( + f"Configured timeouts - Connect: {connect_timeout}s, Read: {read_timeout}s, " + f"Write: {write_timeout}s, Pool: {pool_timeout}s" + ) + + return timeout + + def _is_localhost_url(self) -> bool: + """Check if the base URL points to localhost or local network. + + Returns: + True if URL is localhost or local network, False otherwise + """ + if not self.base_url: + return False + + try: + parsed = urlparse(self.base_url) + hostname = parsed.hostname + + # Check for common localhost patterns + if hostname in ["localhost", "127.0.0.1", "::1"]: + return True + + # Check for private network ranges (local network) + if hostname: + try: + ip = ipaddress.ip_address(hostname) + return ip.is_private or ip.is_loopback + except ValueError: + # Not an IP address, might be a hostname + pass + + return False + except Exception: + return False + + def _validate_base_url(self) -> None: + """Validate base URL for security (SSRF protection). + + Raises: + ValueError: If URL is invalid or potentially unsafe + """ + if not self.base_url: + return + + try: + parsed = urlparse(self.base_url) + + # Check URL scheme - only allow http/https + if parsed.scheme not in ("http", "https"): + raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.") + + # Check hostname exists + if not parsed.hostname: + raise ValueError("URL must include a hostname") + + # Check port is valid (if specified) + port = parsed.port + if port is not None and (port < 1 or port > 65535): + raise ValueError(f"Invalid port number: {port}. Must be between 1 and 65535.") + except Exception as e: + if isinstance(e, ValueError): + raise + raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}") + + @property + def client(self): + """Lazy initialization of OpenAI client with security checks and timeout configuration.""" + if self._client is None: + import os + + import httpx + + # Temporarily disable proxy environment variables to prevent httpx from detecting them + original_env = {} + proxy_env_vars = ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"] + + for var in proxy_env_vars: + if var in os.environ: + original_env[var] = os.environ[var] + del os.environ[var] + + try: + # Create a custom httpx client that explicitly avoids proxy parameters + timeout_config = ( + self.timeout_config + if hasattr(self, "timeout_config") and self.timeout_config + else httpx.Timeout(30.0) + ) + + # Create httpx client with minimal config to avoid proxy conflicts + # Note: proxies parameter was removed in httpx 0.28.0 + # Check for test transport injection + if hasattr(self, "_test_transport"): + # Use custom transport for testing (HTTP recording/replay) + http_client = httpx.Client( + transport=self._test_transport, + timeout=timeout_config, + follow_redirects=True, + ) + else: + # Normal production client + http_client = httpx.Client( + timeout=timeout_config, + follow_redirects=True, + ) + + # Keep client initialization minimal to avoid proxy parameter conflicts + client_kwargs = { + "api_key": self.api_key, + "http_client": http_client, + } + + if self.base_url: + client_kwargs["base_url"] = self.base_url + + if self.organization: + client_kwargs["organization"] = self.organization + + # Add default headers if any + if self.DEFAULT_HEADERS: + client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy() + + logging.debug(f"OpenAI client initialized with custom httpx client and timeout: {timeout_config}") + + # Create OpenAI client with custom httpx client + self._client = OpenAI(**client_kwargs) + + except Exception as e: + # If all else fails, try absolute minimal client without custom httpx + logging.warning(f"Failed to create client with custom httpx, falling back to minimal config: {e}") + try: + minimal_kwargs = {"api_key": self.api_key} + if self.base_url: + minimal_kwargs["base_url"] = self.base_url + self._client = OpenAI(**minimal_kwargs) + except Exception as fallback_error: + logging.error(f"Even minimal OpenAI client creation failed: {fallback_error}") + raise + finally: + # Restore original proxy environment variables + for var, value in original_env.items(): + os.environ[var] = value + + return self._client + + def _sanitize_for_logging(self, params: dict) -> dict: + """Sanitize sensitive data from parameters before logging. + + Args: + params: Dictionary of API parameters + + Returns: + dict: Sanitized copy of parameters safe for logging + """ + sanitized = copy.deepcopy(params) + + # Sanitize messages content + if "input" in sanitized: + for msg in sanitized.get("input", []): + if isinstance(msg, dict) and "content" in msg: + for content_item in msg.get("content", []): + if isinstance(content_item, dict) and "text" in content_item: + # Truncate long text and add ellipsis + text = content_item["text"] + if len(text) > 100: + content_item["text"] = text[:100] + "... [truncated]" + + # Remove any API keys that might be in headers/auth + sanitized.pop("api_key", None) + sanitized.pop("authorization", None) + + return sanitized + + def _safe_extract_output_text(self, response) -> str: + """Safely extract output_text from o3-pro response with validation. + + Args: + response: Response object from OpenAI SDK + + Returns: + str: The output text content + + Raises: + ValueError: If output_text is missing, None, or not a string + """ + logging.debug(f"Response object type: {type(response)}") + logging.debug(f"Response attributes: {dir(response)}") + + if not hasattr(response, "output_text"): + raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}") + + content = response.output_text + logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})") + + if content is None: + raise ValueError("o3-pro returned None for output_text") + + if not isinstance(content, str): + raise ValueError(f"o3-pro output_text is not a string. Got type: {type(content).__name__}") + + return content + + def _generate_with_responses_endpoint( + self, + model_name: str, + messages: list, + temperature: float, + max_output_tokens: Optional[int] = None, + **kwargs, + ) -> ModelResponse: + """Generate content using the /v1/responses endpoint for o3-pro via OpenAI library.""" + # Convert messages to the correct format for responses endpoint + input_messages = [] + + for message in messages: + role = message.get("role", "") + content = message.get("content", "") + + if role == "system": + # For o3-pro, system messages should be handled carefully to avoid policy violations + # Instead of prefixing with "System:", we'll include the system content naturally + input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]}) + elif role == "user": + input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]}) + elif role == "assistant": + input_messages.append({"role": "assistant", "content": [{"type": "output_text", "text": content}]}) + + # Prepare completion parameters for responses endpoint + # Based on OpenAI documentation, use nested reasoning object for responses endpoint + completion_params = { + "model": model_name, + "input": input_messages, + "reasoning": {"effort": "medium"}, # Use nested object for responses endpoint + "store": True, + } + + # Add max tokens if specified (using max_completion_tokens for responses endpoint) + if max_output_tokens: + completion_params["max_completion_tokens"] = max_output_tokens + + # For responses endpoint, we only add parameters that are explicitly supported + # Remove unsupported chat completion parameters that may cause API errors + + # Retry logic with progressive delays + max_retries = 4 + retry_delays = [1, 3, 5, 8] + last_exception = None + actual_attempts = 0 + + for attempt in range(max_retries): + try: # Log sanitized payload for debugging + import json + + sanitized_params = self._sanitize_for_logging(completion_params) + logging.info( + f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}" + ) + + # Use OpenAI client's responses endpoint + response = self.client.responses.create(**completion_params) + + # Extract content from responses endpoint format + # Use validation helper to safely extract output_text + content = self._safe_extract_output_text(response) + + # Try to extract usage information + usage = None + if hasattr(response, "usage"): + usage = self._extract_usage(response) + elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"): + # Safely extract token counts with None handling + input_tokens = getattr(response, "input_tokens", 0) or 0 + output_tokens = getattr(response, "output_tokens", 0) or 0 + usage = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + } + + return ModelResponse( + content=content, + usage=usage, + model_name=model_name, + friendly_name=self.FRIENDLY_NAME, + provider=self.get_provider_type(), + metadata={ + "model": getattr(response, "model", model_name), + "id": getattr(response, "id", ""), + "created": getattr(response, "created_at", 0), + "endpoint": "responses", + }, + ) + + except Exception as e: + last_exception = e + + # Check if this is a retryable error using structured error codes + is_retryable = self._is_error_retryable(e) + + if is_retryable and attempt < max_retries - 1: + delay = retry_delays[attempt] + logging.warning( + f"Retryable error for o3-pro responses endpoint, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..." + ) + time.sleep(delay) + else: + break + + # If we get here, all retries failed + error_msg = f"o3-pro responses endpoint error after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}" + logging.error(error_msg) + raise RuntimeError(error_msg) from last_exception + + def generate_content( + self, + prompt: str, + model_name: str, + system_prompt: Optional[str] = None, + temperature: float = 0.3, + max_output_tokens: Optional[int] = None, + images: Optional[list[str]] = None, + **kwargs, + ) -> ModelResponse: + """Generate content using the OpenAI-compatible API. + + Args: + prompt: User prompt to send to the model + model_name: Name of the model to use + system_prompt: Optional system prompt for model behavior + temperature: Sampling temperature + max_output_tokens: Maximum tokens to generate + **kwargs: Additional provider-specific parameters + + Returns: + ModelResponse with generated content and metadata + """ + # Validate model name against allow-list + if not self.validate_model_name(model_name): + raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}") + + capabilities: Optional[ModelCapabilities] + try: + capabilities = self.get_capabilities(model_name) + except Exception as exc: + logging.debug(f"Falling back to generic capabilities for {model_name}: {exc}") + capabilities = None + + # Get effective temperature for this model from capabilities when available + if capabilities: + effective_temperature = capabilities.get_effective_temperature(temperature) + if effective_temperature is not None and effective_temperature != temperature: + logging.debug( + f"Adjusting temperature from {temperature} to {effective_temperature} for model {model_name}" + ) + else: + effective_temperature = temperature + + # Only validate if temperature is not None (meaning the model supports it) + if effective_temperature is not None: + # Validate parameters with the effective temperature + self.validate_parameters(model_name, effective_temperature) + + # Prepare messages + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + # Prepare user message with text and potentially images + user_content = [] + user_content.append({"type": "text", "text": prompt}) + + # Add images if provided and model supports vision + if images and capabilities and capabilities.supports_images: + for image_path in images: + try: + image_content = self._process_image(image_path) + if image_content: + user_content.append(image_content) + except Exception as e: + logging.warning(f"Failed to process image {image_path}: {e}") + # Continue with other images and text + continue + elif images and (not capabilities or not capabilities.supports_images): + logging.warning(f"Model {model_name} does not support images, ignoring {len(images)} image(s)") + + # Add user message + if len(user_content) == 1: + # Only text content, use simple string format for compatibility + messages.append({"role": "user", "content": prompt}) + else: + # Text + images, use content array format + messages.append({"role": "user", "content": user_content}) + + # Prepare completion parameters + completion_params = { + "model": model_name, + "messages": messages, + } + + # Check model capabilities once to determine parameter support + resolved_model = self._resolve_model_name(model_name) + + # Use the effective temperature we calculated earlier + supports_sampling = effective_temperature is not None + + if supports_sampling: + completion_params["temperature"] = effective_temperature + + # Add max tokens if specified and model supports it + # O3/O4 models that don't support temperature also don't support max_tokens + if max_output_tokens and supports_sampling: + completion_params["max_tokens"] = max_output_tokens + + # Add any additional OpenAI-specific parameters + # Use capabilities to filter parameters for reasoning models + for key, value in kwargs.items(): + if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]: + # Reasoning models (those that don't support temperature) also don't support these parameters + if not supports_sampling and key in ["top_p", "frequency_penalty", "presence_penalty"]: + continue # Skip unsupported parameters for reasoning models + completion_params[key] = value + + # Check if this is o3-pro and needs the responses endpoint + if resolved_model == "o3-pro": + # This model requires the /v1/responses endpoint + # If it fails, we should not fall back to chat/completions + return self._generate_with_responses_endpoint( + model_name=resolved_model, + messages=messages, + temperature=temperature, + max_output_tokens=max_output_tokens, + **kwargs, + ) + + # Retry logic with progressive delays + max_retries = 4 # Total of 4 attempts + retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s + + last_exception = None + actual_attempts = 0 + + for attempt in range(max_retries): + actual_attempts = attempt + 1 # Convert from 0-based index to human-readable count + try: + # Generate completion + response = self.client.chat.completions.create(**completion_params) + + # Extract content and usage + content = response.choices[0].message.content + usage = self._extract_usage(response) + + return ModelResponse( + content=content, + usage=usage, + model_name=model_name, + friendly_name=self.FRIENDLY_NAME, + provider=self.get_provider_type(), + metadata={ + "finish_reason": response.choices[0].finish_reason, + "model": response.model, # Actual model used + "id": response.id, + "created": response.created, + }, + ) + + except Exception as e: + last_exception = e + + # Check if this is a retryable error using structured error codes + is_retryable = self._is_error_retryable(e) + + # If this is the last attempt or not retryable, give up + if attempt == max_retries - 1 or not is_retryable: + break + + # Get progressive delay + delay = retry_delays[attempt] + + # Log retry attempt + logging.warning( + f"{self.FRIENDLY_NAME} error for model {model_name}, attempt {actual_attempts}/{max_retries}: {str(e)}. Retrying in {delay}s..." + ) + time.sleep(delay) + + # If we get here, all retries failed + error_msg = f"{self.FRIENDLY_NAME} API error for model {model_name} after {actual_attempts} attempt{'s' if actual_attempts > 1 else ''}: {str(last_exception)}" + logging.error(error_msg) + raise RuntimeError(error_msg) from last_exception + + def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None: + """Validate model parameters. + + For proxy providers, this may use generic capabilities. + + Args: + model_name: Model to validate for + temperature: Temperature to validate + **kwargs: Additional parameters to validate + """ + try: + capabilities = self.get_capabilities(model_name) + + # Check if we're using generic capabilities + if hasattr(capabilities, "_is_generic"): + logging.debug( + f"Using generic parameter validation for {model_name}. Actual model constraints may differ." + ) + + # Validate temperature using parent class method + super().validate_parameters(model_name, temperature, **kwargs) + + except Exception as e: + # For proxy providers, we might not have accurate capabilities + # Log warning but don't fail + logging.warning(f"Parameter validation limited for {model_name}: {e}") + + def _extract_usage(self, response) -> dict[str, int]: + """Extract token usage from OpenAI response. + + Args: + response: OpenAI API response object + + Returns: + Dictionary with usage statistics + """ + usage = {} + + if hasattr(response, "usage") and response.usage: + # Safely extract token counts with None handling + usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0) or 0 + usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0) or 0 + usage["total_tokens"] = getattr(response.usage, "total_tokens", 0) or 0 + + return usage + + def count_tokens(self, text: str, model_name: str) -> int: + """Count tokens using OpenAI-compatible tokenizer tables when available.""" + + resolved_model = self._resolve_model_name(model_name) + + try: + import tiktoken + + try: + encoding = tiktoken.encoding_for_model(resolved_model) + except KeyError: + encoding = tiktoken.get_encoding("cl100k_base") + + return len(encoding.encode(text)) + + except (ImportError, Exception) as exc: + logging.debug("tiktoken unavailable for %s: %s", resolved_model, exc) + + return super().count_tokens(text, model_name) + + def _is_error_retryable(self, error: Exception) -> bool: + """Determine if an error should be retried based on structured error codes. + + Uses OpenAI API error structure instead of text pattern matching for reliability. + + Args: + error: Exception from OpenAI API call + + Returns: + True if error should be retried, False otherwise + """ + error_str = str(error).lower() + + # Check for 429 errors first - these need special handling + if "429" in error_str: + # Try to extract structured error information + error_type = None + error_code = None + + # Parse structured error from OpenAI API response + # Format: "Error code: 429 - {'error': {'type': 'tokens', 'code': 'rate_limit_exceeded', ...}}" + try: + import ast + import json + import re + + # Extract JSON part from error string using regex + # Look for pattern: {...} (from first { to last }) + json_match = re.search(r"\{.*\}", str(error)) + if json_match: + json_like_str = json_match.group(0) + + # First try: parse as Python literal (handles single quotes safely) + try: + error_data = ast.literal_eval(json_like_str) + except (ValueError, SyntaxError): + # Fallback: try JSON parsing with simple quote replacement + # (for cases where it's already valid JSON or simple replacements work) + json_str = json_like_str.replace("'", '"') + error_data = json.loads(json_str) + + if "error" in error_data: + error_info = error_data["error"] + error_type = error_info.get("type") + error_code = error_info.get("code") + + except (json.JSONDecodeError, ValueError, SyntaxError, AttributeError): + # Fall back to checking hasattr for OpenAI SDK exception objects + if hasattr(error, "response") and hasattr(error.response, "json"): + try: + response_data = error.response.json() + if "error" in response_data: + error_info = response_data["error"] + error_type = error_info.get("type") + error_code = error_info.get("code") + except Exception: + pass + + # Determine if 429 is retryable based on structured error codes + if error_type == "tokens": + # Token-related 429s are typically non-retryable (request too large) + logging.debug(f"Non-retryable 429: token-related error (type={error_type}, code={error_code})") + return False + elif error_code in ["invalid_request_error", "context_length_exceeded"]: + # These are permanent failures + logging.debug(f"Non-retryable 429: permanent failure (type={error_type}, code={error_code})") + return False + else: + # Other 429s (like requests per minute) are retryable + logging.debug(f"Retryable 429: rate limiting (type={error_type}, code={error_code})") + return True + + # For non-429 errors, check if they're retryable + retryable_indicators = [ + "timeout", + "connection", + "network", + "temporary", + "unavailable", + "retry", + "408", # Request timeout + "500", # Internal server error + "502", # Bad gateway + "503", # Service unavailable + "504", # Gateway timeout + "ssl", # SSL errors + "handshake", # Handshake failures + ] + + return any(indicator in error_str for indicator in retryable_indicators) + + def _process_image(self, image_path: str) -> Optional[dict]: + """Process an image for OpenAI-compatible API.""" + try: + if image_path.startswith("data:"): + # Validate the data URL + validate_image(image_path) + # Handle data URL: data:image/png;base64,iVBORw0... + return {"type": "image_url", "image_url": {"url": image_path}} + else: + # Use base class validation + image_bytes, mime_type = validate_image(image_path) + + # Read and encode the image + import base64 + + image_data = base64.b64encode(image_bytes).decode() + logging.debug(f"Processing image '{image_path}' as MIME type '{mime_type}'") + + # Create data URL for OpenAI API + data_url = f"data:{mime_type};base64,{image_data}" + + return {"type": "image_url", "image_url": {"url": data_url}} + + except ValueError as e: + logging.warning(str(e)) + return None + except Exception as e: + logging.error(f"Error processing image {image_path}: {e}") + return None diff --git a/providers/openai_provider.py b/providers/openai_provider.py new file mode 100644 index 0000000..a032756 --- /dev/null +++ b/providers/openai_provider.py @@ -0,0 +1,296 @@ +"""OpenAI model provider implementation.""" + +import logging +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from tools.models import ToolModelCategory + +from .openai_compatible import OpenAICompatibleProvider +from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint + +logger = logging.getLogger(__name__) + + +class OpenAIModelProvider(OpenAICompatibleProvider): + """Implementation that talks to api.openai.com using rich model metadata. + + In addition to the built-in catalogue, the provider can surface models + defined in ``conf/custom_models.json`` (for organisations running their own + OpenAI-compatible gateways) while still respecting restriction policies. + """ + + # Model configurations using ModelCapabilities objects + MODEL_CAPABILITIES = { + "gpt-5": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="gpt-5", + friendly_name="OpenAI (GPT-5)", + context_window=400_000, # 400K tokens + max_output_tokens=128_000, # 128K max output tokens + supports_extended_thinking=True, # Supports reasoning tokens + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # GPT-5 supports vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=True, # Regular models accept temperature parameter + temperature_constraint=TemperatureConstraint.create("fixed"), + description="GPT-5 (400K context, 128K output) - Advanced model with reasoning support", + aliases=["gpt5"], + ), + "gpt-5-mini": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="gpt-5-mini", + friendly_name="OpenAI (GPT-5-mini)", + context_window=400_000, # 400K tokens + max_output_tokens=128_000, # 128K max output tokens + supports_extended_thinking=True, # Supports reasoning tokens + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # GPT-5-mini supports vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("fixed"), + description="GPT-5-mini (400K context, 128K output) - Efficient variant with reasoning support", + aliases=["gpt5-mini", "gpt5mini", "mini"], + ), + "gpt-5-nano": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="gpt-5-nano", + friendly_name="OpenAI (GPT-5 nano)", + context_window=400_000, + max_output_tokens=128_000, + supports_extended_thinking=True, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, + max_image_size_mb=20.0, + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("fixed"), + description="GPT-5 nano (400K context) - Fastest, cheapest version of GPT-5 for summarization and classification tasks", + aliases=["gpt5nano", "gpt5-nano", "nano"], + ), + "o3": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="o3", + friendly_name="OpenAI (O3)", + context_window=200_000, # 200K tokens + max_output_tokens=65536, # 64K max output tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # O3 models support vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=False, # O3 models don't accept temperature parameter + temperature_constraint=TemperatureConstraint.create("fixed"), + description="Strong reasoning (200K context) - Logical problems, code generation, systematic analysis", + aliases=[], + ), + "o3-mini": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="o3-mini", + friendly_name="OpenAI (O3-mini)", + context_window=200_000, # 200K tokens + max_output_tokens=65536, # 64K max output tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # O3 models support vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=False, # O3 models don't accept temperature parameter + temperature_constraint=TemperatureConstraint.create("fixed"), + description="Fast O3 variant (200K context) - Balanced performance/speed, moderate complexity", + aliases=["o3mini"], + ), + "o3-pro": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="o3-pro", + friendly_name="OpenAI (O3-Pro)", + context_window=200_000, # 200K tokens + max_output_tokens=65536, # 64K max output tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # O3 models support vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=False, # O3 models don't accept temperature parameter + temperature_constraint=TemperatureConstraint.create("fixed"), + description="Professional-grade reasoning (200K context) - EXTREMELY EXPENSIVE: Only for the most complex problems requiring universe-scale complexity analysis OR when the user explicitly asks for this model. Use sparingly for critical architectural decisions or exceptionally complex debugging that other models cannot handle.", + aliases=["o3pro"], + ), + "o4-mini": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="o4-mini", + friendly_name="OpenAI (O4-mini)", + context_window=200_000, # 200K tokens + max_output_tokens=65536, # 64K max output tokens + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # O4 models support vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=False, # O4 models don't accept temperature parameter + temperature_constraint=TemperatureConstraint.create("fixed"), + description="Latest reasoning model (200K context) - Optimized for shorter contexts, rapid reasoning", + aliases=["o4mini"], + ), + "gpt-4.1": ModelCapabilities( + provider=ProviderType.OPENAI, + model_name="gpt-4.1", + friendly_name="OpenAI (GPT 4.1)", + context_window=1_000_000, # 1M tokens + max_output_tokens=32_768, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=True, + supports_images=True, # GPT-4.1 supports vision + max_image_size_mb=20.0, # 20MB per OpenAI docs + supports_temperature=True, # Regular models accept temperature parameter + temperature_constraint=TemperatureConstraint.create("range"), + description="GPT-4.1 (1M context) - Advanced reasoning model with large context window", + aliases=["gpt4.1"], + ), + } + + def __init__(self, api_key: str, **kwargs): + """Initialize OpenAI provider with API key.""" + # Set default OpenAI base URL, allow override for regions/custom endpoints + kwargs.setdefault("base_url", "https://api.openai.com/v1") + super().__init__(api_key, **kwargs) + + # ------------------------------------------------------------------ + # Capability surface + # ------------------------------------------------------------------ + + def _lookup_capabilities( + self, + canonical_name: str, + requested_name: Optional[str] = None, + ) -> Optional[ModelCapabilities]: + """Look up OpenAI capabilities from built-ins or the custom registry.""" + + builtin = super()._lookup_capabilities(canonical_name, requested_name) + if builtin is not None: + return builtin + + try: + from .openrouter_registry import OpenRouterModelRegistry + + registry = OpenRouterModelRegistry() + config = registry.get_model_config(canonical_name) + + if config and config.provider == ProviderType.OPENAI: + return config + + except Exception as exc: # pragma: no cover - registry failures are non-critical + logger.debug(f"Could not resolve custom OpenAI model '{canonical_name}': {exc}") + + return None + + def _finalise_capabilities( + self, + capabilities: ModelCapabilities, + canonical_name: str, + requested_name: str, + ) -> ModelCapabilities: + """Ensure registry-sourced models report the correct provider type.""" + + if capabilities.provider != ProviderType.OPENAI: + capabilities.provider = ProviderType.OPENAI + return capabilities + + def _raise_unsupported_model(self, model_name: str) -> None: + raise ValueError(f"Unsupported OpenAI model: {model_name}") + + # ------------------------------------------------------------------ + # Provider identity + # ------------------------------------------------------------------ + + def get_provider_type(self) -> ProviderType: + """Get the provider type.""" + return ProviderType.OPENAI + + # ------------------------------------------------------------------ + # Request execution + # ------------------------------------------------------------------ + + def generate_content( + self, + prompt: str, + model_name: str, + system_prompt: Optional[str] = None, + temperature: float = 0.3, + max_output_tokens: Optional[int] = None, + **kwargs, + ) -> ModelResponse: + """Generate content using OpenAI API with proper model name resolution.""" + # Resolve model alias before making API call + resolved_model_name = self._resolve_model_name(model_name) + + # Call parent implementation with resolved model name + return super().generate_content( + prompt=prompt, + model_name=resolved_model_name, + system_prompt=system_prompt, + temperature=temperature, + max_output_tokens=max_output_tokens, + **kwargs, + ) + + # ------------------------------------------------------------------ + # Provider preferences + # ------------------------------------------------------------------ + + def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]: + """Get OpenAI's preferred model for a given category from allowed models. + + Args: + category: The tool category requiring a model + allowed_models: Pre-filtered list of models allowed by restrictions + + Returns: + Preferred model name or None + """ + from tools.models import ToolModelCategory + + if not allowed_models: + return None + + # Helper to find first available from preference list + def find_first(preferences: list[str]) -> Optional[str]: + """Return first available model from preference list.""" + for model in preferences: + if model in allowed_models: + return model + return None + + if category == ToolModelCategory.EXTENDED_REASONING: + # Prefer models with extended thinking support + preferred = find_first(["o3", "o3-pro", "gpt-5"]) + return preferred if preferred else allowed_models[0] + + elif category == ToolModelCategory.FAST_RESPONSE: + # Prefer fast, cost-efficient models + preferred = find_first(["gpt-5", "gpt-5-mini", "o4-mini", "o3-mini"]) + return preferred if preferred else allowed_models[0] + + else: # BALANCED or default + # Prefer balanced performance/cost models + preferred = find_first(["gpt-5", "gpt-5-mini", "o4-mini", "o3-mini"]) + return preferred if preferred else allowed_models[0] diff --git a/providers/openrouter.py b/providers/openrouter.py new file mode 100644 index 0000000..ddb7745 --- /dev/null +++ b/providers/openrouter.py @@ -0,0 +1,251 @@ +"""OpenRouter provider implementation.""" + +import logging +import os +from typing import Optional + +from .openai_compatible import OpenAICompatibleProvider +from .openrouter_registry import OpenRouterModelRegistry +from .shared import ( + ModelCapabilities, + ModelResponse, + ProviderType, + RangeTemperatureConstraint, +) + + +class OpenRouterProvider(OpenAICompatibleProvider): + """Client for OpenRouter's multi-model aggregation service. + + Role + Surface OpenRouter’s dynamic catalogue through the same interface as + native providers so tools can reference OpenRouter models and aliases + without special cases. + + Characteristics + * Pulls live model definitions from :class:`OpenRouterModelRegistry` + (aliases, provider-specific metadata, capability hints) + * Applies alias-aware restriction checks before exposing models to the + registry or tooling + * Reuses :class:`OpenAICompatibleProvider` infrastructure for request + execution so OpenRouter endpoints behave like standard OpenAI-style + APIs. + """ + + FRIENDLY_NAME = "OpenRouter" + + # Custom headers required by OpenRouter + DEFAULT_HEADERS = { + "HTTP-Referer": os.getenv("OPENROUTER_REFERER", "https://github.com/BeehiveInnovations/zen-mcp-server"), + "X-Title": os.getenv("OPENROUTER_TITLE", "Zen MCP Server"), + } + + # Model registry for managing configurations and aliases + _registry: Optional[OpenRouterModelRegistry] = None + + def __init__(self, api_key: str, **kwargs): + """Initialize OpenRouter provider. + + Args: + api_key: OpenRouter API key + **kwargs: Additional configuration + """ + base_url = "https://openrouter.ai/api/v1" + super().__init__(api_key, base_url=base_url, **kwargs) + + # Initialize model registry + if OpenRouterProvider._registry is None: + OpenRouterProvider._registry = OpenRouterModelRegistry() + # Log loaded models and aliases only on first load + models = self._registry.list_models() + aliases = self._registry.list_aliases() + logging.info(f"OpenRouter loaded {len(models)} models with {len(aliases)} aliases") + + # ------------------------------------------------------------------ + # Capability surface + # ------------------------------------------------------------------ + + def _lookup_capabilities( + self, + canonical_name: str, + requested_name: Optional[str] = None, + ) -> Optional[ModelCapabilities]: + """Fetch OpenRouter capabilities from the registry or build a generic fallback.""" + + capabilities = self._registry.get_capabilities(canonical_name) + if capabilities: + return capabilities + + base_identifier = canonical_name.split(":", 1)[0] + if "/" in base_identifier: + logging.debug( + "Using generic OpenRouter capabilities for %s (provider/model format detected)", canonical_name + ) + generic = ModelCapabilities( + provider=ProviderType.OPENROUTER, + model_name=canonical_name, + friendly_name=self.FRIENDLY_NAME, + context_window=32_768, + max_output_tokens=32_768, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=False, + temperature_constraint=RangeTemperatureConstraint(0.0, 2.0, 1.0), + ) + generic._is_generic = True + return generic + + logging.debug( + "Rejecting unknown OpenRouter model '%s' (no provider prefix); requires explicit configuration", + canonical_name, + ) + return None + + # ------------------------------------------------------------------ + # Provider identity + # ------------------------------------------------------------------ + + def get_provider_type(self) -> ProviderType: + """Identify this provider for restrictions and logging.""" + return ProviderType.OPENROUTER + + # ------------------------------------------------------------------ + # Request execution + # ------------------------------------------------------------------ + + def generate_content( + self, + prompt: str, + model_name: str, + system_prompt: Optional[str] = None, + temperature: float = 0.3, + max_output_tokens: Optional[int] = None, + **kwargs, + ) -> ModelResponse: + """Generate content using the OpenRouter API. + + Args: + prompt: User prompt to send to the model + model_name: Name of the model (or alias) to use + system_prompt: Optional system prompt for model behavior + temperature: Sampling temperature + max_output_tokens: Maximum tokens to generate + **kwargs: Additional provider-specific parameters + + Returns: + ModelResponse with generated content and metadata + """ + # Resolve model alias to actual OpenRouter model name + resolved_model = self._resolve_model_name(model_name) + + # Always disable streaming for OpenRouter + # MCP doesn't use streaming, and this avoids issues with O3 model access + if "stream" not in kwargs: + kwargs["stream"] = False + + # Call parent method with resolved model name + return super().generate_content( + prompt=prompt, + model_name=resolved_model, + system_prompt=system_prompt, + temperature=temperature, + max_output_tokens=max_output_tokens, + **kwargs, + ) + + # ------------------------------------------------------------------ + # Registry helpers + # ------------------------------------------------------------------ + + def list_models( + self, + *, + respect_restrictions: bool = True, + include_aliases: bool = True, + lowercase: bool = False, + unique: bool = False, + ) -> list[str]: + """Return formatted OpenRouter model names, respecting alias-aware restrictions.""" + + if not self._registry: + return [] + + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() if respect_restrictions else None + allowed_configs: dict[str, ModelCapabilities] = {} + + for model_name in self._registry.list_models(): + config = self._registry.resolve(model_name) + if not config: + continue + + # Custom models belong to CustomProvider; skip them here so the two + # providers don't race over the same registrations (important for tests + # that stub the registry with minimal objects lacking attrs). + if hasattr(config, "is_custom") and config.is_custom is True: + continue + + if restriction_service: + allowed = restriction_service.is_allowed(self.get_provider_type(), model_name) + + if not allowed and config.aliases: + for alias in config.aliases: + if restriction_service.is_allowed(self.get_provider_type(), alias): + allowed = True + break + + if not allowed: + continue + + allowed_configs[model_name] = config + + if not allowed_configs: + return [] + + # When restrictions are in place, don't include aliases to avoid confusion + # Only return the canonical model names that are actually allowed + actual_include_aliases = include_aliases and not respect_restrictions + + return ModelCapabilities.collect_model_names( + allowed_configs, + include_aliases=actual_include_aliases, + lowercase=lowercase, + unique=unique, + ) + + # ------------------------------------------------------------------ + # Registry helpers + # ------------------------------------------------------------------ + + def _resolve_model_name(self, model_name: str) -> str: + """Resolve aliases defined in the OpenRouter registry.""" + + config = self._registry.resolve(model_name) + if config: + if config.model_name != model_name: + logging.info(f"Resolved model alias '{model_name}' to '{config.model_name}'") + return config.model_name + + logging.debug(f"Model '{model_name}' not found in registry, using as-is") + return model_name + + def get_all_model_capabilities(self) -> dict[str, ModelCapabilities]: + """Expose registry-backed OpenRouter capabilities.""" + + if not self._registry: + return {} + + capabilities: dict[str, ModelCapabilities] = {} + for model_name in self._registry.list_models(): + config = self._registry.resolve(model_name) + if not config: + continue + + # See note in list_models: respect the CustomProvider boundary. + if hasattr(config, "is_custom") and config.is_custom is True: + continue + + capabilities[model_name] = config + return capabilities diff --git a/providers/openrouter_registry.py b/providers/openrouter_registry.py new file mode 100644 index 0000000..e61ce7f --- /dev/null +++ b/providers/openrouter_registry.py @@ -0,0 +1,292 @@ +"""OpenRouter model registry for managing model configurations and aliases.""" + +import importlib.resources +import logging +import os +from pathlib import Path +from typing import Optional + +# Import handled via importlib.resources.files() calls directly +from utils.file_utils import read_json_file + +from .shared import ( + ModelCapabilities, + ProviderType, + TemperatureConstraint, +) + + +class OpenRouterModelRegistry: + """In-memory view of OpenRouter and custom model metadata. + + Role + Parse the packaged ``conf/custom_models.json`` (or user-specified + overrides), construct alias and capability maps, and serve those + structures to providers that rely on OpenRouter semantics (both the + OpenRouter provider itself and the Custom provider). + + Key duties + * Load :class:`ModelCapabilities` definitions from configuration files + * Maintain a case-insensitive alias → canonical name map for fast + resolution + * Provide helpers to list models, list aliases, and resolve an arbitrary + name to its capability object without repeatedly touching the file + system. + """ + + def __init__(self, config_path: Optional[str] = None): + """Initialize the registry. + + Args: + config_path: Path to config file. If None, uses default locations. + """ + self.alias_map: dict[str, str] = {} # alias -> model_name + self.model_map: dict[str, ModelCapabilities] = {} # model_name -> config + + # Determine config path and loading strategy + self.use_resources = False + if config_path: + # Direct config_path parameter + self.config_path = Path(config_path) + else: + # Check environment variable first + env_path = os.getenv("CUSTOM_MODELS_CONFIG_PATH") + if env_path: + # Environment variable path + self.config_path = Path(env_path) + else: + # Try importlib.resources for robust packaging support + self.config_path = None + self.use_resources = False + + try: + resource_traversable = importlib.resources.files("conf").joinpath("custom_models.json") + if hasattr(resource_traversable, "read_text"): + self.use_resources = True + else: + raise AttributeError("read_text not available") + except Exception: + pass + + if not self.use_resources: + # Fallback to file system paths + potential_paths = [ + Path(__file__).parent.parent / "conf" / "custom_models.json", + Path.cwd() / "conf" / "custom_models.json", + ] + + for path in potential_paths: + if path.exists(): + self.config_path = path + break + + if self.config_path is None: + self.config_path = potential_paths[0] + + # Load configuration + self.reload() + + def reload(self) -> None: + """Reload configuration from disk.""" + try: + configs = self._read_config() + self._build_maps(configs) + caller_info = "" + try: + import inspect + + caller_frame = inspect.currentframe().f_back + if caller_frame: + caller_name = caller_frame.f_code.co_name + caller_file = ( + caller_frame.f_code.co_filename.split("/")[-1] if caller_frame.f_code.co_filename else "unknown" + ) + # Look for tool context + while caller_frame: + frame_locals = caller_frame.f_locals + if "self" in frame_locals and hasattr(frame_locals["self"], "get_name"): + tool_name = frame_locals["self"].get_name() + caller_info = f" (called from {tool_name} tool)" + break + caller_frame = caller_frame.f_back + if not caller_info: + caller_info = f" (called from {caller_name} in {caller_file})" + except Exception: + # If frame inspection fails, just continue without caller info + pass + + logging.debug( + f"Loaded {len(self.model_map)} OpenRouter models with {len(self.alias_map)} aliases{caller_info}" + ) + except ValueError as e: + # Re-raise ValueError only for duplicate aliases (critical config errors) + logging.error(f"Failed to load OpenRouter model configuration: {e}") + # Initialize with empty maps on failure + self.alias_map = {} + self.model_map = {} + if "Duplicate alias" in str(e): + raise + except Exception as e: + logging.error(f"Failed to load OpenRouter model configuration: {e}") + # Initialize with empty maps on failure + self.alias_map = {} + self.model_map = {} + + def _read_config(self) -> list[ModelCapabilities]: + """Read configuration from file or package resources. + + Returns: + List of model configurations + """ + try: + if self.use_resources: + # Use importlib.resources for packaged environments + try: + resource_path = importlib.resources.files("conf").joinpath("custom_models.json") + if hasattr(resource_path, "read_text"): + # Python 3.9+ + config_text = resource_path.read_text(encoding="utf-8") + else: + # Python 3.8 fallback + with resource_path.open("r", encoding="utf-8") as f: + config_text = f.read() + + import json + + data = json.loads(config_text) + logging.debug("Loaded OpenRouter config from package resources") + except Exception as e: + logging.warning(f"Failed to load config from resources: {e}") + return [] + else: + # Use file path loading + if not self.config_path.exists(): + logging.warning(f"OpenRouter model config not found at {self.config_path}") + return [] + + # Use centralized JSON reading utility + data = read_json_file(str(self.config_path)) + logging.debug(f"Loaded OpenRouter config from file: {self.config_path}") + + if data is None: + location = "resources" if self.use_resources else str(self.config_path) + raise ValueError(f"Could not read or parse JSON from {location}") + + # Parse models + configs = [] + for model_data in data.get("models", []): + # Create ModelCapabilities directly from JSON data + # Handle temperature_constraint conversion + temp_constraint_str = model_data.get("temperature_constraint") + temp_constraint = TemperatureConstraint.create(temp_constraint_str or "range") + + # Set provider-specific defaults based on is_custom flag + is_custom = model_data.get("is_custom", False) + if is_custom: + model_data.setdefault("provider", ProviderType.CUSTOM) + model_data.setdefault("friendly_name", f"Custom ({model_data.get('model_name', 'Unknown')})") + else: + model_data.setdefault("provider", ProviderType.OPENROUTER) + model_data.setdefault("friendly_name", f"OpenRouter ({model_data.get('model_name', 'Unknown')})") + model_data["temperature_constraint"] = temp_constraint + + # Remove the string version of temperature_constraint before creating ModelCapabilities + if "temperature_constraint" in model_data and isinstance(model_data["temperature_constraint"], str): + del model_data["temperature_constraint"] + model_data["temperature_constraint"] = temp_constraint + + config = ModelCapabilities(**model_data) + configs.append(config) + + return configs + except ValueError: + # Re-raise ValueError for specific config errors + raise + except Exception as e: + location = "resources" if self.use_resources else str(self.config_path) + raise ValueError(f"Error reading config from {location}: {e}") + + def _build_maps(self, configs: list[ModelCapabilities]) -> None: + """Build alias and model maps from configurations. + + Args: + configs: List of model configurations + """ + alias_map = {} + model_map = {} + + for config in configs: + # Add to model map + model_map[config.model_name] = config + + # Add the model_name itself as an alias for case-insensitive lookup + # But only if it's not already in the aliases list + model_name_lower = config.model_name.lower() + aliases_lower = [alias.lower() for alias in config.aliases] + + if model_name_lower not in aliases_lower: + if model_name_lower in alias_map: + existing_model = alias_map[model_name_lower] + if existing_model != config.model_name: + raise ValueError( + f"Duplicate model name '{config.model_name}' (case-insensitive) found for models " + f"'{existing_model}' and '{config.model_name}'" + ) + else: + alias_map[model_name_lower] = config.model_name + + # Add aliases + for alias in config.aliases: + alias_lower = alias.lower() + if alias_lower in alias_map: + existing_model = alias_map[alias_lower] + raise ValueError( + f"Duplicate alias '{alias}' found for models '{existing_model}' and '{config.model_name}'" + ) + alias_map[alias_lower] = config.model_name + + # Atomic update + self.alias_map = alias_map + self.model_map = model_map + + def resolve(self, name_or_alias: str) -> Optional[ModelCapabilities]: + """Resolve a model name or alias to configuration. + + Args: + name_or_alias: Model name or alias to resolve + + Returns: + Model configuration if found, None otherwise + """ + # Try alias lookup (case-insensitive) - this now includes model names too + alias_lower = name_or_alias.lower() + if alias_lower in self.alias_map: + model_name = self.alias_map[alias_lower] + return self.model_map.get(model_name) + + return None + + def get_capabilities(self, name_or_alias: str) -> Optional[ModelCapabilities]: + """Get model capabilities for a name or alias. + + Args: + name_or_alias: Model name or alias + + Returns: + ModelCapabilities if found, None otherwise + """ + # Registry now returns ModelCapabilities directly + return self.resolve(name_or_alias) + + def get_model_config(self, name_or_alias: str) -> Optional[ModelCapabilities]: + """Backward-compatible wrapper used by providers and older tests.""" + + return self.resolve(name_or_alias) + + def list_models(self) -> list[str]: + """List all available model names.""" + return list(self.model_map.keys()) + + def list_aliases(self) -> list[str]: + """List all available aliases.""" + return list(self.alias_map.keys()) diff --git a/providers/registry.py b/providers/registry.py new file mode 100644 index 0000000..6f412ff --- /dev/null +++ b/providers/registry.py @@ -0,0 +1,397 @@ +"""Model provider registry for managing available providers.""" + +import logging +import os +from typing import TYPE_CHECKING, Optional + +from .base import ModelProvider +from .shared import ProviderType + +if TYPE_CHECKING: + from tools.models import ToolModelCategory + + +class ModelProviderRegistry: + """Central catalogue of provider implementations used by the MCP server. + + Role + Holds the mapping between :class:`ProviderType` values and concrete + :class:`ModelProvider` subclasses/factories. At runtime the registry + is responsible for instantiating providers, caching them for reuse, and + mediating lookup of providers and model names in provider priority + order. + + Core responsibilities + * Resolve API keys and other runtime configuration for each provider + * Lazily create provider instances so unused backends incur no cost + * Expose convenience methods for enumerating available models and + locating which provider can service a requested model name or alias + * Honour the project-wide provider priority policy so namespaces (or + alias collisions) are resolved deterministically. + """ + + _instance = None + + # Provider priority order for model selection + # Native APIs first, then custom endpoints, then catch-all providers + PROVIDER_PRIORITY_ORDER = [ + ProviderType.GOOGLE, # Direct Gemini access + ProviderType.OPENAI, # Direct OpenAI access + ProviderType.XAI, # Direct X.AI GROK access + ProviderType.DIAL, # DIAL unified API access + ProviderType.CUSTOM, # Local/self-hosted models + ProviderType.OPENROUTER, # Catch-all for cloud models + ] + + def __new__(cls): + """Singleton pattern for registry.""" + if cls._instance is None: + logging.debug("REGISTRY: Creating new registry instance") + cls._instance = super().__new__(cls) + # Initialize instance dictionaries on first creation + cls._instance._providers = {} + cls._instance._initialized_providers = {} + logging.debug(f"REGISTRY: Created instance {cls._instance}") + return cls._instance + + @classmethod + def register_provider(cls, provider_type: ProviderType, provider_class: type[ModelProvider]) -> None: + """Register a new provider class. + + Args: + provider_type: Type of the provider (e.g., ProviderType.GOOGLE) + provider_class: Class that implements ModelProvider interface + """ + instance = cls() + instance._providers[provider_type] = provider_class + # Invalidate any cached instance so subsequent lookups use the new registration + instance._initialized_providers.pop(provider_type, None) + + @classmethod + def get_provider(cls, provider_type: ProviderType, force_new: bool = False) -> Optional[ModelProvider]: + """Get an initialized provider instance. + + Args: + provider_type: Type of provider to get + force_new: Force creation of new instance instead of using cached + + Returns: + Initialized ModelProvider instance or None if not available + """ + instance = cls() + + # Return cached instance if available and not forcing new + if not force_new and provider_type in instance._initialized_providers: + return instance._initialized_providers[provider_type] + + # Check if provider class is registered + if provider_type not in instance._providers: + return None + + # Get API key from environment + api_key = cls._get_api_key_for_provider(provider_type) + + # Get provider class or factory function + provider_class = instance._providers[provider_type] + + # For custom providers, handle special initialization requirements + if provider_type == ProviderType.CUSTOM: + # Check if it's a factory function (callable but not a class) + if callable(provider_class) and not isinstance(provider_class, type): + # Factory function - call it with api_key parameter + provider = provider_class(api_key=api_key) + else: + # Regular class - need to handle URL requirement + custom_url = os.getenv("CUSTOM_API_URL", "") + if not custom_url: + if api_key: # Key is set but URL is missing + logging.warning("CUSTOM_API_KEY set but CUSTOM_API_URL missing – skipping Custom provider") + return None + # Use empty string as API key for custom providers that don't need auth (e.g., Ollama) + # This allows the provider to be created even without CUSTOM_API_KEY being set + api_key = api_key or "" + # Initialize custom provider with both API key and base URL + provider = provider_class(api_key=api_key, base_url=custom_url) + elif provider_type == ProviderType.GOOGLE: + # For Gemini, check if custom base URL is configured + if not api_key: + return None + gemini_base_url = os.getenv("GEMINI_BASE_URL") + provider_kwargs = {"api_key": api_key} + if gemini_base_url: + provider_kwargs["base_url"] = gemini_base_url + logging.info(f"Initialized Gemini provider with custom endpoint: {gemini_base_url}") + provider = provider_class(**provider_kwargs) + else: + if not api_key: + return None + # Initialize non-custom provider with just API key + provider = provider_class(api_key=api_key) + + # Cache the instance + instance._initialized_providers[provider_type] = provider + + return provider + + @classmethod + def get_provider_for_model(cls, model_name: str) -> Optional[ModelProvider]: + """Get provider instance for a specific model name. + + Provider priority order: + 1. Native APIs (GOOGLE, OPENAI) - Most direct and efficient + 2. CUSTOM - For local/private models with specific endpoints + 3. OPENROUTER - Catch-all for cloud models via unified API + + Args: + model_name: Name of the model (e.g., "gemini-2.5-flash", "gpt5") + + Returns: + ModelProvider instance that supports this model + """ + logging.debug(f"get_provider_for_model called with model_name='{model_name}'") + + # Check providers in priority order + instance = cls() + logging.debug(f"Registry instance: {instance}") + logging.debug(f"Available providers in registry: {list(instance._providers.keys())}") + + for provider_type in cls.PROVIDER_PRIORITY_ORDER: + if provider_type in instance._providers: + logging.debug(f"Found {provider_type} in registry") + # Get or create provider instance + provider = cls.get_provider(provider_type) + if provider and provider.validate_model_name(model_name): + logging.debug(f"{provider_type} validates model {model_name}") + return provider + else: + logging.debug(f"{provider_type} does not validate model {model_name}") + else: + logging.debug(f"{provider_type} not found in registry") + + logging.debug(f"No provider found for model {model_name}") + return None + + @classmethod + def get_available_providers(cls) -> list[ProviderType]: + """Get list of registered provider types.""" + instance = cls() + return list(instance._providers.keys()) + + @classmethod + def get_available_models(cls, respect_restrictions: bool = True) -> dict[str, ProviderType]: + """Get mapping of all available models to their providers. + + Args: + respect_restrictions: If True, filter out models not allowed by restrictions + + Returns: + Dict mapping model names to provider types + """ + # Import here to avoid circular imports + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() if respect_restrictions else None + models: dict[str, ProviderType] = {} + instance = cls() + + for provider_type in instance._providers: + provider = cls.get_provider(provider_type) + if not provider: + continue + + try: + available = provider.list_models(respect_restrictions=respect_restrictions) + except NotImplementedError: + logging.warning("Provider %s does not implement list_models", provider_type) + continue + + for model_name in available: + # ===================================================================================== + # CRITICAL: Prevent double restriction filtering (Fixed Issue #98) + # ===================================================================================== + # Previously, both the provider AND registry applied restrictions, causing + # double-filtering that resulted in "no models available" errors. + # + # Logic: If respect_restrictions=True, provider already filtered models, + # so registry should NOT filter them again. + # TEST COVERAGE: tests/test_provider_routing_bugs.py::TestOpenRouterAliasRestrictions + # ===================================================================================== + if ( + restriction_service + and not respect_restrictions # Only filter if provider didn't already filter + and not restriction_service.is_allowed(provider_type, model_name) + ): + logging.debug("Model %s filtered by restrictions", model_name) + continue + models[model_name] = provider_type + + return models + + @classmethod + def get_available_model_names(cls, provider_type: Optional[ProviderType] = None) -> list[str]: + """Get list of available model names, optionally filtered by provider. + + This respects model restrictions automatically. + + Args: + provider_type: Optional provider to filter by + + Returns: + List of available model names + """ + available_models = cls.get_available_models(respect_restrictions=True) + + if provider_type: + # Filter by specific provider + return [name for name, ptype in available_models.items() if ptype == provider_type] + else: + # Return all available models + return list(available_models.keys()) + + @classmethod + def _get_api_key_for_provider(cls, provider_type: ProviderType) -> Optional[str]: + """Get API key for a provider from environment variables. + + Args: + provider_type: Provider type to get API key for + + Returns: + API key string or None if not found + """ + key_mapping = { + ProviderType.GOOGLE: "GEMINI_API_KEY", + ProviderType.OPENAI: "OPENAI_API_KEY", + ProviderType.XAI: "XAI_API_KEY", + ProviderType.OPENROUTER: "OPENROUTER_API_KEY", + ProviderType.CUSTOM: "CUSTOM_API_KEY", # Can be empty for providers that don't need auth + ProviderType.DIAL: "DIAL_API_KEY", + } + + env_var = key_mapping.get(provider_type) + if not env_var: + return None + + return os.getenv(env_var) + + @classmethod + def _get_allowed_models_for_provider(cls, provider: ModelProvider, provider_type: ProviderType) -> list[str]: + """Get a list of allowed canonical model names for a given provider. + + Args: + provider: The provider instance to get models for + provider_type: The provider type for restriction checking + + Returns: + List of model names that are both supported and allowed + """ + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() + + allowed_models = [] + + # Get the provider's supported models + try: + # Use list_models to get all supported models (handles both regular and custom providers) + supported_models = provider.list_models(respect_restrictions=False) + except (NotImplementedError, AttributeError): + # Fallback to provider-declared capability maps if list_models not implemented + model_map = getattr(provider, "MODEL_CAPABILITIES", None) + supported_models = list(model_map.keys()) if isinstance(model_map, dict) else [] + + # Filter by restrictions + for model_name in supported_models: + if restriction_service.is_allowed(provider_type, model_name): + allowed_models.append(model_name) + + return allowed_models + + @classmethod + def get_preferred_fallback_model(cls, tool_category: Optional["ToolModelCategory"] = None) -> str: + """Get the preferred fallback model based on provider priority and tool category. + + This method orchestrates model selection by: + 1. Getting allowed models for each provider (respecting restrictions) + 2. Asking providers for their preference from the allowed list + 3. Falling back to first available model if no preference given + + Args: + tool_category: Optional category to influence model selection + + Returns: + Model name string for fallback use + """ + from tools.models import ToolModelCategory + + effective_category = tool_category or ToolModelCategory.BALANCED + first_available_model = None + + # Ask each provider for their preference in priority order + for provider_type in cls.PROVIDER_PRIORITY_ORDER: + provider = cls.get_provider(provider_type) + if provider: + # 1. Registry filters the models first + allowed_models = cls._get_allowed_models_for_provider(provider, provider_type) + + if not allowed_models: + continue + + # 2. Keep track of the first available model as fallback + if not first_available_model: + first_available_model = sorted(allowed_models)[0] + + # 3. Ask provider to pick from allowed list + preferred_model = provider.get_preferred_model(effective_category, allowed_models) + + if preferred_model: + logging.debug( + f"Provider {provider_type.value} selected '{preferred_model}' for category '{effective_category.value}'" + ) + return preferred_model + + # If no provider returned a preference, use first available model + if first_available_model: + logging.debug(f"No provider preference, using first available: {first_available_model}") + return first_available_model + + # Ultimate fallback if no providers have models + logging.warning("No models available from any provider, using default fallback") + return "gemini-2.5-flash" + + @classmethod + def get_available_providers_with_keys(cls) -> list[ProviderType]: + """Get list of provider types that have valid API keys. + + Returns: + List of ProviderType values for providers with valid API keys + """ + available = [] + instance = cls() + for provider_type in instance._providers: + if cls.get_provider(provider_type) is not None: + available.append(provider_type) + return available + + @classmethod + def clear_cache(cls) -> None: + """Clear cached provider instances.""" + instance = cls() + instance._initialized_providers.clear() + + @classmethod + def reset_for_testing(cls) -> None: + """Reset the registry to a clean state for testing. + + This provides a safe, public API for tests to clean up registry state + without directly manipulating private attributes. + """ + cls._instance = None + if hasattr(cls, "_providers"): + cls._providers = {} + + @classmethod + def unregister_provider(cls, provider_type: ProviderType) -> None: + """Unregister a provider (mainly for testing).""" + instance = cls() + instance._providers.pop(provider_type, None) + instance._initialized_providers.pop(provider_type, None) diff --git a/providers/shared/__init__.py b/providers/shared/__init__.py new file mode 100644 index 0000000..aa7c613 --- /dev/null +++ b/providers/shared/__init__.py @@ -0,0 +1,21 @@ +"""Shared data structures and helpers for model providers.""" + +from .model_capabilities import ModelCapabilities +from .model_response import ModelResponse +from .provider_type import ProviderType +from .temperature import ( + DiscreteTemperatureConstraint, + FixedTemperatureConstraint, + RangeTemperatureConstraint, + TemperatureConstraint, +) + +__all__ = [ + "ModelCapabilities", + "ModelResponse", + "ProviderType", + "TemperatureConstraint", + "FixedTemperatureConstraint", + "RangeTemperatureConstraint", + "DiscreteTemperatureConstraint", +] diff --git a/providers/shared/model_capabilities.py b/providers/shared/model_capabilities.py new file mode 100644 index 0000000..02c2d1c --- /dev/null +++ b/providers/shared/model_capabilities.py @@ -0,0 +1,122 @@ +"""Dataclass describing the feature set of a model exposed by a provider.""" + +from dataclasses import dataclass, field +from typing import Optional + +from .provider_type import ProviderType +from .temperature import RangeTemperatureConstraint, TemperatureConstraint + +__all__ = ["ModelCapabilities"] + + +@dataclass +class ModelCapabilities: + """Static description of what a model can do within a provider. + + Role + Acts as the canonical record for everything the server needs to know + about a model—its provider, token limits, feature switches, aliases, + and temperature rules. Providers populate these objects so tools and + higher-level services can rely on a consistent schema. + + Typical usage + * Provider subclasses declare `MODEL_CAPABILITIES` maps containing these + objects (for example ``OpenAIModelProvider``) + * Helper utilities (e.g. restriction validation, alias expansion) read + these objects to build model lists for tooling and policy enforcement + * Tool selection logic inspects attributes such as + ``supports_extended_thinking`` or ``context_window`` to choose an + appropriate model for a task. + """ + + provider: ProviderType + model_name: str + friendly_name: str + description: str = "" + aliases: list[str] = field(default_factory=list) + + # Capacity limits / resource budgets + context_window: int = 0 + max_output_tokens: int = 0 + max_thinking_tokens: int = 0 + + # Capability flags + supports_extended_thinking: bool = False + supports_system_prompts: bool = True + supports_streaming: bool = True + supports_function_calling: bool = False + supports_images: bool = False + supports_json_mode: bool = False + supports_temperature: bool = True + + # Additional attributes + max_image_size_mb: float = 0.0 + is_custom: bool = False + temperature_constraint: TemperatureConstraint = field( + default_factory=lambda: RangeTemperatureConstraint(0.0, 2.0, 0.3) + ) + + def get_effective_temperature(self, requested_temperature: float) -> Optional[float]: + """Return the temperature that should be sent to the provider. + + Models that do not support temperature return ``None`` so that callers + can omit the parameter entirely. For supported models, the configured + constraint clamps the requested value into a provider-safe range. + """ + + if not self.supports_temperature: + return None + + return self.temperature_constraint.get_corrected_value(requested_temperature) + + @staticmethod + def collect_aliases(model_configs: dict[str, "ModelCapabilities"]) -> dict[str, list[str]]: + """Build a mapping of model name to aliases from capability configs.""" + + return { + base_model: capabilities.aliases + for base_model, capabilities in model_configs.items() + if capabilities.aliases + } + + @staticmethod + def collect_model_names( + model_configs: dict[str, "ModelCapabilities"], + *, + include_aliases: bool = True, + lowercase: bool = False, + unique: bool = False, + ) -> list[str]: + """Build an ordered list of model names and aliases. + + Args: + model_configs: Mapping of canonical model names to capabilities. + include_aliases: When True, include aliases for each model. + lowercase: When True, normalize names to lowercase. + unique: When True, ensure each returned name appears once (after formatting). + + Returns: + Ordered list of model names (and optionally aliases) formatted per options. + """ + + formatted_names: list[str] = [] + seen: set[str] | None = set() if unique else None + + def append_name(name: str) -> None: + formatted = name.lower() if lowercase else name + + if seen is not None: + if formatted in seen: + return + seen.add(formatted) + + formatted_names.append(formatted) + + for base_model, capabilities in model_configs.items(): + append_name(base_model) + + if include_aliases and capabilities.aliases: + for alias in capabilities.aliases: + append_name(alias) + + return formatted_names diff --git a/providers/shared/model_response.py b/providers/shared/model_response.py new file mode 100644 index 0000000..cccff48 --- /dev/null +++ b/providers/shared/model_response.py @@ -0,0 +1,26 @@ +"""Dataclass used to normalise provider SDK responses.""" + +from dataclasses import dataclass, field +from typing import Any + +from .provider_type import ProviderType + +__all__ = ["ModelResponse"] + + +@dataclass +class ModelResponse: + """Portable representation of a provider completion.""" + + content: str + usage: dict[str, int] = field(default_factory=dict) + model_name: str = "" + friendly_name: str = "" + provider: ProviderType = ProviderType.GOOGLE + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def total_tokens(self) -> int: + """Return the total token count if the provider reported usage data.""" + + return self.usage.get("total_tokens", 0) diff --git a/providers/shared/provider_type.py b/providers/shared/provider_type.py new file mode 100644 index 0000000..44153f0 --- /dev/null +++ b/providers/shared/provider_type.py @@ -0,0 +1,16 @@ +"""Enumeration describing which backend owns a given model.""" + +from enum import Enum + +__all__ = ["ProviderType"] + + +class ProviderType(Enum): + """Canonical identifiers for every supported provider backend.""" + + GOOGLE = "google" + OPENAI = "openai" + XAI = "xai" + OPENROUTER = "openrouter" + CUSTOM = "custom" + DIAL = "dial" diff --git a/providers/shared/temperature.py b/providers/shared/temperature.py new file mode 100644 index 0000000..22a54a9 --- /dev/null +++ b/providers/shared/temperature.py @@ -0,0 +1,188 @@ +"""Helper types for validating model temperature parameters.""" + +from abc import ABC, abstractmethod +from typing import Optional + +__all__ = [ + "TemperatureConstraint", + "FixedTemperatureConstraint", + "RangeTemperatureConstraint", + "DiscreteTemperatureConstraint", +] + +# Common heuristics for determining temperature support when explicit +# capabilities are unavailable (e.g., custom/local models). +_TEMP_UNSUPPORTED_PATTERNS = { + "o1", + "o3", + "o4", # OpenAI O-series reasoning models + "deepseek-reasoner", + "deepseek-r1", + "r1", # DeepSeek reasoner variants +} + +_TEMP_UNSUPPORTED_KEYWORDS = { + "reasoner", # Catch additional DeepSeek-style naming patterns +} + + +class TemperatureConstraint(ABC): + """Contract for temperature validation used by `ModelCapabilities`. + + Concrete providers describe their temperature behaviour by creating + subclasses that expose three operations: + * `validate` – decide whether a requested temperature is acceptable. + * `get_corrected_value` – coerce out-of-range values into a safe default. + * `get_description` – provide a human readable error message for users. + + Providers call these hooks before sending traffic to the underlying API so + that unsupported temperatures never reach the remote service. + """ + + @abstractmethod + def validate(self, temperature: float) -> bool: + """Return ``True`` when the temperature may be sent to the backend.""" + + @abstractmethod + def get_corrected_value(self, temperature: float) -> float: + """Return a valid substitute for an out-of-range temperature.""" + + @abstractmethod + def get_description(self) -> str: + """Describe the acceptable range to include in error messages.""" + + @abstractmethod + def get_default(self) -> float: + """Return the default temperature for the model.""" + + @staticmethod + def infer_support(model_name: str) -> tuple[bool, str]: + """Heuristically determine whether a model supports temperature.""" + + model_lower = model_name.lower() + + for pattern in _TEMP_UNSUPPORTED_PATTERNS: + conditions = ( + pattern == model_lower, + model_lower.startswith(f"{pattern}-"), + model_lower.startswith(f"openai/{pattern}"), + model_lower.startswith(f"deepseek/{pattern}"), + model_lower.endswith(f"-{pattern}"), + f"/{pattern}" in model_lower, + f"-{pattern}-" in model_lower, + ) + if any(conditions): + return False, f"detected pattern '{pattern}'" + + for keyword in _TEMP_UNSUPPORTED_KEYWORDS: + if keyword in model_lower: + return False, f"detected keyword '{keyword}'" + + return True, "default assumption for models without explicit metadata" + + @staticmethod + def resolve_settings( + model_name: str, + constraint_hint: Optional[str] = None, + ) -> tuple[bool, "TemperatureConstraint", str]: + """Derive temperature support and constraint for a model. + + Args: + model_name: Canonical model identifier or alias. + constraint_hint: Optional configuration hint (``"fixed"``, + ``"range"``, ``"discrete"``). When provided, the hint is + honoured directly. + + Returns: + Tuple ``(supports_temperature, constraint, diagnosis)`` describing + whether temperature may be tuned, the constraint object that should + be attached to :class:`ModelCapabilities`, and the reasoning behind + the decision. + """ + + if constraint_hint: + constraint = TemperatureConstraint.create(constraint_hint) + supports_temperature = constraint_hint != "fixed" + reason = f"constraint hint '{constraint_hint}'" + return supports_temperature, constraint, reason + + supports_temperature, reason = TemperatureConstraint.infer_support(model_name) + if supports_temperature: + constraint: TemperatureConstraint = RangeTemperatureConstraint(0.0, 2.0, 0.7) + else: + constraint = FixedTemperatureConstraint(1.0) + + return supports_temperature, constraint, reason + + @staticmethod + def create(constraint_type: str) -> "TemperatureConstraint": + """Factory that yields the appropriate constraint for a configuration hint.""" + + if constraint_type == "fixed": + # Fixed temperature models (O3/O4) only support temperature=1.0 + return FixedTemperatureConstraint(1.0) + if constraint_type == "discrete": + # For models with specific allowed values - using common OpenAI values as default + return DiscreteTemperatureConstraint([0.0, 0.3, 0.7, 1.0, 1.5, 2.0], 0.3) + # Default range constraint (for "range" or None) + return RangeTemperatureConstraint(0.0, 2.0, 0.3) + + +class FixedTemperatureConstraint(TemperatureConstraint): + """Constraint for models that enforce an exact temperature (for example O3).""" + + def __init__(self, value: float): + self.value = value + + def validate(self, temperature: float) -> bool: + return abs(temperature - self.value) < 1e-6 # Handle floating point precision + + def get_corrected_value(self, temperature: float) -> float: + return self.value + + def get_description(self) -> str: + return f"Only supports temperature={self.value}" + + def get_default(self) -> float: + return self.value + + +class RangeTemperatureConstraint(TemperatureConstraint): + """Constraint for providers that expose a continuous min/max temperature range.""" + + def __init__(self, min_temp: float, max_temp: float, default: Optional[float] = None): + self.min_temp = min_temp + self.max_temp = max_temp + self.default_temp = default or (min_temp + max_temp) / 2 + + def validate(self, temperature: float) -> bool: + return self.min_temp <= temperature <= self.max_temp + + def get_corrected_value(self, temperature: float) -> float: + return max(self.min_temp, min(self.max_temp, temperature)) + + def get_description(self) -> str: + return f"Supports temperature range [{self.min_temp}, {self.max_temp}]" + + def get_default(self) -> float: + return self.default_temp + + +class DiscreteTemperatureConstraint(TemperatureConstraint): + """Constraint for models that permit a discrete list of temperature values.""" + + def __init__(self, allowed_values: list[float], default: Optional[float] = None): + self.allowed_values = sorted(allowed_values) + self.default_temp = default or allowed_values[len(allowed_values) // 2] + + def validate(self, temperature: float) -> bool: + return any(abs(temperature - val) < 1e-6 for val in self.allowed_values) + + def get_corrected_value(self, temperature: float) -> float: + return min(self.allowed_values, key=lambda x: abs(x - temperature)) + + def get_description(self) -> str: + return f"Supports temperatures: {self.allowed_values}" + + def get_default(self) -> float: + return self.default_temp diff --git a/providers/xai.py b/providers/xai.py new file mode 100644 index 0000000..c03bc57 --- /dev/null +++ b/providers/xai.py @@ -0,0 +1,157 @@ +"""X.AI (GROK) model provider implementation.""" + +import logging +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from tools.models import ToolModelCategory + +from .openai_compatible import OpenAICompatibleProvider +from .shared import ModelCapabilities, ModelResponse, ProviderType, TemperatureConstraint + +logger = logging.getLogger(__name__) + + +class XAIModelProvider(OpenAICompatibleProvider): + """Integration for X.AI's GROK models exposed over an OpenAI-style API. + + Publishes capability metadata for the officially supported deployments and + maps tool-category preferences to the appropriate GROK model. + """ + + FRIENDLY_NAME = "X.AI" + + # Model configurations using ModelCapabilities objects + MODEL_CAPABILITIES = { + "grok-4": ModelCapabilities( + provider=ProviderType.XAI, + model_name="grok-4", + friendly_name="X.AI (Grok 4)", + context_window=256_000, # 256K tokens + max_output_tokens=256_000, # 256K tokens max output + supports_extended_thinking=True, # Grok-4 supports reasoning mode + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, # Function calling supported + supports_json_mode=True, # Structured outputs supported + supports_images=True, # Multimodal capabilities + max_image_size_mb=20.0, # Standard image size limit + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("range"), + description="GROK-4 (256K context) - Frontier multimodal reasoning model with advanced capabilities", + aliases=["grok", "grok4", "grok-4"], + ), + "grok-3": ModelCapabilities( + provider=ProviderType.XAI, + model_name="grok-3", + friendly_name="X.AI (Grok 3)", + context_window=131_072, # 131K tokens + max_output_tokens=131072, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet + supports_images=False, # Assuming GROK is text-only for now + max_image_size_mb=0.0, + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("range"), + description="GROK-3 (131K context) - Advanced reasoning model from X.AI, excellent for complex analysis", + aliases=["grok3"], + ), + "grok-3-fast": ModelCapabilities( + provider=ProviderType.XAI, + model_name="grok-3-fast", + friendly_name="X.AI (Grok 3 Fast)", + context_window=131_072, # 131K tokens + max_output_tokens=131072, + supports_extended_thinking=False, + supports_system_prompts=True, + supports_streaming=True, + supports_function_calling=True, + supports_json_mode=False, # Assuming GROK doesn't have JSON mode yet + supports_images=False, # Assuming GROK is text-only for now + max_image_size_mb=0.0, + supports_temperature=True, + temperature_constraint=TemperatureConstraint.create("range"), + description="GROK-3 Fast (131K context) - Higher performance variant, faster processing but more expensive", + aliases=["grok3fast", "grokfast", "grok3-fast"], + ), + } + + def __init__(self, api_key: str, **kwargs): + """Initialize X.AI provider with API key.""" + # Set X.AI base URL + kwargs.setdefault("base_url", "https://api.x.ai/v1") + super().__init__(api_key, **kwargs) + + def get_provider_type(self) -> ProviderType: + """Get the provider type.""" + return ProviderType.XAI + + def generate_content( + self, + prompt: str, + model_name: str, + system_prompt: Optional[str] = None, + temperature: float = 0.3, + max_output_tokens: Optional[int] = None, + **kwargs, + ) -> ModelResponse: + """Generate content using X.AI API with proper model name resolution.""" + # Resolve model alias before making API call + resolved_model_name = self._resolve_model_name(model_name) + + # Call parent implementation with resolved model name + return super().generate_content( + prompt=prompt, + model_name=resolved_model_name, + system_prompt=system_prompt, + temperature=temperature, + max_output_tokens=max_output_tokens, + **kwargs, + ) + + def get_preferred_model(self, category: "ToolModelCategory", allowed_models: list[str]) -> Optional[str]: + """Get XAI's preferred model for a given category from allowed models. + + Args: + category: The tool category requiring a model + allowed_models: Pre-filtered list of models allowed by restrictions + + Returns: + Preferred model name or None + """ + from tools.models import ToolModelCategory + + if not allowed_models: + return None + + if category == ToolModelCategory.EXTENDED_REASONING: + # Prefer GROK-4 for advanced reasoning with thinking mode + if "grok-4" in allowed_models: + return "grok-4" + elif "grok-3" in allowed_models: + return "grok-3" + # Fall back to any available model + return allowed_models[0] + + elif category == ToolModelCategory.FAST_RESPONSE: + # Prefer GROK-3-Fast for speed, then GROK-4 + if "grok-3-fast" in allowed_models: + return "grok-3-fast" + elif "grok-4" in allowed_models: + return "grok-4" + # Fall back to any available model + return allowed_models[0] + + else: # BALANCED or default + # Prefer GROK-4 for balanced use (best overall capabilities) + if "grok-4" in allowed_models: + return "grok-4" + elif "grok-3" in allowed_models: + return "grok-3" + elif "grok-3-fast" in allowed_models: + return "grok-3-fast" + # Fall back to any available model + return allowed_models[0] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..6e2b713 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +mcp>=1.0.0 +google-genai>=1.19.0 +openai>=1.55.2 # Minimum version for httpx 0.28.0 compatibility +pydantic>=2.0.0 +python-dotenv>=1.0.0 +importlib-resources>=5.0.0; python_version<"3.9" + +# Development dependencies (install with pip install -r requirements-dev.txt) +# pytest>=7.4.0 +# pytest-asyncio>=0.21.0 +# pytest-mock>=3.11.0 \ No newline at end of file diff --git a/run-server.sh b/run-server.sh new file mode 100755 index 0000000..619f790 --- /dev/null +++ b/run-server.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# Zen-Marketing MCP Server Setup and Run Script + +set -e + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +cd "$SCRIPT_DIR" + +echo "==================================" +echo "Zen-Marketing MCP Server Setup" +echo "==================================" + +# Check Python version +echo "Checking Python installation..." +if ! command -v python3 &> /dev/null; then + echo "Error: Python 3 is required but not installed" + exit 1 +fi + +PYTHON_VERSION=$(python3 --version 2>&1 | awk '{print $2}') +echo "Found Python $PYTHON_VERSION" + +# Create virtual environment if it doesn't exist +if [ ! -d ".venv" ]; then + echo "Creating virtual environment..." + python3 -m venv .venv +else + echo "Virtual environment already exists" +fi + +# Activate virtual environment +echo "Activating virtual environment..." +source .venv/bin/activate + +# Upgrade pip +echo "Upgrading pip..." +pip install --upgrade pip --quiet + +# Install requirements +echo "Installing requirements..." +pip install -r requirements.txt --quiet + +# Check for .env file +if [ ! -f ".env" ]; then + echo "" + echo "WARNING: No .env file found!" + echo "Please create a .env file based on .env.example:" + echo " cp .env.example .env" + echo " # Then edit .env with your API keys" + echo "" + read -p "Continue anyway? (y/N) " -n 1 -r + echo + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + exit 1 + fi +fi + +# Run the server +echo "" +echo "==================================" +echo "Starting Zen-Marketing MCP Server" +echo "==================================" +echo "" + +python server.py diff --git a/server.py b/server.py new file mode 100644 index 0000000..2899267 --- /dev/null +++ b/server.py @@ -0,0 +1,352 @@ +""" +Zen-Marketing MCP Server - Main server implementation + +AI-powered marketing tools for Claude Desktop, specialized for technical B2B content +creation, multi-platform campaigns, and content variation testing. + +Based on Zen MCP Server architecture by Fahad Gilani. +""" + +import asyncio +import atexit +import logging +import os +import sys +import time +from logging.handlers import RotatingFileHandler +from pathlib import Path +from typing import Any + +# Load environment variables from .env file +try: + from dotenv import load_dotenv + + script_dir = Path(__file__).parent + env_file = script_dir / ".env" + load_dotenv(dotenv_path=env_file, override=False) +except ImportError: + pass + +from mcp.server import Server +from mcp.server.models import InitializationOptions +from mcp.server.stdio import stdio_server +from mcp.types import ( + GetPromptResult, + Prompt, + PromptMessage, + PromptsCapability, + ServerCapabilities, + TextContent, + Tool, + ToolsCapability, +) + +from config import DEFAULT_MODEL, __version__ +from tools.chat import ChatTool +from tools.contentvariant import ContentVariantTool +from tools.listmodels import ListModelsTool +from tools.models import ToolOutput +from tools.version import VersionTool + +# Configure logging +log_level = os.getenv("LOG_LEVEL", "INFO").upper() + + +class LocalTimeFormatter(logging.Formatter): + def formatTime(self, record, datefmt=None): + ct = self.converter(record.created) + if datefmt: + s = time.strftime(datefmt, ct) + else: + t = time.strftime("%Y-%m-%d %H:%M:%S", ct) + s = f"{t},{record.msecs:03.0f}" + return s + + +# Configure logging +log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + +root_logger = logging.getLogger() +root_logger.handlers.clear() + +stderr_handler = logging.StreamHandler(sys.stderr) +stderr_handler.setLevel(getattr(logging, log_level, logging.INFO)) +stderr_handler.setFormatter(LocalTimeFormatter(log_format)) +root_logger.addHandler(stderr_handler) + +root_logger.setLevel(getattr(logging, log_level, logging.INFO)) + +# Add rotating file handler +try: + log_dir = Path(__file__).parent / "logs" + log_dir.mkdir(exist_ok=True) + + file_handler = RotatingFileHandler( + log_dir / "mcp_server.log", + maxBytes=20 * 1024 * 1024, + backupCount=5, + encoding="utf-8", + ) + file_handler.setLevel(getattr(logging, log_level, logging.INFO)) + file_handler.setFormatter(LocalTimeFormatter(log_format)) + logging.getLogger().addHandler(file_handler) + + mcp_logger = logging.getLogger("mcp_activity") + mcp_file_handler = RotatingFileHandler( + log_dir / "mcp_activity.log", + maxBytes=10 * 1024 * 1024, + backupCount=2, + encoding="utf-8", + ) + mcp_file_handler.setLevel(logging.INFO) + mcp_file_handler.setFormatter(LocalTimeFormatter("%(asctime)s - %(message)s")) + mcp_logger.addHandler(mcp_file_handler) + mcp_logger.setLevel(logging.INFO) + mcp_logger.propagate = True + + logging.info(f"Logging to: {log_dir / 'mcp_server.log'}") + logging.info(f"Process PID: {os.getpid()}") +except Exception as e: + print(f"Warning: Could not set up file logging: {e}", file=sys.stderr) + +logger = logging.getLogger(__name__) + +# Create MCP server instance +server: Server = Server("zen-marketing") + +# Essential tools that cannot be disabled +ESSENTIAL_TOOLS = {"version", "listmodels"} + + +def parse_disabled_tools_env() -> set[str]: + """Parse DISABLED_TOOLS environment variable""" + disabled_tools_env = os.getenv("DISABLED_TOOLS", "").strip() + if not disabled_tools_env: + return set() + return {t.strip().lower() for t in disabled_tools_env.split(",") if t.strip()} + + +def filter_disabled_tools(all_tools: dict[str, Any]) -> dict[str, Any]: + """Filter tools based on DISABLED_TOOLS environment variable""" + disabled_tools = parse_disabled_tools_env() + if not disabled_tools: + logger.info("All tools enabled") + return all_tools + + enabled_tools = {} + for tool_name, tool_instance in all_tools.items(): + if tool_name in ESSENTIAL_TOOLS or tool_name not in disabled_tools: + enabled_tools[tool_name] = tool_instance + else: + logger.info(f"Tool '{tool_name}' disabled via DISABLED_TOOLS") + + logger.info(f"Active tools: {sorted(enabled_tools.keys())}") + return enabled_tools + + +# Initialize tool registry +TOOLS = { + "chat": ChatTool(), + "contentvariant": ContentVariantTool(), + "listmodels": ListModelsTool(), + "version": VersionTool(), +} +TOOLS = filter_disabled_tools(TOOLS) + +# Prompt templates +PROMPT_TEMPLATES = { + "chat": { + "name": "chat", + "description": "Chat and brainstorm marketing ideas", + "template": "Chat about marketing strategy with {model}", + }, + "contentvariant": { + "name": "variations", + "description": "Generate content variations for A/B testing", + "template": "Generate content variations for testing with {model}", + }, + "listmodels": { + "name": "listmodels", + "description": "List available AI models", + "template": "List all available models", + }, + "version": { + "name": "version", + "description": "Show server version", + "template": "Show Zen-Marketing server version", + }, +} + + +def configure_providers(): + """Configure and validate AI providers""" + logger.debug("Checking environment variables for API keys...") + + from providers import ModelProviderRegistry + from providers.gemini import GeminiModelProvider + from providers.openrouter import OpenRouterProvider + from providers.shared import ProviderType + + valid_providers = [] + + # Check for Gemini API key + gemini_key = os.getenv("GEMINI_API_KEY") + if gemini_key and gemini_key != "your_gemini_api_key_here": + valid_providers.append("Gemini") + logger.info("Gemini API key found - Gemini models available") + + # Check for OpenRouter API key + openrouter_key = os.getenv("OPENROUTER_API_KEY") + if openrouter_key and openrouter_key != "your_openrouter_api_key_here": + valid_providers.append("OpenRouter") + logger.info("OpenRouter API key found - Multiple models available") + + # Register providers + if gemini_key and gemini_key != "your_gemini_api_key_here": + ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider) + + if openrouter_key and openrouter_key != "your_openrouter_api_key_here": + ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider) + + if not valid_providers: + error_msg = ( + "No valid API keys found. Please configure at least one:\n" + " - GEMINI_API_KEY for Google Gemini\n" + " - OPENROUTER_API_KEY for OpenRouter (minimax, etc.)\n" + "Create a .env file based on .env.example" + ) + logger.error(error_msg) + raise ValueError(error_msg) + + logger.info(f"Configured providers: {', '.join(valid_providers)}") + logger.info(f"Default model: {DEFAULT_MODEL}") + + +# Tool call handlers +@server.list_tools() +async def handle_list_tools() -> list[Tool]: + """List available tools""" + tools = [] + for tool_name, tool_instance in TOOLS.items(): + tool_schema = tool_instance.get_input_schema() + tools.append( + Tool( + name=tool_name, + description=tool_instance.get_description(), + inputSchema=tool_schema, + ) + ) + return tools + + +@server.call_tool() +async def handle_call_tool(name: str, arguments: dict) -> list[TextContent]: + """Handle tool execution requests""" + logger.info(f"Tool call: {name}") + + if name not in TOOLS: + error_msg = f"Unknown tool: {name}" + logger.error(error_msg) + return [TextContent(type="text", text=error_msg)] + + try: + tool = TOOLS[name] + result: ToolOutput = await tool.execute(arguments) + + if result.is_error: + logger.error(f"Tool {name} failed: {result.text}") + else: + logger.info(f"Tool {name} completed successfully") + + return [TextContent(type="text", text=result.text)] + + except Exception as e: + error_msg = f"Tool execution failed: {str(e)}" + logger.error(f"{error_msg}\n{type(e).__name__}") + return [TextContent(type="text", text=error_msg)] + + +@server.list_prompts() +async def handle_list_prompts() -> list[Prompt]: + """List available prompt templates""" + prompts = [] + for tool_name, template_info in PROMPT_TEMPLATES.items(): + if tool_name in TOOLS: + prompts.append( + Prompt( + name=template_info["name"], + description=template_info["description"], + arguments=[ + {"name": "model", "description": "AI model to use", "required": False} + ], + ) + ) + return prompts + + +@server.get_prompt() +async def handle_get_prompt(name: str, arguments: dict | None = None) -> GetPromptResult: + """Get a specific prompt template""" + model = arguments.get("model", DEFAULT_MODEL) if arguments else DEFAULT_MODEL + + for template_info in PROMPT_TEMPLATES.values(): + if template_info["name"] == name: + prompt_text = template_info["template"].format(model=model) + return GetPromptResult( + description=template_info["description"], + messages=[PromptMessage(role="user", content=TextContent(type="text", text=prompt_text))], + ) + + raise ValueError(f"Unknown prompt: {name}") + + +def cleanup(): + """Cleanup function called on shutdown""" + logger.info("Zen-Marketing MCP Server shutting down") + logging.shutdown() + + +async def main(): + """Main entry point""" + logger.info("=" * 80) + logger.info(f"Zen-Marketing MCP Server v{__version__}") + logger.info(f"Python: {sys.version}") + logger.info(f"Working directory: {os.getcwd()}") + logger.info("=" * 80) + + # Register cleanup + atexit.register(cleanup) + + # Configure providers + try: + configure_providers() + except ValueError as e: + logger.error(f"Configuration error: {e}") + sys.exit(1) + + # List enabled tools + logger.info(f"Enabled tools: {sorted(TOOLS.keys())}") + + # Run server + capabilities = ServerCapabilities( + tools=ToolsCapability(), + prompts=PromptsCapability(), + ) + + options = InitializationOptions( + server_name="zen-marketing", + server_version=__version__, + capabilities=capabilities, + ) + + async with stdio_server() as (read_stream, write_stream): + logger.info("Server started, awaiting requests...") + await server.run( + read_stream, + write_stream, + options, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/systemprompts/__init__.py b/systemprompts/__init__.py new file mode 100644 index 0000000..96e7b51 --- /dev/null +++ b/systemprompts/__init__.py @@ -0,0 +1,6 @@ +"""System prompts for Zen-Marketing MCP Server tools""" + +from .chat_prompt import CHAT_PROMPT +from .contentvariant_prompt import CONTENTVARIANT_PROMPT + +__all__ = ["CHAT_PROMPT", "CONTENTVARIANT_PROMPT"] diff --git a/systemprompts/chat_prompt.py b/systemprompts/chat_prompt.py new file mode 100644 index 0000000..579776e --- /dev/null +++ b/systemprompts/chat_prompt.py @@ -0,0 +1,29 @@ +"""System prompt for the chat tool""" + +CHAT_PROMPT = """You are a marketing strategist and content expert specializing in technical B2B marketing. + +Your expertise includes: +- Content strategy and variation testing +- Multi-platform content adaptation (LinkedIn, Twitter/X, Instagram, newsletters, blogs) +- Email marketing and subject line optimization +- SEO and WordPress content optimization +- Technical content editing while preserving author voice +- Brand voice consistency and style enforcement +- Campaign planning across multiple touchpoints +- Fact-checking and technical verification + +COMMUNICATION STYLE: +- Professional but approachable +- Data-informed recommendations +- Concrete, actionable suggestions +- Respect for brand voice and author expertise +- Platform-aware (character limits, best practices) + +When discussing content: +- Focus on testing hypotheses and variation strategies +- Consider psychological hooks and audience resonance +- Balance creativity with conversion optimization +- Maintain technical accuracy in specialized domains + +If web search is needed for current best practices, platform updates, or fact verification, indicate this clearly. +""" diff --git a/systemprompts/contentvariant_prompt.py b/systemprompts/contentvariant_prompt.py new file mode 100644 index 0000000..f6dd727 --- /dev/null +++ b/systemprompts/contentvariant_prompt.py @@ -0,0 +1,62 @@ +"""System prompt for the contentvariant tool""" + +CONTENTVARIANT_PROMPT = """You are a marketing content strategist specializing in A/B testing and variation generation. + +TASK: Generate multiple variations of marketing content for testing different approaches. + +OUTPUT FORMAT: +Return variations as a numbered list, each with: +1. The variation text +2. The testing angle (what makes it different) +3. Predicted audience response or reasoning + +CONSTRAINTS: +- Maintain core message across variations +- Respect platform character limits if specified +- Preserve brand voice characteristics provided +- Generate genuinely different approaches, not just word swaps +- Each variation should test a distinct hypothesis + +VARIATION TYPES: +- **Hook variations**: Different opening angles (question, statement, stat, story) +- **Length variations**: Short, medium, long within platform constraints +- **Tone variations**: Professional, conversational, urgent, educational, provocative +- **Structure variations**: Question format, list format, narrative, before-after +- **CTA variations**: Different calls-to-action (learn, try, join, download) +- **Psychological angles**: Curiosity, FOMO, social proof, contrarian, pain-solution + +TESTING ANGLES (use when specified): +- **Technical curiosity**: Lead with interesting technical detail +- **Contrarian/provocative**: Challenge conventional wisdom +- **Knowledge gap**: Emphasize what they don't know yet +- **Urgency/timeliness**: Time-sensitive opportunity or threat +- **Insider knowledge**: Position as exclusive expertise +- **Problem-solution**: Lead with pain point +- **Social proof**: Leverage credibility or results +- **Before-after**: Transformation narrative + +PLATFORM CONSIDERATIONS: +- Twitter/Bluesky: 280 chars, punchy hooks, visual language +- LinkedIn: 1300 chars optimal, professional tone, business value +- Instagram: 2200 chars, storytelling, visual companion +- Email subject: Under 60 chars, curiosity-driven +- Blog/article: Longer form, educational depth + +CHARACTER COUNT: +Always include character count for each variation when platform is specified. + +EXAMPLE OUTPUT FORMAT: +**Variation 1: Technical Curiosity Hook** +"Most HVAC techs miss this voltage regulation pattern—here's what the top 10% know about PCB diagnostics that changes everything." +(149 characters) +*Testing angle: Opens with intriguing exclusivity (what others miss) then promises insider knowledge* +*Predicted response: Appeals to practitioners wanting competitive edge* + +**Variation 2: Contrarian Statement** +"Stop blaming the capacitor. 80% of 'bad cap' calls are actually voltage regulation failures upstream. Here's the diagnostic most techs skip." +(142 characters) +*Testing angle: Challenges common diagnosis, positions as expert correction* +*Predicted response: Stops scroll with unexpected claim, appeals to problem-solvers* + +Be creative, test bold hypotheses, and make variations substantially different from each other. +""" diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..0e071b0 --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1,39 @@ +""" +Tool implementations for Zen MCP Server +""" + +from .analyze import AnalyzeTool +from .challenge import ChallengeTool +from .chat import ChatTool +from .codereview import CodeReviewTool +from .consensus import ConsensusTool +from .debug import DebugIssueTool +from .docgen import DocgenTool +from .listmodels import ListModelsTool +from .planner import PlannerTool +from .precommit import PrecommitTool +from .refactor import RefactorTool +from .secaudit import SecauditTool +from .testgen import TestGenTool +from .thinkdeep import ThinkDeepTool +from .tracer import TracerTool +from .version import VersionTool + +__all__ = [ + "ThinkDeepTool", + "CodeReviewTool", + "DebugIssueTool", + "DocgenTool", + "AnalyzeTool", + "ChatTool", + "ConsensusTool", + "ListModelsTool", + "PlannerTool", + "PrecommitTool", + "ChallengeTool", + "RefactorTool", + "SecauditTool", + "TestGenTool", + "TracerTool", + "VersionTool", +] diff --git a/tools/chat.py b/tools/chat.py new file mode 100644 index 0000000..87a8603 --- /dev/null +++ b/tools/chat.py @@ -0,0 +1,189 @@ +""" +Chat tool - General development chat and collaborative thinking + +This tool provides a conversational interface for general development assistance, +brainstorming, problem-solving, and collaborative thinking. It supports file context, +images, and conversation continuation for seamless multi-turn interactions. +""" + +from typing import TYPE_CHECKING, Any, Optional + +from pydantic import Field + +if TYPE_CHECKING: + from tools.models import ToolModelCategory + +from config import TEMPERATURE_BALANCED +from systemprompts import CHAT_PROMPT +from tools.shared.base_models import COMMON_FIELD_DESCRIPTIONS, ToolRequest + +from .simple.base import SimpleTool + +# Field descriptions matching the original Chat tool exactly +CHAT_FIELD_DESCRIPTIONS = { + "prompt": ( + "Your question or idea for collaborative thinking. Provide detailed context, including your goal, what you've tried, and any specific challenges. " + "CRITICAL: To discuss code, use 'files' parameter instead of pasting code blocks here." + ), + "files": "absolute file or folder paths for code context (do NOT shorten).", + "images": "Optional absolute image paths or base64 for visual context when helpful.", +} + + +class ChatRequest(ToolRequest): + """Request model for Chat tool""" + + prompt: str = Field(..., description=CHAT_FIELD_DESCRIPTIONS["prompt"]) + files: Optional[list[str]] = Field(default_factory=list, description=CHAT_FIELD_DESCRIPTIONS["files"]) + images: Optional[list[str]] = Field(default_factory=list, description=CHAT_FIELD_DESCRIPTIONS["images"]) + + +class ChatTool(SimpleTool): + """ + General development chat and collaborative thinking tool using SimpleTool architecture. + + This tool provides identical functionality to the original Chat tool but uses the new + SimpleTool architecture for cleaner code organization and better maintainability. + + Migration note: This tool is designed to be a drop-in replacement for the original + Chat tool with 100% behavioral compatibility. + """ + + def get_name(self) -> str: + return "chat" + + def get_description(self) -> str: + return ( + "General chat and collaborative thinking partner for brainstorming, development discussion, " + "getting second opinions, and exploring ideas. Use for ideas, validations, questions, and thoughtful explanations." + ) + + def get_system_prompt(self) -> str: + return CHAT_PROMPT + + def get_default_temperature(self) -> float: + return TEMPERATURE_BALANCED + + def get_model_category(self) -> "ToolModelCategory": + """Chat prioritizes fast responses and cost efficiency""" + from tools.models import ToolModelCategory + + return ToolModelCategory.FAST_RESPONSE + + def get_request_model(self): + """Return the Chat-specific request model""" + return ChatRequest + + # === Schema Generation === + # For maximum compatibility, we override get_input_schema() to match the original Chat tool exactly + + def get_input_schema(self) -> dict[str, Any]: + """ + Generate input schema matching the original Chat tool exactly. + + This maintains 100% compatibility with the original Chat tool by using + the same schema generation approach while still benefiting from SimpleTool + convenience methods. + """ + required_fields = ["prompt"] + if self.is_effective_auto_mode(): + required_fields.append("model") + + schema = { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": CHAT_FIELD_DESCRIPTIONS["prompt"], + }, + "files": { + "type": "array", + "items": {"type": "string"}, + "description": CHAT_FIELD_DESCRIPTIONS["files"], + }, + "images": { + "type": "array", + "items": {"type": "string"}, + "description": CHAT_FIELD_DESCRIPTIONS["images"], + }, + "model": self.get_model_field_schema(), + "temperature": { + "type": "number", + "description": COMMON_FIELD_DESCRIPTIONS["temperature"], + "minimum": 0, + "maximum": 1, + }, + "thinking_mode": { + "type": "string", + "enum": ["minimal", "low", "medium", "high", "max"], + "description": COMMON_FIELD_DESCRIPTIONS["thinking_mode"], + }, + "continuation_id": { + "type": "string", + "description": COMMON_FIELD_DESCRIPTIONS["continuation_id"], + }, + }, + "required": required_fields, + } + + return schema + + # === Tool-specific field definitions (alternative approach for reference) === + # These aren't used since we override get_input_schema(), but they show how + # the tool could be implemented using the automatic SimpleTool schema building + + def get_tool_fields(self) -> dict[str, dict[str, Any]]: + """ + Tool-specific field definitions for ChatSimple. + + Note: This method isn't used since we override get_input_schema() for + exact compatibility, but it demonstrates how ChatSimple could be + implemented using automatic schema building. + """ + return { + "prompt": { + "type": "string", + "description": CHAT_FIELD_DESCRIPTIONS["prompt"], + }, + "files": { + "type": "array", + "items": {"type": "string"}, + "description": CHAT_FIELD_DESCRIPTIONS["files"], + }, + "images": { + "type": "array", + "items": {"type": "string"}, + "description": CHAT_FIELD_DESCRIPTIONS["images"], + }, + } + + def get_required_fields(self) -> list[str]: + """Required fields for ChatSimple tool""" + return ["prompt"] + + # === Hook Method Implementations === + + async def prepare_prompt(self, request: ChatRequest) -> str: + """ + Prepare the chat prompt with optional context files. + + This implementation matches the original Chat tool exactly while using + SimpleTool convenience methods for cleaner code. + """ + # Use SimpleTool's Chat-style prompt preparation + return self.prepare_chat_style_prompt(request) + + def format_response(self, response: str, request: ChatRequest, model_info: Optional[dict] = None) -> str: + """ + Format the chat response to match the original Chat tool exactly. + """ + return ( + f"{response}\n\n---\n\nAGENT'S TURN: Evaluate this perspective alongside your analysis to " + "form a comprehensive solution and continue with the user's request and task at hand." + ) + + def get_websearch_guidance(self) -> str: + """ + Return Chat tool-style web search guidance. + """ + return self.get_chat_style_websearch_guidance() diff --git a/tools/contentvariant.py b/tools/contentvariant.py new file mode 100644 index 0000000..92e5fd8 --- /dev/null +++ b/tools/contentvariant.py @@ -0,0 +1,180 @@ +"""Content Variant Generator Tool + +Generates multiple variations of marketing content for A/B testing. +Supports testing different hooks, tones, lengths, and psychological angles. +""" + +from typing import Optional + +from pydantic import Field + +from config import TEMPERATURE_HIGHLY_CREATIVE +from systemprompts import CONTENTVARIANT_PROMPT +from tools.models import ToolModelCategory +from tools.shared.base_models import ToolRequest +from tools.simple.base import SimpleTool + + +class ContentVariantRequest(ToolRequest): + """Request model for Content Variant Generator""" + + content: str = Field( + ..., + description="Base content or topic to create variations from. Can be a draft post, subject line, or content concept.", + ) + variation_count: int = Field( + default=10, + ge=5, + le=25, + description="Number of variations to generate (5-25). Default is 10.", + ) + variation_types: Optional[list[str]] = Field( + default=None, + description="Types of variations to explore: 'hook', 'tone', 'length', 'structure', 'cta', 'angle'. Leave empty for mixed approach.", + ) + platform: Optional[str] = Field( + default=None, + description="Target platform for character limits and formatting: 'twitter', 'bluesky', 'linkedin', 'instagram', 'facebook', 'email_subject', 'blog'", + ) + constraints: Optional[str] = Field( + default=None, + description="Additional constraints: character limits, style requirements, brand voice guidelines, prohibited words/phrases", + ) + testing_angles: Optional[list[str]] = Field( + default=None, + description="Specific psychological angles to test: 'curiosity', 'contrarian', 'knowledge_gap', 'urgency', 'insider', 'problem_solution', 'social_proof', 'transformation'", + ) + + +class ContentVariantTool(SimpleTool): + """Generate multiple content variations for A/B testing""" + + def get_name(self) -> str: + return "contentvariant" + + def get_description(self) -> str: + return ( + "Generate 5-25 variations of marketing content for A/B testing. " + "Tests different hooks, tones, lengths, and psychological angles. " + "Ideal for subject lines, social posts, email copy, and ads. " + "Each variation includes testing rationale and predicted audience response." + ) + + def get_system_prompt(self) -> str: + return CONTENTVARIANT_PROMPT + + def get_default_temperature(self) -> float: + return TEMPERATURE_HIGHLY_CREATIVE + + def get_model_category(self) -> ToolModelCategory: + return ToolModelCategory.FAST_RESPONSE + + def get_request_model(self): + return ContentVariantRequest + + def format_request_for_model(self, request: ContentVariantRequest) -> dict: + """Format the content variant request for the AI model""" + prompt_parts = [f"Generate {request.variation_count} variations of this content:"] + prompt_parts.append(f"\n**Base Content:**\n{request.content}") + + if request.platform: + prompt_parts.append(f"\n**Target Platform:** {request.platform}") + + if request.variation_types: + prompt_parts.append(f"\n**Variation Types:** {', '.join(request.variation_types)}") + + if request.testing_angles: + prompt_parts.append(f"\n**Testing Angles:** {', '.join(request.testing_angles)}") + + if request.constraints: + prompt_parts.append(f"\n**Constraints:** {request.constraints}") + + prompt_parts.append( + "\nGenerate variations with clear labels, character counts (if platform specified), " + "testing angles, and predicted audience responses." + ) + + formatted_request = { + "prompt": "\n".join(prompt_parts), + "files": request.files or [], + "images": request.images or [], + "continuation_id": request.continuation_id, + "model": request.model, + "temperature": request.temperature, + "thinking_mode": request.thinking_mode, + "use_websearch": request.use_websearch, + } + + return formatted_request + + def get_input_schema(self) -> dict: + """Return the JSON schema for this tool's input""" + return { + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "Base content or topic to create variations from", + }, + "variation_count": { + "type": "integer", + "description": "Number of variations to generate (5-25, default 10)", + "minimum": 5, + "maximum": 25, + "default": 10, + }, + "variation_types": { + "type": "array", + "items": {"type": "string"}, + "description": "Types: 'hook', 'tone', 'length', 'structure', 'cta', 'angle'", + }, + "platform": { + "type": "string", + "description": "Platform: 'twitter', 'bluesky', 'linkedin', 'instagram', 'facebook', 'email_subject', 'blog'", + }, + "constraints": { + "type": "string", + "description": "Character limits, style requirements, brand guidelines", + }, + "testing_angles": { + "type": "array", + "items": {"type": "string"}, + "description": "Angles: 'curiosity', 'contrarian', 'knowledge_gap', 'urgency', 'insider', 'problem_solution', 'social_proof', 'transformation'", + }, + "files": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional brand guidelines or style reference files", + }, + "images": { + "type": "array", + "items": {"type": "string"}, + "description": "Optional visual assets for context", + }, + "continuation_id": { + "type": "string", + "description": "Thread ID to continue previous conversation", + }, + "model": { + "type": "string", + "description": "AI model to use (leave empty for default fast model)", + }, + "temperature": { + "type": "number", + "description": "Creativity level 0.0-1.0 (default 0.8 for high variation)", + "minimum": 0.0, + "maximum": 1.0, + }, + "thinking_mode": { + "type": "string", + "description": "Thinking depth: minimal, low, medium, high, max", + "enum": ["minimal", "low", "medium", "high", "max"], + }, + "use_websearch": { + "type": "boolean", + "description": "Enable web search for current platform best practices", + "default": False, + }, + }, + "required": ["content"], + } diff --git a/tools/listmodels.py b/tools/listmodels.py new file mode 100644 index 0000000..18e94aa --- /dev/null +++ b/tools/listmodels.py @@ -0,0 +1,299 @@ +""" +List Models Tool - Display all available models organized by provider + +This tool provides a comprehensive view of all AI models available in the system, +organized by their provider (Gemini, OpenAI, X.AI, OpenRouter, Custom). +It shows which providers are configured and what models can be used. +""" + +import logging +import os +from typing import Any, Optional + +from mcp.types import TextContent + +from tools.models import ToolModelCategory, ToolOutput +from tools.shared.base_models import ToolRequest +from tools.shared.base_tool import BaseTool + +logger = logging.getLogger(__name__) + + +class ListModelsTool(BaseTool): + """ + Tool for listing all available AI models organized by provider. + + This tool helps users understand: + - Which providers are configured (have API keys) + - What models are available from each provider + - Model aliases and their full names + - Context window sizes and capabilities + """ + + def get_name(self) -> str: + return "listmodels" + + def get_description(self) -> str: + return "Shows which AI model providers are configured, available model names, their aliases and capabilities." + + def get_input_schema(self) -> dict[str, Any]: + """Return the JSON schema for the tool's input""" + return { + "type": "object", + "properties": {"model": {"type": "string", "description": "Model to use (ignored by listmodels tool)"}}, + "required": [], + } + + def get_annotations(self) -> Optional[dict[str, Any]]: + """Return tool annotations indicating this is a read-only tool""" + return {"readOnlyHint": True} + + def get_system_prompt(self) -> str: + """No AI model needed for this tool""" + return "" + + def get_request_model(self): + """Return the Pydantic model for request validation.""" + return ToolRequest + + def requires_model(self) -> bool: + return False + + async def prepare_prompt(self, request: ToolRequest) -> str: + """Not used for this utility tool""" + return "" + + def format_response(self, response: str, request: ToolRequest, model_info: Optional[dict] = None) -> str: + """Not used for this utility tool""" + return response + + async def execute(self, arguments: dict[str, Any]) -> list[TextContent]: + """ + List all available models organized by provider. + + This overrides the base class execute to provide direct output without AI model calls. + + Args: + arguments: Standard tool arguments (none required) + + Returns: + Formatted list of models by provider + """ + from providers.openrouter_registry import OpenRouterModelRegistry + from providers.registry import ModelProviderRegistry + from providers.shared import ProviderType + + output_lines = ["# Available AI Models\n"] + + # Map provider types to friendly names and their models + provider_info = { + ProviderType.GOOGLE: {"name": "Google Gemini", "env_key": "GEMINI_API_KEY"}, + ProviderType.OPENAI: {"name": "OpenAI", "env_key": "OPENAI_API_KEY"}, + ProviderType.XAI: {"name": "X.AI (Grok)", "env_key": "XAI_API_KEY"}, + ProviderType.DIAL: {"name": "AI DIAL", "env_key": "DIAL_API_KEY"}, + } + + # Check each native provider type + for provider_type, info in provider_info.items(): + # Check if provider is enabled + provider = ModelProviderRegistry.get_provider(provider_type) + is_configured = provider is not None + + output_lines.append(f"## {info['name']} {'✅' if is_configured else '❌'}") + + if is_configured: + output_lines.append("**Status**: Configured and available") + output_lines.append("\n**Models**:") + + aliases = [] + for model_name, capabilities in provider.get_all_model_capabilities().items(): + description = capabilities.description or "No description available" + context_window = capabilities.context_window + + if context_window >= 1_000_000: + context_str = f"{context_window // 1_000_000}M context" + elif context_window >= 1_000: + context_str = f"{context_window // 1_000}K context" + else: + context_str = f"{context_window} context" if context_window > 0 else "unknown context" + + output_lines.append(f"- `{model_name}` - {context_str}") + output_lines.append(f" - {description}") + + for alias in capabilities.aliases or []: + if alias != model_name: + aliases.append(f"- `{alias}` → `{model_name}`") + + if aliases: + output_lines.append("\n**Aliases**:") + output_lines.extend(sorted(aliases)) + else: + output_lines.append(f"**Status**: Not configured (set {info['env_key']})") + + output_lines.append("") + + # Check OpenRouter + openrouter_key = os.getenv("OPENROUTER_API_KEY") + is_openrouter_configured = openrouter_key and openrouter_key != "your_openrouter_api_key_here" + + output_lines.append(f"## OpenRouter {'✅' if is_openrouter_configured else '❌'}") + + if is_openrouter_configured: + output_lines.append("**Status**: Configured and available") + output_lines.append("**Description**: Access to multiple cloud AI providers via unified API") + + try: + # Get OpenRouter provider from registry to properly apply restrictions + from providers.registry import ModelProviderRegistry + from providers.shared import ProviderType + + provider = ModelProviderRegistry.get_provider(ProviderType.OPENROUTER) + if provider: + # Get models with restrictions applied + available_models = provider.list_models(respect_restrictions=True) + registry = OpenRouterModelRegistry() + + # Group by provider for better organization + providers_models = {} + for model_name in available_models: # Show ALL available models + # Try to resolve to get config details + config = registry.resolve(model_name) + if config: + # Extract provider from model_name + provider_name = config.model_name.split("/")[0] if "/" in config.model_name else "other" + if provider_name not in providers_models: + providers_models[provider_name] = [] + providers_models[provider_name].append((model_name, config)) + else: + # Model without config - add with basic info + provider_name = model_name.split("/")[0] if "/" in model_name else "other" + if provider_name not in providers_models: + providers_models[provider_name] = [] + providers_models[provider_name].append((model_name, None)) + + output_lines.append("\n**Available Models**:") + for provider_name, models in sorted(providers_models.items()): + output_lines.append(f"\n*{provider_name.title()}:*") + for alias, config in models: # Show ALL models from each provider + if config: + context_str = f"{config.context_window // 1000}K" if config.context_window else "?" + output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)") + else: + output_lines.append(f"- `{alias}`") + + total_models = len(available_models) + # Show all models - no truncation message needed + + # Check if restrictions are applied + restriction_service = None + try: + from utils.model_restrictions import get_restriction_service + + restriction_service = get_restriction_service() + if restriction_service.has_restrictions(ProviderType.OPENROUTER): + allowed_set = restriction_service.get_allowed_models(ProviderType.OPENROUTER) + output_lines.append( + f"\n**Note**: Restricted to models matching: {', '.join(sorted(allowed_set))}" + ) + except Exception as e: + logger.warning(f"Error checking OpenRouter restrictions: {e}") + else: + output_lines.append("**Error**: Could not load OpenRouter provider") + + except Exception as e: + output_lines.append(f"**Error loading models**: {str(e)}") + else: + output_lines.append("**Status**: Not configured (set OPENROUTER_API_KEY)") + output_lines.append("**Note**: Provides access to GPT-5, O3, Mistral, and many more") + + output_lines.append("") + + # Check Custom API + custom_url = os.getenv("CUSTOM_API_URL") + + output_lines.append(f"## Custom/Local API {'✅' if custom_url else '❌'}") + + if custom_url: + output_lines.append("**Status**: Configured and available") + output_lines.append(f"**Endpoint**: {custom_url}") + output_lines.append("**Description**: Local models via Ollama, vLLM, LM Studio, etc.") + + try: + registry = OpenRouterModelRegistry() + custom_models = [] + + for alias in registry.list_aliases(): + config = registry.resolve(alias) + if config and config.is_custom: + custom_models.append((alias, config)) + + if custom_models: + output_lines.append("\n**Custom Models**:") + for alias, config in custom_models: + context_str = f"{config.context_window // 1000}K" if config.context_window else "?" + output_lines.append(f"- `{alias}` → `{config.model_name}` ({context_str} context)") + if config.description: + output_lines.append(f" - {config.description}") + + except Exception as e: + output_lines.append(f"**Error loading custom models**: {str(e)}") + else: + output_lines.append("**Status**: Not configured (set CUSTOM_API_URL)") + output_lines.append("**Example**: CUSTOM_API_URL=http://localhost:11434 (for Ollama)") + + output_lines.append("") + + # Add summary + output_lines.append("## Summary") + + # Count configured providers + configured_count = sum( + [ + 1 + for provider_type, info in provider_info.items() + if ModelProviderRegistry.get_provider(provider_type) is not None + ] + ) + if is_openrouter_configured: + configured_count += 1 + if custom_url: + configured_count += 1 + + output_lines.append(f"**Configured Providers**: {configured_count}") + + # Get total available models + try: + from providers.registry import ModelProviderRegistry + + # Get all available models respecting restrictions + available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True) + total_models = len(available_models) + output_lines.append(f"**Total Available Models**: {total_models}") + except Exception as e: + logger.warning(f"Error getting total available models: {e}") + + # Add usage tips + output_lines.append("\n**Usage Tips**:") + output_lines.append("- Use model aliases (e.g., 'flash', 'gpt5', 'opus') for convenience") + output_lines.append("- In auto mode, the CLI Agent will select the best model for each task") + output_lines.append("- Custom models are only available when CUSTOM_API_URL is set") + output_lines.append("- OpenRouter provides access to many cloud models with one API key") + + # Format output + content = "\n".join(output_lines) + + tool_output = ToolOutput( + status="success", + content=content, + content_type="text", + metadata={ + "tool_name": self.name, + "configured_providers": configured_count, + }, + ) + + return [TextContent(type="text", text=tool_output.model_dump_json())] + + def get_model_category(self) -> ToolModelCategory: + """Return the model category for this tool.""" + return ToolModelCategory.FAST_RESPONSE # Simple listing, no AI needed diff --git a/tools/models.py b/tools/models.py new file mode 100644 index 0000000..7b431d5 --- /dev/null +++ b/tools/models.py @@ -0,0 +1,373 @@ +""" +Data models for tool responses and interactions +""" + +from enum import Enum +from typing import Any, Literal, Optional + +from pydantic import BaseModel, Field + + +class ToolModelCategory(Enum): + """Categories for tool model selection based on requirements.""" + + EXTENDED_REASONING = "extended_reasoning" # Requires deep thinking capabilities + FAST_RESPONSE = "fast_response" # Speed and cost efficiency preferred + BALANCED = "balanced" # Balance of capability and performance + + +class ContinuationOffer(BaseModel): + """Offer for CLI agent to continue conversation when Gemini doesn't ask follow-up""" + + continuation_id: str = Field( + ..., description="Thread continuation ID for multi-turn conversations across different tools" + ) + note: str = Field(..., description="Message explaining continuation opportunity to CLI agent") + remaining_turns: int = Field(..., description="Number of conversation turns remaining") + + +class ToolOutput(BaseModel): + """Standardized output format for all tools""" + + status: Literal[ + "success", + "error", + "files_required_to_continue", + "full_codereview_required", + "focused_review_required", + "test_sample_needed", + "more_tests_required", + "refactor_analysis_complete", + "trace_complete", + "resend_prompt", + "code_too_large", + "continuation_available", + "no_bug_found", + ] = "success" + content: Optional[str] = Field(None, description="The main content/response from the tool") + content_type: Literal["text", "markdown", "json"] = "text" + metadata: Optional[dict[str, Any]] = Field(default_factory=dict) + continuation_offer: Optional[ContinuationOffer] = Field( + None, description="Optional offer for Agent to continue conversation" + ) + + +class FilesNeededRequest(BaseModel): + """Request for missing files / code to continue""" + + status: Literal["files_required_to_continue"] = "files_required_to_continue" + mandatory_instructions: str = Field(..., description="Critical instructions for Agent regarding required context") + files_needed: Optional[list[str]] = Field( + default_factory=list, description="Specific files that are needed for analysis" + ) + suggested_next_action: Optional[dict[str, Any]] = Field( + None, + description="Suggested tool call with parameters after getting clarification", + ) + + +class FullCodereviewRequired(BaseModel): + """Request for full code review when scope is too large for quick review""" + + status: Literal["full_codereview_required"] = "full_codereview_required" + important: Optional[str] = Field(None, description="Important message about escalation") + reason: Optional[str] = Field(None, description="Reason why full review is needed") + + +class FocusedReviewRequired(BaseModel): + """Request for Agent to provide smaller, focused subsets of code for review""" + + status: Literal["focused_review_required"] = "focused_review_required" + reason: str = Field(..., description="Why the current scope is too large for effective review") + suggestion: str = Field( + ..., description="Suggested approach for breaking down the review into smaller, focused parts" + ) + + +class TestSampleNeeded(BaseModel): + """Request for additional test samples to determine testing framework""" + + status: Literal["test_sample_needed"] = "test_sample_needed" + reason: str = Field(..., description="Reason why additional test samples are required") + + +class MoreTestsRequired(BaseModel): + """Request for continuation to generate additional tests""" + + status: Literal["more_tests_required"] = "more_tests_required" + pending_tests: str = Field(..., description="List of pending tests to be generated") + + +class RefactorOpportunity(BaseModel): + """A single refactoring opportunity with precise targeting information""" + + id: str = Field(..., description="Unique identifier for this refactoring opportunity") + type: Literal["decompose", "codesmells", "modernize", "organization"] = Field( + ..., description="Type of refactoring" + ) + severity: Literal["critical", "high", "medium", "low"] = Field(..., description="Severity level") + file: str = Field(..., description="Absolute path to the file") + start_line: int = Field(..., description="Starting line number") + end_line: int = Field(..., description="Ending line number") + context_start_text: str = Field(..., description="Exact text from start line for verification") + context_end_text: str = Field(..., description="Exact text from end line for verification") + issue: str = Field(..., description="Clear description of what needs refactoring") + suggestion: str = Field(..., description="Specific refactoring action to take") + rationale: str = Field(..., description="Why this improves the code") + code_to_replace: str = Field(..., description="Original code that should be changed") + replacement_code_snippet: str = Field(..., description="Refactored version of the code") + new_code_snippets: Optional[list[dict]] = Field( + default_factory=list, description="Additional code snippets to be added" + ) + + +class RefactorAction(BaseModel): + """Next action for Agent to implement refactoring""" + + action_type: Literal["EXTRACT_METHOD", "SPLIT_CLASS", "MODERNIZE_SYNTAX", "REORGANIZE_CODE", "DECOMPOSE_FILE"] = ( + Field(..., description="Type of action to perform") + ) + target_file: str = Field(..., description="Absolute path to target file") + source_lines: str = Field(..., description="Line range (e.g., '45-67')") + description: str = Field(..., description="Step-by-step action description for CLI Agent") + + +class RefactorAnalysisComplete(BaseModel): + """Complete refactor analysis with prioritized opportunities""" + + status: Literal["refactor_analysis_complete"] = "refactor_analysis_complete" + refactor_opportunities: list[RefactorOpportunity] = Field(..., description="List of refactoring opportunities") + priority_sequence: list[str] = Field(..., description="Recommended order of refactoring IDs") + next_actions: list[RefactorAction] = Field(..., description="Specific actions for the agent to implement") + + +class CodeTooLargeRequest(BaseModel): + """Request to reduce file selection due to size constraints""" + + status: Literal["code_too_large"] = "code_too_large" + content: str = Field(..., description="Message explaining the size constraint") + content_type: Literal["text"] = "text" + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ResendPromptRequest(BaseModel): + """Request to resend prompt via file due to size limits""" + + status: Literal["resend_prompt"] = "resend_prompt" + content: str = Field(..., description="Instructions for handling large prompt") + content_type: Literal["text"] = "text" + metadata: dict[str, Any] = Field(default_factory=dict) + + +class TraceEntryPoint(BaseModel): + """Entry point information for trace analysis""" + + file: str = Field(..., description="Absolute path to the file") + class_or_struct: str = Field(..., description="Class or module name") + method: str = Field(..., description="Method or function name") + signature: str = Field(..., description="Full method signature") + parameters: Optional[dict[str, Any]] = Field(default_factory=dict, description="Parameter values used in analysis") + + +class TraceTarget(BaseModel): + """Target information for dependency analysis""" + + file: str = Field(..., description="Absolute path to the file") + class_or_struct: str = Field(..., description="Class or module name") + method: str = Field(..., description="Method or function name") + signature: str = Field(..., description="Full method signature") + + +class CallPathStep(BaseModel): + """A single step in the call path trace""" + + from_info: dict[str, Any] = Field(..., description="Source location information", alias="from") + to: dict[str, Any] = Field(..., description="Target location information") + reason: str = Field(..., description="Reason for the call or dependency") + condition: Optional[str] = Field(None, description="Conditional logic if applicable") + ambiguous: bool = Field(False, description="Whether this call is ambiguous") + + +class BranchingPoint(BaseModel): + """A branching point in the execution flow""" + + file: str = Field(..., description="File containing the branching point") + method: str = Field(..., description="Method containing the branching point") + line: int = Field(..., description="Line number of the branching point") + condition: str = Field(..., description="Branching condition") + branches: list[str] = Field(..., description="Possible execution branches") + ambiguous: bool = Field(False, description="Whether the branching is ambiguous") + + +class SideEffect(BaseModel): + """A side effect detected in the trace""" + + type: str = Field(..., description="Type of side effect") + description: str = Field(..., description="Description of the side effect") + file: str = Field(..., description="File where the side effect occurs") + method: str = Field(..., description="Method where the side effect occurs") + line: int = Field(..., description="Line number of the side effect") + + +class UnresolvedDependency(BaseModel): + """An unresolved dependency in the trace""" + + reason: str = Field(..., description="Reason why the dependency is unresolved") + affected_file: str = Field(..., description="File affected by the unresolved dependency") + line: int = Field(..., description="Line number of the unresolved dependency") + + +class IncomingDependency(BaseModel): + """An incoming dependency (what calls this target)""" + + from_file: str = Field(..., description="Source file of the dependency") + from_class: str = Field(..., description="Source class of the dependency") + from_method: str = Field(..., description="Source method of the dependency") + line: int = Field(..., description="Line number of the dependency") + type: str = Field(..., description="Type of dependency") + + +class OutgoingDependency(BaseModel): + """An outgoing dependency (what this target calls)""" + + to_file: str = Field(..., description="Target file of the dependency") + to_class: str = Field(..., description="Target class of the dependency") + to_method: str = Field(..., description="Target method of the dependency") + line: int = Field(..., description="Line number of the dependency") + type: str = Field(..., description="Type of dependency") + + +class TypeDependency(BaseModel): + """A type-level dependency (inheritance, imports, etc.)""" + + dependency_type: str = Field(..., description="Type of dependency") + source_file: str = Field(..., description="Source file of the dependency") + source_entity: str = Field(..., description="Source entity (class, module)") + target: str = Field(..., description="Target entity") + + +class StateAccess(BaseModel): + """State access information""" + + file: str = Field(..., description="File where state is accessed") + method: str = Field(..., description="Method accessing the state") + access_type: str = Field(..., description="Type of access (reads, writes, etc.)") + state_entity: str = Field(..., description="State entity being accessed") + + +class TraceComplete(BaseModel): + """Complete trace analysis response""" + + status: Literal["trace_complete"] = "trace_complete" + trace_type: Literal["precision", "dependencies"] = Field(..., description="Type of trace performed") + + # Precision mode fields + entry_point: Optional[TraceEntryPoint] = Field(None, description="Entry point for precision trace") + call_path: Optional[list[CallPathStep]] = Field(default_factory=list, description="Call path for precision trace") + branching_points: Optional[list[BranchingPoint]] = Field(default_factory=list, description="Branching points") + side_effects: Optional[list[SideEffect]] = Field(default_factory=list, description="Side effects detected") + unresolved: Optional[list[UnresolvedDependency]] = Field( + default_factory=list, description="Unresolved dependencies" + ) + + # Dependencies mode fields + target: Optional[TraceTarget] = Field(None, description="Target for dependency analysis") + incoming_dependencies: Optional[list[IncomingDependency]] = Field( + default_factory=list, description="Incoming dependencies" + ) + outgoing_dependencies: Optional[list[OutgoingDependency]] = Field( + default_factory=list, description="Outgoing dependencies" + ) + type_dependencies: Optional[list[TypeDependency]] = Field(default_factory=list, description="Type dependencies") + state_access: Optional[list[StateAccess]] = Field(default_factory=list, description="State access information") + + +class DiagnosticHypothesis(BaseModel): + """A debugging hypothesis with context and next steps""" + + rank: int = Field(..., description="Ranking of this hypothesis (1 = most likely)") + confidence: Literal["high", "medium", "low"] = Field(..., description="Confidence level") + hypothesis: str = Field(..., description="Description of the potential root cause") + reasoning: str = Field(..., description="Why this hypothesis is plausible") + next_step: str = Field(..., description="Suggested action to test/validate this hypothesis") + + +class StructuredDebugResponse(BaseModel): + """Enhanced debug response with multiple hypotheses""" + + summary: str = Field(..., description="Brief summary of the issue") + hypotheses: list[DiagnosticHypothesis] = Field(..., description="Ranked list of potential causes") + immediate_actions: list[str] = Field( + default_factory=list, + description="Immediate steps to take regardless of root cause", + ) + additional_context_needed: Optional[list[str]] = Field( + default_factory=list, + description="Additional files or information that would help with analysis", + ) + + +class DebugHypothesis(BaseModel): + """A debugging hypothesis with detailed analysis""" + + name: str = Field(..., description="Name/title of the hypothesis") + confidence: Literal["High", "Medium", "Low"] = Field(..., description="Confidence level") + root_cause: str = Field(..., description="Technical explanation of the root cause") + evidence: str = Field(..., description="Logs or code clues supporting this hypothesis") + correlation: str = Field(..., description="How symptoms map to the cause") + validation: str = Field(..., description="Quick test to confirm the hypothesis") + minimal_fix: str = Field(..., description="Smallest change to resolve the issue") + regression_check: str = Field(..., description="Why this fix is safe") + file_references: list[str] = Field(default_factory=list, description="File:line format for exact locations") + + +class DebugAnalysisComplete(BaseModel): + """Complete debugging analysis with systematic investigation tracking""" + + status: Literal["analysis_complete"] = "analysis_complete" + investigation_id: str = Field(..., description="Auto-generated unique ID for this investigation") + summary: str = Field(..., description="Brief description of the problem and its impact") + investigation_steps: list[str] = Field(..., description="Steps taken during the investigation") + hypotheses: list[DebugHypothesis] = Field(..., description="Ranked hypotheses with detailed analysis") + key_findings: list[str] = Field(..., description="Important discoveries made during analysis") + immediate_actions: list[str] = Field(..., description="Steps to take regardless of which hypothesis is correct") + recommended_tools: list[str] = Field(default_factory=list, description="Additional tools recommended for analysis") + prevention_strategy: Optional[str] = Field( + None, description="Targeted measures to prevent this exact issue from recurring" + ) + investigation_summary: str = Field( + ..., description="Comprehensive summary of the complete investigation process and conclusions" + ) + + +class NoBugFound(BaseModel): + """Response when thorough investigation finds no concrete evidence of a bug""" + + status: Literal["no_bug_found"] = "no_bug_found" + summary: str = Field(..., description="Summary of what was thoroughly investigated") + investigation_steps: list[str] = Field(..., description="Steps taken during the investigation") + areas_examined: list[str] = Field(..., description="Code areas and potential failure points examined") + confidence_level: Literal["High", "Medium", "Low"] = Field( + ..., description="Confidence level in the no-bug finding" + ) + alternative_explanations: list[str] = Field( + ..., description="Possible alternative explanations for reported symptoms" + ) + recommended_questions: list[str] = Field(..., description="Questions to clarify the issue with the user") + next_steps: list[str] = Field(..., description="Suggested actions to better understand the reported issue") + + +# Registry mapping status strings to their corresponding Pydantic models +SPECIAL_STATUS_MODELS = { + "files_required_to_continue": FilesNeededRequest, + "full_codereview_required": FullCodereviewRequired, + "focused_review_required": FocusedReviewRequired, + "test_sample_needed": TestSampleNeeded, + "more_tests_required": MoreTestsRequired, + "refactor_analysis_complete": RefactorAnalysisComplete, + "trace_complete": TraceComplete, + "resend_prompt": ResendPromptRequest, + "code_too_large": CodeTooLargeRequest, + "analysis_complete": DebugAnalysisComplete, + "no_bug_found": NoBugFound, +} diff --git a/tools/shared/__init__.py b/tools/shared/__init__.py new file mode 100644 index 0000000..e486150 --- /dev/null +++ b/tools/shared/__init__.py @@ -0,0 +1,19 @@ +""" +Shared infrastructure for Zen MCP tools. + +This module contains the core base classes and utilities that are shared +across all tool types. It provides the foundation for the tool architecture. +""" + +from .base_models import BaseWorkflowRequest, ConsolidatedFindings, ToolRequest, WorkflowRequest +from .base_tool import BaseTool +from .schema_builders import SchemaBuilder + +__all__ = [ + "BaseTool", + "ToolRequest", + "BaseWorkflowRequest", + "WorkflowRequest", + "ConsolidatedFindings", + "SchemaBuilder", +] diff --git a/tools/shared/base_models.py b/tools/shared/base_models.py new file mode 100644 index 0000000..a6c7a3c --- /dev/null +++ b/tools/shared/base_models.py @@ -0,0 +1,165 @@ +""" +Base models for Zen MCP tools. + +This module contains the shared Pydantic models used across all tools, +extracted to avoid circular imports and promote code reuse. + +Key Models: +- ToolRequest: Base request model for all tools +- WorkflowRequest: Extended request model for workflow-based tools +- ConsolidatedFindings: Model for tracking workflow progress +""" + +import logging +from typing import Optional + +from pydantic import BaseModel, Field, field_validator + +logger = logging.getLogger(__name__) + + +# Shared field descriptions to avoid duplication +COMMON_FIELD_DESCRIPTIONS = { + "model": "Model to run. Supply a name if requested by the user or stay in auto mode. When in auto mode, use `listmodels` tool for model discovery.", + "temperature": "0 = deterministic · 1 = creative.", + "thinking_mode": "Reasoning depth: minimal, low, medium, high, or max.", + "continuation_id": ( + "Unique thread continuation ID for multi-turn conversations. Works across different tools. " + "ALWAYS reuse the last continuation_id you were given—this preserves full conversation context, " + "files, and findings so the agent can resume seamlessly." + ), + "images": "Optional absolute image paths or base64 blobs for visual context.", + "files": "Optional absolute file or folder paths (do not shorten).", +} + +# Workflow-specific field descriptions +WORKFLOW_FIELD_DESCRIPTIONS = { + "step": "Current work step content and findings from your overall work", + "step_number": "Current step number in work sequence (starts at 1)", + "total_steps": "Estimated total steps needed to complete work", + "next_step_required": "Whether another work step is needed. When false, aim to reduce total_steps to match step_number to avoid mismatch.", + "findings": "Important findings, evidence and insights discovered in this step", + "files_checked": "List of files examined during this work step", + "relevant_files": "Files identified as relevant to issue/goal (FULL absolute paths to real files/folders - DO NOT SHORTEN)", + "relevant_context": "Methods/functions identified as involved in the issue", + "issues_found": "Issues identified with severity levels during work", + "confidence": ( + "Confidence level: exploring (just starting), low (early investigation), " + "medium (some evidence), high (strong evidence), very_high (comprehensive understanding), " + "almost_certain (near complete confidence), certain (100% confidence locally - no external validation needed)" + ), + "hypothesis": "Current theory about issue/goal based on work", + "backtrack_from_step": "Step number to backtrack from if work needs revision", + "use_assistant_model": ( + "Use assistant model for expert analysis after workflow steps. " + "False skips expert analysis, relies solely on Claude's investigation. " + "Defaults to True for comprehensive validation." + ), +} + + +class ToolRequest(BaseModel): + """ + Base request model for all Zen MCP tools. + + This model defines common fields that all tools accept, including + model selection, temperature control, and conversation threading. + Tool-specific request models should inherit from this class. + """ + + # Model configuration + model: Optional[str] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["model"]) + temperature: Optional[float] = Field(None, ge=0.0, le=1.0, description=COMMON_FIELD_DESCRIPTIONS["temperature"]) + thinking_mode: Optional[str] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["thinking_mode"]) + + # Conversation support + continuation_id: Optional[str] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["continuation_id"]) + + # Visual context + images: Optional[list[str]] = Field(None, description=COMMON_FIELD_DESCRIPTIONS["images"]) + + +class BaseWorkflowRequest(ToolRequest): + """ + Minimal base request model for workflow tools. + + This provides only the essential fields that ALL workflow tools need, + allowing for maximum flexibility in tool-specific implementations. + """ + + # Core workflow fields that ALL workflow tools need + step: str = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["step"]) + step_number: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["step_number"]) + total_steps: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["total_steps"]) + next_step_required: bool = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"]) + + +class WorkflowRequest(BaseWorkflowRequest): + """ + Extended request model for workflow-based tools. + + This model extends ToolRequest with fields specific to the workflow + pattern, where tools perform multi-step work with forced pauses between steps. + + Used by: debug, precommit, codereview, refactor, thinkdeep, analyze + """ + + # Required workflow fields + step: str = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["step"]) + step_number: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["step_number"]) + total_steps: int = Field(..., ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["total_steps"]) + next_step_required: bool = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"]) + + # Work tracking fields + findings: str = Field(..., description=WORKFLOW_FIELD_DESCRIPTIONS["findings"]) + files_checked: list[str] = Field(default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["files_checked"]) + relevant_files: list[str] = Field(default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"]) + relevant_context: list[str] = Field( + default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["relevant_context"] + ) + issues_found: list[dict] = Field(default_factory=list, description=WORKFLOW_FIELD_DESCRIPTIONS["issues_found"]) + confidence: str = Field("low", description=WORKFLOW_FIELD_DESCRIPTIONS["confidence"]) + + # Optional workflow fields + hypothesis: Optional[str] = Field(None, description=WORKFLOW_FIELD_DESCRIPTIONS["hypothesis"]) + backtrack_from_step: Optional[int] = Field( + None, ge=1, description=WORKFLOW_FIELD_DESCRIPTIONS["backtrack_from_step"] + ) + use_assistant_model: Optional[bool] = Field(True, description=WORKFLOW_FIELD_DESCRIPTIONS["use_assistant_model"]) + + @field_validator("files_checked", "relevant_files", "relevant_context", mode="before") + @classmethod + def convert_string_to_list(cls, v): + """Convert string inputs to empty lists to handle malformed inputs gracefully.""" + if isinstance(v, str): + logger.warning(f"Field received string '{v}' instead of list, converting to empty list") + return [] + return v + + +class ConsolidatedFindings(BaseModel): + """ + Model for tracking consolidated findings across workflow steps. + + This model accumulates findings, files, methods, and issues + discovered during multi-step work. It's used by + BaseWorkflowMixin to track progress across workflow steps. + """ + + files_checked: set[str] = Field(default_factory=set, description="All files examined across all steps") + relevant_files: set[str] = Field( + default_factory=set, + description="Subset of files_checked identified as relevant for work at hand", + ) + relevant_context: set[str] = Field( + default_factory=set, description="All methods/functions identified during overall work" + ) + findings: list[str] = Field(default_factory=list, description="Chronological findings from each work step") + hypotheses: list[dict] = Field(default_factory=list, description="Evolution of hypotheses across steps") + issues_found: list[dict] = Field(default_factory=list, description="All issues with severity levels") + images: list[str] = Field(default_factory=list, description="Images collected during work") + confidence: str = Field("low", description="Latest confidence level from steps") + + +# Tool-specific field descriptions are now declared in each tool file +# This keeps concerns separated and makes each tool self-contained diff --git a/tools/shared/base_tool.py b/tools/shared/base_tool.py new file mode 100644 index 0000000..4872026 --- /dev/null +++ b/tools/shared/base_tool.py @@ -0,0 +1,1399 @@ +""" +Core Tool Infrastructure for Zen MCP Tools + +This module provides the fundamental base class for all tools: +- BaseTool: Abstract base class defining the tool interface + +The BaseTool class defines the core contract that tools must implement and provides +common functionality for request validation, error handling, model management, +conversation handling, file processing, and response formatting. +""" + +import logging +import os +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional + +from mcp.types import TextContent + +if TYPE_CHECKING: + from tools.models import ToolModelCategory + +from config import MCP_PROMPT_SIZE_LIMIT +from providers import ModelProvider, ModelProviderRegistry +from utils import check_token_limit +from utils.conversation_memory import ( + ConversationTurn, + get_conversation_file_list, + get_thread, +) +from utils.file_utils import read_file_content, read_files + +# Import models from tools.models for compatibility +try: + from tools.models import SPECIAL_STATUS_MODELS, ContinuationOffer, ToolOutput +except ImportError: + # Fallback in case models haven't been set up yet + SPECIAL_STATUS_MODELS = {} + ContinuationOffer = None + ToolOutput = None + +logger = logging.getLogger(__name__) + + +class BaseTool(ABC): + """ + Abstract base class for all Zen MCP tools. + + This class defines the interface that all tools must implement and provides + common functionality for request handling, model creation, and response formatting. + + CONVERSATION-AWARE FILE PROCESSING: + This base class implements the sophisticated dual prioritization strategy for + conversation-aware file handling across all tools: + + 1. FILE DEDUPLICATION WITH NEWEST-FIRST PRIORITY: + - When same file appears in multiple conversation turns, newest reference wins + - Prevents redundant file embedding while preserving most recent file state + - Cross-tool file tracking ensures consistent behavior across analyze → codereview → debug + + 2. CONVERSATION CONTEXT INTEGRATION: + - All tools receive enhanced prompts with conversation history via reconstruct_thread_context() + - File references from previous turns are preserved and accessible + - Cross-tool knowledge transfer maintains full context without manual file re-specification + + 3. TOKEN-AWARE FILE EMBEDDING: + - Respects model-specific token allocation budgets from ModelContext + - Prioritizes conversation history, then newest files, then remaining content + - Graceful degradation when token limits are approached + + 4. STATELESS-TO-STATEFUL BRIDGING: + - Tools operate on stateless MCP requests but access full conversation state + - Conversation memory automatically injected via continuation_id parameter + - Enables natural AI-to-AI collaboration across tool boundaries + + To create a new tool: + 1. Create a new class that inherits from BaseTool + 2. Implement all abstract methods + 3. Define a request model that inherits from ToolRequest + 4. Register the tool in server.py's TOOLS dictionary + """ + + # Class-level cache for OpenRouter registry to avoid multiple loads + _openrouter_registry_cache = None + + @classmethod + def _get_openrouter_registry(cls): + """Get cached OpenRouter registry instance, creating if needed.""" + # Use BaseTool class directly to ensure cache is shared across all subclasses + if BaseTool._openrouter_registry_cache is None: + from providers.openrouter_registry import OpenRouterModelRegistry + + BaseTool._openrouter_registry_cache = OpenRouterModelRegistry() + logger.debug("Created cached OpenRouter registry instance") + return BaseTool._openrouter_registry_cache + + def __init__(self): + # Cache tool metadata at initialization to avoid repeated calls + self.name = self.get_name() + self.description = self.get_description() + self.default_temperature = self.get_default_temperature() + # Tool initialization complete + + @abstractmethod + def get_name(self) -> str: + """ + Return the unique name identifier for this tool. + + This name is used by MCP clients to invoke the tool and must be + unique across all registered tools. + + Returns: + str: The tool's unique name (e.g., "review_code", "analyze") + """ + pass + + @abstractmethod + def get_description(self) -> str: + """ + Return a detailed description of what this tool does. + + This description is shown to MCP clients (like Claude) to help them + understand when and how to use the tool. It should be comprehensive + and include trigger phrases. + + Returns: + str: Detailed tool description with usage examples + """ + pass + + @abstractmethod + def get_input_schema(self) -> dict[str, Any]: + """ + Return the JSON Schema that defines this tool's parameters. + + This schema is used by MCP clients to validate inputs before + sending requests. It should match the tool's request model. + + Returns: + Dict[str, Any]: JSON Schema object defining required and optional parameters + """ + pass + + @abstractmethod + def get_system_prompt(self) -> str: + """ + Return the system prompt that configures the AI model's behavior. + + This prompt sets the context and instructions for how the model + should approach the task. It's prepended to the user's request. + + Returns: + str: System prompt with role definition and instructions + """ + pass + + def get_annotations(self) -> Optional[dict[str, Any]]: + """ + Return optional annotations for this tool. + + Annotations provide hints about tool behavior without being security-critical. + They help MCP clients make better decisions about tool usage. + + Returns: + Optional[dict]: Dictionary with annotation fields like readOnlyHint, destructiveHint, etc. + Returns None if no annotations are needed. + """ + return None + + def requires_model(self) -> bool: + """ + Return whether this tool requires AI model access. + + Tools that override execute() to do pure data processing (like planner) + should return False to skip model resolution at the MCP boundary. + + Returns: + bool: True if tool needs AI model access (default), False for data-only tools + """ + return True + + def is_effective_auto_mode(self) -> bool: + """ + Check if we're in effective auto mode for schema generation. + + This determines whether the model parameter should be required in the tool schema. + Used at initialization time when schemas are generated. + + Returns: + bool: True if model parameter should be required in the schema + """ + from config import DEFAULT_MODEL + from providers.registry import ModelProviderRegistry + + # Case 1: Explicit auto mode + if DEFAULT_MODEL.lower() == "auto": + return True + + # Case 2: Model not available (fallback to auto mode) + if DEFAULT_MODEL.lower() != "auto": + provider = ModelProviderRegistry.get_provider_for_model(DEFAULT_MODEL) + if not provider: + return True + + return False + + def _should_require_model_selection(self, model_name: str) -> bool: + """ + Check if we should require the CLI to select a model at runtime. + + This is called during request execution to determine if we need + to return an error asking the CLI to provide a model parameter. + + Args: + model_name: The model name from the request or DEFAULT_MODEL + + Returns: + bool: True if we should require model selection + """ + # Case 1: Model is explicitly "auto" + if model_name.lower() == "auto": + return True + + # Case 2: Requested model is not available + from providers.registry import ModelProviderRegistry + + provider = ModelProviderRegistry.get_provider_for_model(model_name) + if not provider: + logger = logging.getLogger(f"tools.{self.name}") + logger.warning(f"Model '{model_name}' is not available with current API keys. Requiring model selection.") + return True + + return False + + def _get_available_models(self) -> list[str]: + """ + Get list of models available from enabled providers. + + Only returns models from providers that have valid API keys configured. + This fixes the namespace collision bug where models from disabled providers + were shown to the CLI, causing routing conflicts. + + Returns: + List of model names from enabled providers only + """ + from providers.registry import ModelProviderRegistry + + # Get models from enabled providers only (those with valid API keys) + all_models = ModelProviderRegistry.get_available_model_names() + + # Add OpenRouter models if OpenRouter is configured + openrouter_key = os.getenv("OPENROUTER_API_KEY") + if openrouter_key and openrouter_key != "your_openrouter_api_key_here": + try: + registry = self._get_openrouter_registry() + # Add all aliases from the registry (includes OpenRouter cloud models) + for alias in registry.list_aliases(): + if alias not in all_models: + all_models.append(alias) + except Exception as e: + import logging + + logging.debug(f"Failed to add OpenRouter models to enum: {e}") + + # Add custom models if custom API is configured + custom_url = os.getenv("CUSTOM_API_URL") + if custom_url: + try: + registry = self._get_openrouter_registry() + # Find all custom models (is_custom=true) + for alias in registry.list_aliases(): + config = registry.resolve(alias) + # Check if this is a custom model that requires custom endpoints + if config and config.is_custom: + if alias not in all_models: + all_models.append(alias) + except Exception as e: + import logging + + logging.debug(f"Failed to add custom models to enum: {e}") + + # Remove duplicates while preserving order + seen = set() + unique_models = [] + for model in all_models: + if model not in seen: + seen.add(model) + unique_models.append(model) + + return unique_models + + def _format_available_models_list(self) -> str: + """Return a human-friendly list of available models or guidance when none found.""" + + available_models = self._get_available_models() + if not available_models: + return ( + "No models detected. Configure provider credentials or set DEFAULT_MODEL to a valid option. " + "If the user requested a specific model, respond with this notice instead of substituting another model." + ) + return ", ".join(available_models) + + def _build_model_unavailable_message(self, model_name: str) -> str: + """Compose a consistent error message for unavailable model scenarios.""" + + tool_category = self.get_model_category() + suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category) + available_models_text = self._format_available_models_list() + + return ( + f"Model '{model_name}' is not available with current API keys. " + f"Available models: {available_models_text}. " + f"Suggested model for {self.get_name()}: '{suggested_model}' " + f"(category: {tool_category.value}). If the user explicitly requested a model, you MUST use that exact name or report this error back—do not substitute another model." + ) + + def _build_auto_mode_required_message(self) -> str: + """Compose the auto-mode prompt when an explicit model selection is required.""" + + tool_category = self.get_model_category() + suggested_model = ModelProviderRegistry.get_preferred_fallback_model(tool_category) + available_models_text = self._format_available_models_list() + + return ( + "Model parameter is required in auto mode. " + f"Available models: {available_models_text}. " + f"Suggested model for {self.get_name()}: '{suggested_model}' " + f"(category: {tool_category.value}). When the user names a model, relay that exact name—never swap in another option." + ) + + def get_model_field_schema(self) -> dict[str, Any]: + """ + Generate the model field schema based on auto mode configuration. + + When auto mode is enabled, the model parameter becomes required + and includes detailed descriptions of each model's capabilities. + + Returns: + Dict containing the model field JSON schema + """ + + from config import DEFAULT_MODEL + + # Use the centralized effective auto mode check + if self.is_effective_auto_mode(): + description = ( + "Currently in auto model selection mode. CRITICAL: When the user names a model, you MUST use that exact name unless the server rejects it. " + "If no model is provided, you may call the `listmodels` tool to review options and select an appropriate match." + ) + return { + "type": "string", + "description": description, + } + + description = ( + f"The default model is '{DEFAULT_MODEL}'. Override only when the user explicitly requests a different model, and use that exact name. " + "If the requested model fails validation, surface the server error instead of substituting another model. When unsure, call the `listmodels` tool for details." + ) + + return { + "type": "string", + "description": description, + } + + def get_default_temperature(self) -> float: + """ + Return the default temperature setting for this tool. + + Override this method to set tool-specific temperature defaults. + Lower values (0.0-0.3) for analytical tasks, higher (0.7-1.0) for creative tasks. + + Returns: + float: Default temperature between 0.0 and 1.0 + """ + return 0.5 + + def wants_line_numbers_by_default(self) -> bool: + """ + Return whether this tool wants line numbers added to code files by default. + + By default, ALL tools get line numbers for precise code references. + Line numbers are essential for accurate communication about code locations. + + Returns: + bool: True if line numbers should be added by default for this tool + """ + return True # All tools get line numbers by default for consistency + + def get_default_thinking_mode(self) -> str: + """ + Return the default thinking mode for this tool. + + Thinking mode controls computational budget for reasoning. + Override for tools that need more or less reasoning depth. + + Returns: + str: One of "minimal", "low", "medium", "high", "max" + """ + return "medium" # Default to medium thinking for better reasoning + + def get_model_category(self) -> "ToolModelCategory": + """ + Return the model category for this tool. + + Model category influences which model is selected in auto mode. + Override to specify whether your tool needs extended reasoning, + fast response, or balanced capabilities. + + Returns: + ToolModelCategory: Category that influences model selection + """ + from tools.models import ToolModelCategory + + return ToolModelCategory.BALANCED + + @abstractmethod + def get_request_model(self): + """ + Return the Pydantic model class used for validating requests. + + This model should inherit from ToolRequest and define all + parameters specific to this tool. + + Returns: + Type[ToolRequest]: The request model class + """ + pass + + def validate_file_paths(self, request) -> Optional[str]: + """ + Validate that all file paths in the request are absolute. + + This is a critical security function that prevents path traversal attacks + and ensures all file access is properly controlled. All file paths must + be absolute to avoid ambiguity and security issues. + + Args: + request: The validated request object + + Returns: + Optional[str]: Error message if validation fails, None if all paths are valid + """ + # Only validate files/paths if they exist in the request + file_fields = [ + "files", + "file", + "path", + "directory", + "notebooks", + "test_examples", + "style_guide_examples", + "files_checked", + "relevant_files", + ] + + for field_name in file_fields: + if hasattr(request, field_name): + field_value = getattr(request, field_name) + if field_value is None: + continue + + # Handle both single paths and lists of paths + paths_to_check = field_value if isinstance(field_value, list) else [field_value] + + for path in paths_to_check: + if path and not os.path.isabs(path): + return f"All file paths must be FULL absolute paths. Invalid path: '{path}'" + + return None + + def _validate_token_limit(self, content: str, content_type: str = "Content") -> None: + """ + Validate that content doesn't exceed the MCP prompt size limit. + + Args: + content: The content to validate + content_type: Description of the content type for error messages + + Raises: + ValueError: If content exceeds size limit + """ + is_valid, token_count = check_token_limit(content, MCP_PROMPT_SIZE_LIMIT) + if not is_valid: + error_msg = f"~{token_count:,} tokens. Maximum is {MCP_PROMPT_SIZE_LIMIT:,} tokens." + logger.error(f"{self.name} tool {content_type.lower()} validation failed: {error_msg}") + raise ValueError(f"{content_type} too large: {error_msg}") + + logger.debug(f"{self.name} tool {content_type.lower()} token validation passed: {token_count:,} tokens") + + def get_model_provider(self, model_name: str) -> ModelProvider: + """ + Get the appropriate model provider for the given model name. + + This method performs runtime validation to ensure the requested model + is actually available with the current API key configuration. + + Args: + model_name: Name of the model to get provider for + + Returns: + ModelProvider: The provider instance for the model + + Raises: + ValueError: If the model is not available or provider not found + """ + try: + provider = ModelProviderRegistry.get_provider_for_model(model_name) + if not provider: + logger.error(f"No provider found for model '{model_name}' in {self.name} tool") + raise ValueError(self._build_model_unavailable_message(model_name)) + + return provider + except Exception as e: + logger.error(f"Failed to get provider for model '{model_name}' in {self.name} tool: {e}") + raise + + # === CONVERSATION AND FILE HANDLING METHODS === + + def get_conversation_embedded_files(self, continuation_id: Optional[str]) -> list[str]: + """ + Get list of files already embedded in conversation history. + + This method returns the list of files that have already been embedded + in the conversation history for a given continuation thread. Tools can + use this to avoid re-embedding files that are already available in the + conversation context. + + Args: + continuation_id: Thread continuation ID, or None for new conversations + + Returns: + list[str]: List of file paths already embedded in conversation history + """ + if not continuation_id: + # New conversation, no files embedded yet + return [] + + thread_context = get_thread(continuation_id) + if not thread_context: + # Thread not found, no files embedded + return [] + + embedded_files = get_conversation_file_list(thread_context) + logger.debug(f"[FILES] {self.name}: Found {len(embedded_files)} embedded files") + return embedded_files + + def filter_new_files(self, requested_files: list[str], continuation_id: Optional[str]) -> list[str]: + """ + Filter out files that are already embedded in conversation history. + + This method prevents duplicate file embeddings by filtering out files that have + already been embedded in the conversation history. This optimizes token usage + while ensuring tools still have logical access to all requested files through + conversation history references. + + Args: + requested_files: List of files requested for current tool execution + continuation_id: Thread continuation ID, or None for new conversations + + Returns: + list[str]: List of files that need to be embedded (not already in history) + """ + logger.debug(f"[FILES] {self.name}: Filtering {len(requested_files)} requested files") + + if not continuation_id: + # New conversation, all files are new + logger.debug(f"[FILES] {self.name}: New conversation, all {len(requested_files)} files are new") + return requested_files + + try: + embedded_files = set(self.get_conversation_embedded_files(continuation_id)) + logger.debug(f"[FILES] {self.name}: Found {len(embedded_files)} embedded files in conversation") + + # Safety check: If no files are marked as embedded but we have a continuation_id, + # this might indicate an issue with conversation history. Be conservative. + if not embedded_files: + logger.debug(f"{self.name} tool: No files found in conversation history for thread {continuation_id}") + logger.debug( + f"[FILES] {self.name}: No embedded files found, returning all {len(requested_files)} requested files" + ) + return requested_files + + # Return only files that haven't been embedded yet + new_files = [f for f in requested_files if f not in embedded_files] + logger.debug( + f"[FILES] {self.name}: After filtering: {len(new_files)} new files, {len(requested_files) - len(new_files)} already embedded" + ) + logger.debug(f"[FILES] {self.name}: New files to embed: {new_files}") + + # Log filtering results for debugging + if len(new_files) < len(requested_files): + skipped = [f for f in requested_files if f in embedded_files] + logger.debug( + f"{self.name} tool: Filtering {len(skipped)} files already in conversation history: {', '.join(skipped)}" + ) + logger.debug(f"[FILES] {self.name}: Skipped (already embedded): {skipped}") + + return new_files + + except Exception as e: + # If there's any issue with conversation history lookup, be conservative + # and include all files rather than risk losing access to needed files + logger.warning(f"{self.name} tool: Error checking conversation history for {continuation_id}: {e}") + logger.warning(f"{self.name} tool: Including all requested files as fallback") + logger.debug( + f"[FILES] {self.name}: Exception in filter_new_files, returning all {len(requested_files)} files as fallback" + ) + return requested_files + + def format_conversation_turn(self, turn: ConversationTurn) -> list[str]: + """ + Format a conversation turn for display in conversation history. + + Tools can override this to provide custom formatting for their responses + while maintaining the standard structure for cross-tool compatibility. + + This method is called by build_conversation_history when reconstructing + conversation context, allowing each tool to control how its responses + appear in subsequent conversation turns. + + Args: + turn: The conversation turn to format (from utils.conversation_memory) + + Returns: + list[str]: Lines of formatted content for this turn + + Example: + Default implementation returns: + ["Files used in this turn: file1.py, file2.py", "", "Response content..."] + + Tools can override to add custom sections, formatting, or metadata display. + """ + parts = [] + + # Add files context if present + if turn.files: + parts.append(f"Files used in this turn: {', '.join(turn.files)}") + parts.append("") # Empty line for readability + + # Add the actual content + parts.append(turn.content) + + return parts + + def handle_prompt_file(self, files: Optional[list[str]]) -> tuple[Optional[str], Optional[list[str]]]: + """ + Check for and handle prompt.txt in the files list. + + If prompt.txt is found, reads its content and removes it from the files list. + This file is treated specially as the main prompt, not as an embedded file. + + This mechanism allows us to work around MCP's ~25K token limit by having + the CLI save large prompts to a file, effectively using the file transfer + mechanism to bypass token constraints while preserving response capacity. + + Args: + files: List of file paths (will be translated for current environment) + + Returns: + tuple: (prompt_content, updated_files_list) + """ + if not files: + return None, files + + prompt_content = None + updated_files = [] + + for file_path in files: + + # Check if the filename is exactly "prompt.txt" + # This ensures we don't match files like "myprompt.txt" or "prompt.txt.bak" + if os.path.basename(file_path) == "prompt.txt": + try: + # Read prompt.txt content and extract just the text + content, _ = read_file_content(file_path) + # Extract the content between the file markers + if "--- BEGIN FILE:" in content and "--- END FILE:" in content: + lines = content.split("\n") + in_content = False + content_lines = [] + for line in lines: + if line.startswith("--- BEGIN FILE:"): + in_content = True + continue + elif line.startswith("--- END FILE:"): + break + elif in_content: + content_lines.append(line) + prompt_content = "\n".join(content_lines) + else: + # Fallback: if it's already raw content (from tests or direct input) + # and doesn't have error markers, use it directly + if not content.startswith("\n--- ERROR"): + prompt_content = content + else: + prompt_content = None + except Exception: + # If we can't read the file, we'll just skip it + # The error will be handled elsewhere + pass + else: + # Keep the original path in the files list (will be translated later by read_files) + updated_files.append(file_path) + + return prompt_content, updated_files if updated_files else None + + def get_prompt_content_for_size_validation(self, user_content: str) -> str: + """ + Get the content that should be validated for MCP prompt size limits. + + This hook method allows tools to specify what content should be checked + against the MCP transport size limit. By default, it returns the user content, + but can be overridden to exclude conversation history when needed. + + Args: + user_content: The user content that would normally be validated + + Returns: + The content that should actually be validated for size limits + """ + # Default implementation: validate the full user content + return user_content + + def check_prompt_size(self, text: str) -> Optional[dict[str, Any]]: + """ + Check if USER INPUT text is too large for MCP transport boundary. + + IMPORTANT: This method should ONLY be used to validate user input that crosses + the CLI ↔ MCP Server transport boundary. It should NOT be used to limit + internal MCP Server operations. + + Args: + text: The user input text to check (NOT internal prompt content) + + Returns: + Optional[Dict[str, Any]]: Response asking for file handling if too large, None otherwise + """ + if text and len(text) > MCP_PROMPT_SIZE_LIMIT: + return { + "status": "resend_prompt", + "content": ( + f"MANDATORY ACTION REQUIRED: The prompt is too large for MCP's token limits (>{MCP_PROMPT_SIZE_LIMIT:,} characters). " + "YOU MUST IMMEDIATELY save the prompt text to a temporary file named 'prompt.txt' in the working directory. " + "DO NOT attempt to shorten or modify the prompt. SAVE IT AS-IS to 'prompt.txt'. " + "Then resend the request with the absolute file path to 'prompt.txt' in the files parameter (must be FULL absolute path - DO NOT SHORTEN), " + "along with any other files you wish to share as context. Leave the prompt text itself empty or very brief in the new request. " + "This is the ONLY way to handle large prompts - you MUST follow these exact steps." + ), + "content_type": "text", + "metadata": { + "prompt_size": len(text), + "limit": MCP_PROMPT_SIZE_LIMIT, + "instructions": "MANDATORY: Save prompt to 'prompt.txt' in current folder and include absolute path in files parameter. DO NOT modify or shorten the prompt.", + }, + } + return None + + def _prepare_file_content_for_prompt( + self, + request_files: list[str], + continuation_id: Optional[str], + context_description: str = "New files", + max_tokens: Optional[int] = None, + reserve_tokens: int = 1_000, + remaining_budget: Optional[int] = None, + arguments: Optional[dict] = None, + model_context: Optional[Any] = None, + ) -> tuple[str, list[str]]: + """ + Centralized file processing implementing dual prioritization strategy. + + This method is the heart of conversation-aware file processing across all tools. + + Args: + request_files: List of files requested for current tool execution + continuation_id: Thread continuation ID, or None for new conversations + context_description: Description for token limit validation (e.g. "Code", "New files") + max_tokens: Maximum tokens to use (defaults to remaining budget or model-specific content allocation) + reserve_tokens: Tokens to reserve for additional prompt content (default 1K) + remaining_budget: Remaining token budget after conversation history (from server.py) + arguments: Original tool arguments (used to extract _remaining_tokens if available) + model_context: Model context object with all model information including token allocation + + Returns: + tuple[str, list[str]]: (formatted_file_content, actually_processed_files) + - formatted_file_content: Formatted file content string ready for prompt inclusion + - actually_processed_files: List of individual file paths that were actually read and embedded + (directories are expanded to individual files) + """ + if not request_files: + return "", [] + + # Extract remaining budget from arguments if available + if remaining_budget is None: + # Use provided arguments or fall back to stored arguments from execute() + args_to_use = arguments or getattr(self, "_current_arguments", {}) + remaining_budget = args_to_use.get("_remaining_tokens") + + # Use remaining budget if provided, otherwise fall back to max_tokens or model-specific default + if remaining_budget is not None: + effective_max_tokens = remaining_budget - reserve_tokens + elif max_tokens is not None: + effective_max_tokens = max_tokens - reserve_tokens + else: + # Use model_context for token allocation + if not model_context: + # Try to get from stored attributes as fallback + model_context = getattr(self, "_model_context", None) + if not model_context: + logger.error( + f"[FILES] {self.name}: _prepare_file_content_for_prompt called without model_context. " + "This indicates an incorrect call sequence in the tool's implementation." + ) + raise RuntimeError("Model context not provided for file preparation.") + + # This is now the single source of truth for token allocation. + try: + token_allocation = model_context.calculate_token_allocation() + # Standardize on `file_tokens` for consistency and correctness. + effective_max_tokens = token_allocation.file_tokens - reserve_tokens + logger.debug( + f"[FILES] {self.name}: Using model context for {model_context.model_name}: " + f"{token_allocation.file_tokens:,} file tokens from {token_allocation.total_tokens:,} total" + ) + except Exception as e: + logger.error( + f"[FILES] {self.name}: Failed to calculate token allocation from model context: {e}", exc_info=True + ) + # If the context exists but calculation fails, we still need to prevent a crash. + # A loud error is logged, and we fall back to a safe default. + effective_max_tokens = 100_000 - reserve_tokens + + # Ensure we have a reasonable minimum budget + effective_max_tokens = max(1000, effective_max_tokens) + + files_to_embed = self.filter_new_files(request_files, continuation_id) + logger.debug(f"[FILES] {self.name}: Will embed {len(files_to_embed)} files after filtering") + + # Log the specific files for debugging/testing + if files_to_embed: + logger.info( + f"[FILE_PROCESSING] {self.name} tool will embed new files: {', '.join([os.path.basename(f) for f in files_to_embed])}" + ) + else: + logger.info( + f"[FILE_PROCESSING] {self.name} tool: No new files to embed (all files already in conversation history)" + ) + + content_parts = [] + actually_processed_files = [] + + # Read content of new files only + if files_to_embed: + logger.debug(f"{self.name} tool embedding {len(files_to_embed)} new files: {', '.join(files_to_embed)}") + logger.debug( + f"[FILES] {self.name}: Starting file embedding with token budget {effective_max_tokens + reserve_tokens:,}" + ) + try: + # Before calling read_files, expand directories to get individual file paths + from utils.file_utils import expand_paths + + expanded_files = expand_paths(files_to_embed) + logger.debug( + f"[FILES] {self.name}: Expanded {len(files_to_embed)} paths to {len(expanded_files)} individual files" + ) + + file_content = read_files( + files_to_embed, + max_tokens=effective_max_tokens + reserve_tokens, + reserve_tokens=reserve_tokens, + include_line_numbers=self.wants_line_numbers_by_default(), + ) + # Note: No need to validate against MCP_PROMPT_SIZE_LIMIT here + # read_files already handles token-aware truncation based on model's capabilities + content_parts.append(file_content) + + # Track the expanded files as actually processed + actually_processed_files.extend(expanded_files) + + # Estimate tokens for debug logging + from utils.token_utils import estimate_tokens + + content_tokens = estimate_tokens(file_content) + logger.debug( + f"{self.name} tool successfully embedded {len(files_to_embed)} files ({content_tokens:,} tokens)" + ) + logger.debug(f"[FILES] {self.name}: Successfully embedded files - {content_tokens:,} tokens used") + logger.debug( + f"[FILES] {self.name}: Actually processed {len(actually_processed_files)} individual files" + ) + except Exception as e: + logger.error(f"{self.name} tool failed to embed files {files_to_embed}: {type(e).__name__}: {e}") + logger.debug(f"[FILES] {self.name}: File embedding failed - {type(e).__name__}: {e}") + raise + else: + logger.debug(f"[FILES] {self.name}: No files to embed after filtering") + + # Generate note about files already in conversation history + if continuation_id and len(files_to_embed) < len(request_files): + embedded_files = self.get_conversation_embedded_files(continuation_id) + skipped_files = [f for f in request_files if f in embedded_files] + if skipped_files: + logger.debug( + f"{self.name} tool skipping {len(skipped_files)} files already in conversation history: {', '.join(skipped_files)}" + ) + logger.debug(f"[FILES] {self.name}: Adding note about {len(skipped_files)} skipped files") + if content_parts: + content_parts.append("\n\n") + note_lines = [ + "--- NOTE: Additional files referenced in conversation history ---", + "The following files are already available in our conversation context:", + "\n".join(f" - {f}" for f in skipped_files), + "--- END NOTE ---", + ] + content_parts.append("\n".join(note_lines)) + else: + logger.debug(f"[FILES] {self.name}: No skipped files to note") + + result = "".join(content_parts) if content_parts else "" + logger.debug( + f"[FILES] {self.name}: _prepare_file_content_for_prompt returning {len(result)} chars, {len(actually_processed_files)} processed files" + ) + return result, actually_processed_files + + def get_websearch_instruction(self, use_websearch: bool, tool_specific: Optional[str] = None) -> str: + """ + Generate standardized web search instruction based on the use_websearch parameter. + + Args: + use_websearch: Whether web search is enabled + tool_specific: Optional tool-specific search guidance + + Returns: + str: Web search instruction to append to prompt, or empty string + """ + if not use_websearch: + return "" + + base_instruction = """ + +WEB SEARCH CAPABILITY: You can request the calling agent to perform web searches to enhance your analysis with current information! + +IMPORTANT: When you identify areas where web searches would significantly improve your response (such as checking current documentation, finding recent solutions, verifying best practices, or gathering community insights), you MUST explicitly instruct the agent to perform specific web searches and then respond back using the continuation_id from this response to continue the analysis. + +Use clear, direct language based on the value of the search: + +For valuable supplementary information: "Please perform a web search on '[specific topic/query]' and then continue this analysis using the continuation_id from this response if you find relevant information." + +For important missing information: "Please search for '[specific topic/query]' and respond back with the findings using the continuation_id from this response - this information is needed to provide a complete analysis." + +For critical/essential information: "SEARCH REQUIRED: Please immediately perform a web search on '[specific topic/query]' and respond back with the results using the continuation_id from this response. Cannot provide accurate analysis without this current information." + +This ensures you get the most current and comprehensive information while maintaining conversation context through the continuation_id.""" + + if tool_specific: + return f"""{base_instruction} + +{tool_specific} + +When recommending searches, be specific about what information you need and why it would improve your analysis.""" + + # Default instruction for all tools + return f"""{base_instruction} + +Consider requesting searches for: +- Current documentation and API references +- Recent best practices and patterns +- Known issues and community solutions +- Framework updates and compatibility +- Security advisories and patches +- Performance benchmarks and optimizations + +When recommending searches, be specific about what information you need and why it would improve your analysis. Always remember to instruct agent to use the continuation_id from this response when providing search results.""" + + def get_language_instruction(self) -> str: + """ + Generate language instruction based on LOCALE configuration. + + Returns: + str: Language instruction to prepend to prompt, or empty string if + no locale set + """ + # Read LOCALE directly from environment to support dynamic changes + # This allows tests to modify os.environ["LOCALE"] and see the changes + import os + + locale = os.getenv("LOCALE", "").strip() + + if not locale: + return "" + + # Simple language instruction + return f"Always respond in {locale}.\n\n" + + # === ABSTRACT METHODS FOR SIMPLE TOOLS === + + @abstractmethod + async def prepare_prompt(self, request) -> str: + """ + Prepare the complete prompt for the AI model. + + This method should construct the full prompt by combining: + - System prompt from get_system_prompt() + - File content from _prepare_file_content_for_prompt() + - Conversation history from reconstruct_thread_context() + - User's request and any tool-specific context + + Args: + request: The validated request object + + Returns: + str: Complete prompt ready for the AI model + """ + pass + + def format_response(self, response: str, request, model_info: dict = None) -> str: + """ + Format the AI model's response for the user. + + This method allows tools to post-process the model's response, + adding structure, validation, or additional context. + + The default implementation returns the response unchanged. + Tools can override this method to add custom formatting. + + Args: + response: Raw response from the AI model + request: The original request object + model_info: Optional model information and metadata + + Returns: + str: Formatted response ready for the user + """ + return response + + # === IMPLEMENTATION METHODS === + # These will be provided in a full implementation but are inherited from current base.py + # for now to maintain compatibility. + + async def execute(self, arguments: dict[str, Any]) -> list[TextContent]: + """Execute the tool - will be inherited from existing base.py for now.""" + # This will be implemented by importing from the current base.py + # for backward compatibility during the migration + raise NotImplementedError("Subclasses must implement execute method") + + def _should_require_model_selection(self, model_name: str) -> bool: + """ + Check if we should require the CLI to select a model at runtime. + + This is called during request execution to determine if we need + to return an error asking the CLI to provide a model parameter. + + Args: + model_name: The model name from the request or DEFAULT_MODEL + + Returns: + bool: True if we should require model selection + """ + # Case 1: Model is explicitly "auto" + if model_name.lower() == "auto": + return True + + # Case 2: Requested model is not available + from providers.registry import ModelProviderRegistry + + provider = ModelProviderRegistry.get_provider_for_model(model_name) + if not provider: + logger.warning(f"Model '{model_name}' is not available with current API keys. Requiring model selection.") + return True + + return False + + def _get_available_models(self) -> list[str]: + """ + Get list of models available from enabled providers. + + Only returns models from providers that have valid API keys configured. + This fixes the namespace collision bug where models from disabled providers + were shown to the CLI, causing routing conflicts. + + Returns: + List of model names from enabled providers only + """ + from providers.registry import ModelProviderRegistry + + # Get models from enabled providers only (those with valid API keys) + all_models = ModelProviderRegistry.get_available_model_names() + + # Add OpenRouter models and their aliases when OpenRouter is configured + openrouter_key = os.getenv("OPENROUTER_API_KEY") + if openrouter_key and openrouter_key != "your_openrouter_api_key_here": + try: + registry = self._get_openrouter_registry() + + # Include every known alias so MCP enum matches registry capabilities + for alias in registry.list_aliases(): + config = registry.resolve(alias) + if config and config.is_custom: + # Custom-only models require CUSTOM_API_URL; defer to custom block + continue + if alias not in all_models: + all_models.append(alias) + except Exception as exc: # pragma: no cover - logged for observability + import logging + + logging.debug(f"Failed to add OpenRouter models to enum: {exc}") + + # Add custom models (and their aliases) when a custom endpoint is available + custom_url = os.getenv("CUSTOM_API_URL") + if custom_url: + try: + registry = self._get_openrouter_registry() + for alias in registry.list_aliases(): + config = registry.resolve(alias) + if config and config.is_custom and alias not in all_models: + all_models.append(alias) + except Exception as exc: # pragma: no cover - logged for observability + import logging + + logging.debug(f"Failed to add custom models to enum: {exc}") + + # Remove duplicates while preserving insertion order + seen: set[str] = set() + unique_models: list[str] = [] + for model in all_models: + if model not in seen: + seen.add(model) + unique_models.append(model) + + return unique_models + + def _resolve_model_context(self, arguments: dict, request) -> tuple[str, Any]: + """ + Resolve model context and name using centralized logic. + + This method extracts the model resolution logic from execute() so it can be + reused by tools that override execute() (like debug tool) without duplicating code. + + Args: + arguments: Dictionary of arguments from the MCP client + request: The validated request object + + Returns: + tuple[str, ModelContext]: (resolved_model_name, model_context) + + Raises: + ValueError: If model resolution fails or model selection is required + """ + # MODEL RESOLUTION NOW HAPPENS AT MCP BOUNDARY + # Extract pre-resolved model context from server.py + model_context = arguments.get("_model_context") + resolved_model_name = arguments.get("_resolved_model_name") + + if model_context and resolved_model_name: + # Model was already resolved at MCP boundary + model_name = resolved_model_name + logger.debug(f"Using pre-resolved model '{model_name}' from MCP boundary") + else: + # Fallback for direct execute calls + model_name = getattr(request, "model", None) + if not model_name: + from config import DEFAULT_MODEL + + model_name = DEFAULT_MODEL + logger.debug(f"Using fallback model resolution for '{model_name}' (test mode)") + + # For tests: Check if we should require model selection (auto mode) + if self._should_require_model_selection(model_name): + # Build error message based on why selection is required + if model_name.lower() == "auto": + error_message = self._build_auto_mode_required_message() + else: + error_message = self._build_model_unavailable_message(model_name) + raise ValueError(error_message) + + # Create model context for tests + from utils.model_context import ModelContext + + model_context = ModelContext(model_name) + + return model_name, model_context + + def validate_and_correct_temperature(self, temperature: float, model_context: Any) -> tuple[float, list[str]]: + """ + Validate and correct temperature for the specified model. + + This method ensures that the temperature value is within the valid range + for the specific model being used. Different models have different temperature + constraints (e.g., o1 models require temperature=1.0, GPT models support 0-2). + + Args: + temperature: Temperature value to validate + model_context: Model context object containing model name, provider, and capabilities + + Returns: + Tuple of (corrected_temperature, warning_messages) + """ + try: + # Use model context capabilities directly - clean OOP approach + capabilities = model_context.capabilities + constraint = capabilities.temperature_constraint + + warnings = [] + if not constraint.validate(temperature): + corrected = constraint.get_corrected_value(temperature) + warning = ( + f"Temperature {temperature} invalid for {model_context.model_name}. " + f"{constraint.get_description()}. Using {corrected} instead." + ) + warnings.append(warning) + return corrected, warnings + + return temperature, warnings + + except Exception as e: + # If validation fails for any reason, use the original temperature + # and log a warning (but don't fail the request) + logger.warning(f"Temperature validation failed for {model_context.model_name}: {e}") + return temperature, [f"Temperature validation failed: {e}"] + + def _validate_image_limits( + self, images: Optional[list[str]], model_context: Optional[Any] = None, continuation_id: Optional[str] = None + ) -> Optional[dict]: + """ + Validate image size and count against model capabilities. + + This performs strict validation to ensure we don't exceed model-specific + image limits. Uses capability-based validation with actual model + configuration rather than hard-coded limits. + + Args: + images: List of image paths/data URLs to validate + model_context: Model context object containing model name, provider, and capabilities + continuation_id: Optional continuation ID for conversation context + + Returns: + Optional[dict]: Error response if validation fails, None if valid + """ + if not images: + return None + + # Import here to avoid circular imports + import base64 + from pathlib import Path + + # Handle legacy calls (positional model_name string) + if isinstance(model_context, str): + # Legacy call: _validate_image_limits(images, "model-name") + logger.warning( + "Legacy _validate_image_limits call with model_name string. Use model_context object instead." + ) + try: + from utils.model_context import ModelContext + + model_context = ModelContext(model_context) + except Exception as e: + logger.warning(f"Failed to create model context from legacy model_name: {e}") + # Generic error response for any unavailable model + return { + "status": "error", + "content": self._build_model_unavailable_message(str(model_context)), + "content_type": "text", + "metadata": { + "error_type": "validation_error", + "model_name": model_context, + "supports_images": None, # Unknown since model doesn't exist + "image_count": len(images) if images else 0, + }, + } + + if not model_context: + # Get from tool's stored context as fallback + model_context = getattr(self, "_model_context", None) + if not model_context: + logger.warning("No model context available for image validation") + return None + + try: + # Use model context capabilities directly - clean OOP approach + capabilities = model_context.capabilities + model_name = model_context.model_name + except Exception as e: + logger.warning(f"Failed to get capabilities from model_context for image validation: {e}") + # Generic error response when capabilities cannot be accessed + model_name = getattr(model_context, "model_name", "unknown") + return { + "status": "error", + "content": self._build_model_unavailable_message(model_name), + "content_type": "text", + "metadata": { + "error_type": "validation_error", + "model_name": model_name, + "supports_images": None, # Unknown since model capabilities unavailable + "image_count": len(images) if images else 0, + }, + } + + # Check if model supports images + if not capabilities.supports_images: + return { + "status": "error", + "content": ( + f"Image support not available: Model '{model_name}' does not support image processing. " + f"Please use a vision-capable model such as 'gemini-2.5-flash', 'o3', " + f"or 'claude-opus-4.1' for image analysis tasks." + ), + "content_type": "text", + "metadata": { + "error_type": "validation_error", + "model_name": model_name, + "supports_images": False, + "image_count": len(images), + }, + } + + # Get model image limits from capabilities + max_images = 5 # Default max number of images + max_size_mb = capabilities.max_image_size_mb + + # Check image count + if len(images) > max_images: + return { + "status": "error", + "content": ( + f"Too many images: Model '{model_name}' supports a maximum of {max_images} images, " + f"but {len(images)} were provided. Please reduce the number of images." + ), + "content_type": "text", + "metadata": { + "error_type": "validation_error", + "model_name": model_name, + "image_count": len(images), + "max_images": max_images, + }, + } + + # Calculate total size of all images + total_size_mb = 0.0 + for image_path in images: + try: + if image_path.startswith("data:image/"): + # Handle data URL: data:image/png;base64,iVBORw0... + _, data = image_path.split(",", 1) + # Base64 encoding increases size by ~33%, so decode to get actual size + actual_size = len(base64.b64decode(data)) + total_size_mb += actual_size / (1024 * 1024) + else: + # Handle file path + path = Path(image_path) + if path.exists(): + file_size = path.stat().st_size + total_size_mb += file_size / (1024 * 1024) + else: + logger.warning(f"Image file not found: {image_path}") + # Assume a reasonable size for missing files to avoid breaking validation + total_size_mb += 1.0 # 1MB assumption + except Exception as e: + logger.warning(f"Failed to get size for image {image_path}: {e}") + # Assume a reasonable size for problematic files + total_size_mb += 1.0 # 1MB assumption + + # Apply 40MB cap for custom models if needed + effective_limit_mb = max_size_mb + try: + from providers.shared import ProviderType + + # ModelCapabilities dataclass has provider field defined + if capabilities.provider == ProviderType.CUSTOM: + effective_limit_mb = min(max_size_mb, 40.0) + except Exception: + pass + + # Validate against size limit + if total_size_mb > effective_limit_mb: + return { + "status": "error", + "content": ( + f"Image size limit exceeded: Model '{model_name}' supports maximum {effective_limit_mb:.1f}MB " + f"for all images combined, but {total_size_mb:.1f}MB was provided. " + f"Please reduce image sizes or count and try again." + ), + "content_type": "text", + "metadata": { + "error_type": "validation_error", + "model_name": model_name, + "total_size_mb": round(total_size_mb, 2), + "limit_mb": round(effective_limit_mb, 2), + "image_count": len(images), + "supports_images": True, + }, + } + + # All validations passed + logger.debug(f"Image validation passed: {len(images)} images, {total_size_mb:.1f}MB total") + return None + + def _parse_response(self, raw_text: str, request, model_info: Optional[dict] = None): + """Parse response - will be inherited for now.""" + # Implementation inherited from current base.py + raise NotImplementedError("Subclasses must implement _parse_response method") diff --git a/tools/shared/schema_builders.py b/tools/shared/schema_builders.py new file mode 100644 index 0000000..dd0146c --- /dev/null +++ b/tools/shared/schema_builders.py @@ -0,0 +1,159 @@ +""" +Core schema building functionality for Zen MCP tools. + +This module provides base schema generation functionality for simple tools. +Workflow-specific schema building is located in workflow/schema_builders.py +to maintain proper separation of concerns. +""" + +from typing import Any + +from .base_models import COMMON_FIELD_DESCRIPTIONS + + +class SchemaBuilder: + """ + Base schema builder for simple MCP tools. + + This class provides static methods to build consistent schemas for simple tools. + Workflow tools use WorkflowSchemaBuilder in workflow/schema_builders.py. + """ + + # Common field schemas that can be reused across all tool types + COMMON_FIELD_SCHEMAS = { + "temperature": { + "type": "number", + "description": COMMON_FIELD_DESCRIPTIONS["temperature"], + "minimum": 0.0, + "maximum": 1.0, + }, + "thinking_mode": { + "type": "string", + "enum": ["minimal", "low", "medium", "high", "max"], + "description": COMMON_FIELD_DESCRIPTIONS["thinking_mode"], + }, + "continuation_id": { + "type": "string", + "description": COMMON_FIELD_DESCRIPTIONS["continuation_id"], + }, + "images": { + "type": "array", + "items": {"type": "string"}, + "description": COMMON_FIELD_DESCRIPTIONS["images"], + }, + } + + # Simple tool-specific field schemas (workflow tools use relevant_files instead) + SIMPLE_FIELD_SCHEMAS = { + "files": { + "type": "array", + "items": {"type": "string"}, + "description": COMMON_FIELD_DESCRIPTIONS["files"], + }, + } + + @staticmethod + def build_schema( + tool_specific_fields: dict[str, dict[str, Any]] = None, + required_fields: list[str] = None, + model_field_schema: dict[str, Any] = None, + auto_mode: bool = False, + require_model: bool = False, + ) -> dict[str, Any]: + """ + Build complete schema for simple tools. + + Args: + tool_specific_fields: Additional fields specific to the tool + required_fields: List of required field names + model_field_schema: Schema for the model field + auto_mode: Whether the tool is in auto mode (affects model requirement) + + Returns: + Complete JSON schema for the tool + """ + properties = {} + + # Add common fields (temperature, thinking_mode, etc.) + properties.update(SchemaBuilder.COMMON_FIELD_SCHEMAS) + + # Add simple tool-specific fields (files field for simple tools) + properties.update(SchemaBuilder.SIMPLE_FIELD_SCHEMAS) + + # Add model field if provided + if model_field_schema: + properties["model"] = model_field_schema + + # Add tool-specific fields if provided + if tool_specific_fields: + properties.update(tool_specific_fields) + + # Build required fields list + required = list(required_fields) if required_fields else [] + if (auto_mode or require_model) and "model" not in required: + required.append("model") + + # Build the complete schema + schema = { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": properties, + "additionalProperties": False, + } + + if required: + schema["required"] = required + + return schema + + @staticmethod + def get_common_fields() -> dict[str, dict[str, Any]]: + """Get the standard field schemas for simple tools.""" + return SchemaBuilder.COMMON_FIELD_SCHEMAS.copy() + + @staticmethod + def create_field_schema( + field_type: str, + description: str, + enum_values: list[str] = None, + minimum: float = None, + maximum: float = None, + items_type: str = None, + default: Any = None, + ) -> dict[str, Any]: + """ + Helper method to create field schemas with common patterns. + + Args: + field_type: JSON schema type ("string", "number", "array", etc.) + description: Human-readable description of the field + enum_values: For enum fields, list of allowed values + minimum: For numeric fields, minimum value + maximum: For numeric fields, maximum value + items_type: For array fields, type of array items + default: Default value for the field + + Returns: + JSON schema object for the field + """ + schema = { + "type": field_type, + "description": description, + } + + if enum_values: + schema["enum"] = enum_values + + if minimum is not None: + schema["minimum"] = minimum + + if maximum is not None: + schema["maximum"] = maximum + + if items_type and field_type == "array": + schema["items"] = {"type": items_type} + + if default is not None: + schema["default"] = default + + return schema diff --git a/tools/simple/__init__.py b/tools/simple/__init__.py new file mode 100644 index 0000000..9d6f03a --- /dev/null +++ b/tools/simple/__init__.py @@ -0,0 +1,18 @@ +""" +Simple tools for Zen MCP. + +Simple tools follow a basic request → AI model → response pattern. +They inherit from SimpleTool which provides streamlined functionality +for tools that don't need multi-step workflows. + +Available simple tools: +- chat: General chat and collaborative thinking +- consensus: Multi-perspective analysis +- listmodels: Model listing and information +- testgen: Test generation +- tracer: Execution tracing +""" + +from .base import SimpleTool + +__all__ = ["SimpleTool"] diff --git a/tools/simple/base.py b/tools/simple/base.py new file mode 100644 index 0000000..6a02861 --- /dev/null +++ b/tools/simple/base.py @@ -0,0 +1,985 @@ +""" +Base class for simple MCP tools. + +Simple tools follow a straightforward pattern: +1. Receive request +2. Prepare prompt (with files, context, etc.) +3. Call AI model +4. Format and return response + +They use the shared SchemaBuilder for consistent schema generation +and inherit all the conversation, file processing, and model handling +capabilities from BaseTool. +""" + +from abc import abstractmethod +from typing import Any, Optional + +from tools.shared.base_models import ToolRequest +from tools.shared.base_tool import BaseTool +from tools.shared.schema_builders import SchemaBuilder + + +class SimpleTool(BaseTool): + """ + Base class for simple (non-workflow) tools. + + Simple tools are request/response tools that don't require multi-step workflows. + They benefit from: + - Automatic schema generation using SchemaBuilder + - Inherited conversation handling and file processing + - Standardized model integration + - Consistent error handling and response formatting + + To create a simple tool: + 1. Inherit from SimpleTool + 2. Implement get_tool_fields() to define tool-specific fields + 3. Implement prepare_prompt() for prompt preparation + 4. Optionally override format_response() for custom formatting + 5. Optionally override get_required_fields() for custom requirements + + Example: + class ChatTool(SimpleTool): + def get_name(self) -> str: + return "chat" + + def get_tool_fields(self) -> Dict[str, Dict[str, Any]]: + return { + "prompt": { + "type": "string", + "description": "Your question or idea...", + }, + "files": SimpleTool.FILES_FIELD, + } + + def get_required_fields(self) -> List[str]: + return ["prompt"] + """ + + # Common field definitions that simple tools can reuse + FILES_FIELD = SchemaBuilder.SIMPLE_FIELD_SCHEMAS["files"] + IMAGES_FIELD = SchemaBuilder.COMMON_FIELD_SCHEMAS["images"] + + @abstractmethod + def get_tool_fields(self) -> dict[str, dict[str, Any]]: + """ + Return tool-specific field definitions. + + This method should return a dictionary mapping field names to their + JSON schema definitions. Common fields (model, temperature, etc.) + are added automatically by the base class. + + Returns: + Dict mapping field names to JSON schema objects + + Example: + return { + "prompt": { + "type": "string", + "description": "The user's question or request", + }, + "files": SimpleTool.FILES_FIELD, # Reuse common field + "max_tokens": { + "type": "integer", + "minimum": 1, + "description": "Maximum tokens for response", + } + } + """ + pass + + def get_required_fields(self) -> list[str]: + """ + Return list of required field names. + + Override this to specify which fields are required for your tool. + The model field is automatically added if in auto mode. + + Returns: + List of required field names + """ + return [] + + def get_annotations(self) -> Optional[dict[str, Any]]: + """ + Return tool annotations. Simple tools are read-only by default. + + All simple tools perform operations without modifying the environment. + They may call external AI models for analysis or conversation, but they + don't write files or make system changes. + + Override this method if your simple tool needs different annotations. + + Returns: + Dictionary with readOnlyHint set to True + """ + return {"readOnlyHint": True} + + def format_response(self, response: str, request, model_info: Optional[dict] = None) -> str: + """ + Format the AI response before returning to the client. + + This is a hook method that subclasses can override to customize + response formatting. The default implementation returns the response as-is. + + Args: + response: The raw response from the AI model + request: The validated request object + model_info: Optional model information dictionary + + Returns: + Formatted response string + """ + return response + + def get_input_schema(self) -> dict[str, Any]: + """ + Generate the complete input schema using SchemaBuilder. + + This method automatically combines: + - Tool-specific fields from get_tool_fields() + - Common fields (temperature, thinking_mode, etc.) + - Model field with proper auto-mode handling + - Required fields from get_required_fields() + + Tools can override this method for custom schema generation while + still benefiting from SimpleTool's convenience methods. + + Returns: + Complete JSON schema for the tool + """ + required_fields = list(self.get_required_fields()) + return SchemaBuilder.build_schema( + tool_specific_fields=self.get_tool_fields(), + required_fields=required_fields, + model_field_schema=self.get_model_field_schema(), + auto_mode=self.is_effective_auto_mode(), + ) + + def get_request_model(self): + """ + Return the request model class. + + Simple tools use the base ToolRequest by default. + Override this if your tool needs a custom request model. + """ + return ToolRequest + + # Hook methods for safe attribute access without hasattr/getattr + + def get_request_model_name(self, request) -> Optional[str]: + """Get model name from request. Override for custom model name handling.""" + try: + return request.model + except AttributeError: + return None + + def get_request_images(self, request) -> list: + """Get images from request. Override for custom image handling.""" + try: + return request.images if request.images is not None else [] + except AttributeError: + return [] + + def get_request_continuation_id(self, request) -> Optional[str]: + """Get continuation_id from request. Override for custom continuation handling.""" + try: + return request.continuation_id + except AttributeError: + return None + + def get_request_prompt(self, request) -> str: + """Get prompt from request. Override for custom prompt handling.""" + try: + return request.prompt + except AttributeError: + return "" + + def get_request_temperature(self, request) -> Optional[float]: + """Get temperature from request. Override for custom temperature handling.""" + try: + return request.temperature + except AttributeError: + return None + + def get_validated_temperature(self, request, model_context: Any) -> tuple[float, list[str]]: + """ + Get temperature from request and validate it against model constraints. + + This is a convenience method that combines temperature extraction and validation + for simple tools. It ensures temperature is within valid range for the model. + + Args: + request: The request object containing temperature + model_context: Model context object containing model info + + Returns: + Tuple of (validated_temperature, warning_messages) + """ + temperature = self.get_request_temperature(request) + if temperature is None: + temperature = self.get_default_temperature() + return self.validate_and_correct_temperature(temperature, model_context) + + def get_request_thinking_mode(self, request) -> Optional[str]: + """Get thinking_mode from request. Override for custom thinking mode handling.""" + try: + return request.thinking_mode + except AttributeError: + return None + + def get_request_files(self, request) -> list: + """Get files from request. Override for custom file handling.""" + try: + return request.files if request.files is not None else [] + except AttributeError: + return [] + + def get_request_as_dict(self, request) -> dict: + """Convert request to dictionary. Override for custom serialization.""" + try: + # Try Pydantic v2 method first + return request.model_dump() + except AttributeError: + try: + # Fall back to Pydantic v1 method + return request.dict() + except AttributeError: + # Last resort - convert to dict manually + return {"prompt": self.get_request_prompt(request)} + + def set_request_files(self, request, files: list) -> None: + """Set files on request. Override for custom file setting.""" + try: + request.files = files + except AttributeError: + # If request doesn't support file setting, ignore silently + pass + + def get_actually_processed_files(self) -> list: + """Get actually processed files. Override for custom file tracking.""" + try: + return self._actually_processed_files + except AttributeError: + return [] + + async def execute(self, arguments: dict[str, Any]) -> list: + """ + Execute the simple tool using the comprehensive flow from old base.py. + + This method replicates the proven execution pattern while using SimpleTool hooks. + """ + import json + import logging + + from mcp.types import TextContent + + from tools.models import ToolOutput + + logger = logging.getLogger(f"tools.{self.get_name()}") + + try: + # Store arguments for access by helper methods + self._current_arguments = arguments + + logger.info(f"🔧 {self.get_name()} tool called with arguments: {list(arguments.keys())}") + + # Validate request using the tool's Pydantic model + request_model = self.get_request_model() + request = request_model(**arguments) + logger.debug(f"Request validation successful for {self.get_name()}") + + # Validate file paths for security + # This prevents path traversal attacks and ensures proper access control + path_error = self._validate_file_paths(request) + if path_error: + error_output = ToolOutput( + status="error", + content=path_error, + content_type="text", + ) + return [TextContent(type="text", text=error_output.model_dump_json())] + + # Handle model resolution like old base.py + model_name = self.get_request_model_name(request) + if not model_name: + from config import DEFAULT_MODEL + + model_name = DEFAULT_MODEL + + # Store the current model name for later use + self._current_model_name = model_name + + # Handle model context from arguments (for in-process testing) + if "_model_context" in arguments: + self._model_context = arguments["_model_context"] + logger.debug(f"{self.get_name()}: Using model context from arguments") + else: + # Create model context if not provided + from utils.model_context import ModelContext + + self._model_context = ModelContext(model_name) + logger.debug(f"{self.get_name()}: Created model context for {model_name}") + + # Get images if present + images = self.get_request_images(request) + continuation_id = self.get_request_continuation_id(request) + + # Handle conversation history and prompt preparation + if continuation_id: + # Check if conversation history is already embedded + field_value = self.get_request_prompt(request) + if "=== CONVERSATION HISTORY ===" in field_value: + # Use pre-embedded history + prompt = field_value + logger.debug(f"{self.get_name()}: Using pre-embedded conversation history") + else: + # No embedded history - reconstruct it (for in-process calls) + logger.debug(f"{self.get_name()}: No embedded history found, reconstructing conversation") + + # Get thread context + from utils.conversation_memory import add_turn, build_conversation_history, get_thread + + thread_context = get_thread(continuation_id) + + if thread_context: + # Add user's new input to conversation + user_prompt = self.get_request_prompt(request) + user_files = self.get_request_files(request) + if user_prompt: + add_turn(continuation_id, "user", user_prompt, files=user_files) + + # Get updated thread context after adding the turn + thread_context = get_thread(continuation_id) + logger.debug( + f"{self.get_name()}: Retrieved updated thread with {len(thread_context.turns)} turns" + ) + + # Build conversation history with updated thread context + conversation_history, conversation_tokens = build_conversation_history( + thread_context, self._model_context + ) + + # Get the base prompt from the tool + base_prompt = await self.prepare_prompt(request) + + # Combine with conversation history + if conversation_history: + prompt = f"{conversation_history}\n\n=== NEW USER INPUT ===\n{base_prompt}" + else: + prompt = base_prompt + else: + # Thread not found, prepare normally + logger.warning(f"Thread {continuation_id} not found, preparing prompt normally") + prompt = await self.prepare_prompt(request) + else: + # New conversation, prepare prompt normally + prompt = await self.prepare_prompt(request) + + # Add follow-up instructions for new conversations + from server import get_follow_up_instructions + + follow_up_instructions = get_follow_up_instructions(0) + prompt = f"{prompt}\n\n{follow_up_instructions}" + logger.debug( + f"Added follow-up instructions for new {self.get_name()} conversation" + ) # Validate images if any were provided + if images: + image_validation_error = self._validate_image_limits( + images, model_context=self._model_context, continuation_id=continuation_id + ) + if image_validation_error: + return [TextContent(type="text", text=json.dumps(image_validation_error, ensure_ascii=False))] + + # Get and validate temperature against model constraints + temperature, temp_warnings = self.get_validated_temperature(request, self._model_context) + + # Log any temperature corrections + for warning in temp_warnings: + # Get thinking mode with defaults + logger.warning(warning) + thinking_mode = self.get_request_thinking_mode(request) + if thinking_mode is None: + thinking_mode = self.get_default_thinking_mode() + + # Get the provider from model context (clean OOP - no re-fetching) + provider = self._model_context.provider + + # Get system prompt for this tool + base_system_prompt = self.get_system_prompt() + language_instruction = self.get_language_instruction() + system_prompt = language_instruction + base_system_prompt + + # Generate AI response using the provider + logger.info(f"Sending request to {provider.get_provider_type().value} API for {self.get_name()}") + logger.info( + f"Using model: {self._model_context.model_name} via {provider.get_provider_type().value} provider" + ) + + # Estimate tokens for logging + from utils.token_utils import estimate_tokens + + estimated_tokens = estimate_tokens(prompt) + logger.debug(f"Prompt length: {len(prompt)} characters (~{estimated_tokens:,} tokens)") + + # Resolve model capabilities for feature gating + capabilities = self._model_context.capabilities + supports_thinking = capabilities.supports_extended_thinking + + # Generate content with provider abstraction + model_response = provider.generate_content( + prompt=prompt, + model_name=self._current_model_name, + system_prompt=system_prompt, + temperature=temperature, + thinking_mode=thinking_mode if supports_thinking else None, + images=images if images else None, + ) + + logger.info(f"Received response from {provider.get_provider_type().value} API for {self.get_name()}") + + # Process the model's response + if model_response.content: + raw_text = model_response.content + + # Create model info for conversation tracking + model_info = { + "provider": provider, + "model_name": self._current_model_name, + "model_response": model_response, + } + + # Parse response using the same logic as old base.py + tool_output = self._parse_response(raw_text, request, model_info) + logger.info(f"✅ {self.get_name()} tool completed successfully") + + else: + # Handle cases where the model couldn't generate a response + metadata = model_response.metadata or {} + finish_reason = metadata.get("finish_reason", "Unknown") + + if metadata.get("is_blocked_by_safety"): + # Specific handling for content safety blocks + safety_details = metadata.get("safety_feedback") or "details not provided" + logger.warning( + f"Response blocked by content safety policy for {self.get_name()}. " + f"Reason: {finish_reason}, Details: {safety_details}" + ) + tool_output = ToolOutput( + status="error", + content="Your request was blocked by the content safety policy. " + "Please try modifying your prompt.", + content_type="text", + ) + else: + # Handle other empty responses - could be legitimate completion or unclear blocking + if finish_reason == "STOP": + # Model completed normally but returned empty content - retry with clarification + logger.info( + f"Model completed with empty response for {self.get_name()}, retrying with clarification" + ) + + # Retry the same request with modified prompt asking for explicit response + original_prompt = prompt + retry_prompt = f"{original_prompt}\n\nIMPORTANT: Please provide a substantive response. If you cannot respond to the above request, please explain why and suggest alternatives." + + try: + retry_response = provider.generate_content( + prompt=retry_prompt, + model_name=self._current_model_name, + system_prompt=system_prompt, + temperature=temperature, + thinking_mode=thinking_mode if supports_thinking else None, + images=images if images else None, + ) + + if retry_response.content: + # Successful retry - use the retry response + logger.info(f"Retry successful for {self.get_name()}") + raw_text = retry_response.content + + # Update model info for the successful retry + model_info = { + "provider": provider, + "model_name": self._current_model_name, + "model_response": retry_response, + } + + # Parse the retry response + tool_output = self._parse_response(raw_text, request, model_info) + logger.info(f"✅ {self.get_name()} tool completed successfully after retry") + else: + # Retry also failed - inspect metadata to find out why + retry_metadata = retry_response.metadata or {} + if retry_metadata.get("is_blocked_by_safety"): + # The retry was blocked by safety filters + safety_details = retry_metadata.get("safety_feedback") or "details not provided" + logger.warning( + f"Retry for {self.get_name()} was blocked by content safety policy. " + f"Details: {safety_details}" + ) + tool_output = ToolOutput( + status="error", + content="Your request was also blocked by the content safety policy after a retry. " + "Please try rephrasing your prompt significantly.", + content_type="text", + ) + else: + # Retry failed for other reasons (e.g., another STOP) + tool_output = ToolOutput( + status="error", + content="The model repeatedly returned empty responses. This may indicate content filtering or a model issue.", + content_type="text", + ) + except Exception as retry_error: + logger.warning(f"Retry failed for {self.get_name()}: {retry_error}") + tool_output = ToolOutput( + status="error", + content=f"Model returned empty response and retry failed: {str(retry_error)}", + content_type="text", + ) + else: + # Non-STOP finish reasons are likely actual errors + logger.warning( + f"Response blocked or incomplete for {self.get_name()}. Finish reason: {finish_reason}" + ) + tool_output = ToolOutput( + status="error", + content=f"Response blocked or incomplete. Finish reason: {finish_reason}", + content_type="text", + ) + + # Return the tool output as TextContent + return [TextContent(type="text", text=tool_output.model_dump_json())] + + except Exception as e: + # Special handling for MCP size check errors + if str(e).startswith("MCP_SIZE_CHECK:"): + # Extract the JSON content after the prefix + json_content = str(e)[len("MCP_SIZE_CHECK:") :] + return [TextContent(type="text", text=json_content)] + + logger.error(f"Error in {self.get_name()}: {str(e)}") + error_output = ToolOutput( + status="error", + content=f"Error in {self.get_name()}: {str(e)}", + content_type="text", + ) + return [TextContent(type="text", text=error_output.model_dump_json())] + + def _parse_response(self, raw_text: str, request, model_info: Optional[dict] = None): + """ + Parse the raw response and format it using the hook method. + + This simplified version focuses on the SimpleTool pattern: format the response + using the format_response hook, then handle conversation continuation. + """ + from tools.models import ToolOutput + + # Format the response using the hook method + formatted_response = self.format_response(raw_text, request, model_info) + + # Handle conversation continuation like old base.py + continuation_id = self.get_request_continuation_id(request) + if continuation_id: + self._record_assistant_turn(continuation_id, raw_text, request, model_info) + + # Create continuation offer like old base.py + continuation_data = self._create_continuation_offer(request, model_info) + if continuation_data: + return self._create_continuation_offer_response(formatted_response, continuation_data, request, model_info) + else: + # Build metadata with model and provider info for success response + metadata = {} + if model_info: + model_name = model_info.get("model_name") + if model_name: + metadata["model_used"] = model_name + provider = model_info.get("provider") + if provider: + # Handle both provider objects and string values + if isinstance(provider, str): + metadata["provider_used"] = provider + else: + try: + metadata["provider_used"] = provider.get_provider_type().value + except AttributeError: + # Fallback if provider doesn't have get_provider_type method + metadata["provider_used"] = str(provider) + + return ToolOutput( + status="success", + content=formatted_response, + content_type="text", + metadata=metadata if metadata else None, + ) + + def _create_continuation_offer(self, request, model_info: Optional[dict] = None): + """Create continuation offer following old base.py pattern""" + continuation_id = self.get_request_continuation_id(request) + + try: + from utils.conversation_memory import create_thread, get_thread + + if continuation_id: + # Existing conversation + thread_context = get_thread(continuation_id) + if thread_context and thread_context.turns: + turn_count = len(thread_context.turns) + from utils.conversation_memory import MAX_CONVERSATION_TURNS + + if turn_count >= MAX_CONVERSATION_TURNS - 1: + return None # No more turns allowed + + remaining_turns = MAX_CONVERSATION_TURNS - turn_count - 1 + return { + "continuation_id": continuation_id, + "remaining_turns": remaining_turns, + "note": f"Claude can continue this conversation for {remaining_turns} more exchanges.", + } + else: + # New conversation - create thread and offer continuation + # Convert request to dict for initial_context + initial_request_dict = self.get_request_as_dict(request) + + new_thread_id = create_thread(tool_name=self.get_name(), initial_request=initial_request_dict) + + # Add the initial user turn to the new thread + from utils.conversation_memory import MAX_CONVERSATION_TURNS, add_turn + + user_prompt = self.get_request_prompt(request) + user_files = self.get_request_files(request) + user_images = self.get_request_images(request) + + # Add user's initial turn + add_turn( + new_thread_id, "user", user_prompt, files=user_files, images=user_images, tool_name=self.get_name() + ) + + return { + "continuation_id": new_thread_id, + "remaining_turns": MAX_CONVERSATION_TURNS - 1, + "note": f"Claude can continue this conversation for {MAX_CONVERSATION_TURNS - 1} more exchanges.", + } + except Exception: + return None + + def _create_continuation_offer_response( + self, content: str, continuation_data: dict, request, model_info: Optional[dict] = None + ): + """Create response with continuation offer following old base.py pattern""" + from tools.models import ContinuationOffer, ToolOutput + + try: + if not self.get_request_continuation_id(request): + self._record_assistant_turn( + continuation_data["continuation_id"], + content, + request, + model_info, + ) + + continuation_offer = ContinuationOffer( + continuation_id=continuation_data["continuation_id"], + note=continuation_data["note"], + remaining_turns=continuation_data["remaining_turns"], + ) + + # Build metadata with model and provider info + metadata = {"tool_name": self.get_name(), "conversation_ready": True} + if model_info: + model_name = model_info.get("model_name") + if model_name: + metadata["model_used"] = model_name + provider = model_info.get("provider") + if provider: + # Handle both provider objects and string values + if isinstance(provider, str): + metadata["provider_used"] = provider + else: + try: + metadata["provider_used"] = provider.get_provider_type().value + except AttributeError: + # Fallback if provider doesn't have get_provider_type method + metadata["provider_used"] = str(provider) + + return ToolOutput( + status="continuation_available", + content=content, + content_type="text", + continuation_offer=continuation_offer, + metadata=metadata, + ) + except Exception: + # Fallback to simple success if continuation offer fails + return ToolOutput(status="success", content=content, content_type="text") + + def _record_assistant_turn( + self, continuation_id: str, response_text: str, request, model_info: Optional[dict] + ) -> None: + """Persist an assistant response in conversation memory.""" + + if not continuation_id: + return + + from utils.conversation_memory import add_turn + + model_provider = None + model_name = None + model_metadata = None + + if model_info: + provider = model_info.get("provider") + if provider: + if isinstance(provider, str): + model_provider = provider + else: + try: + model_provider = provider.get_provider_type().value + except AttributeError: + model_provider = str(provider) + model_name = model_info.get("model_name") + model_response = model_info.get("model_response") + if model_response: + model_metadata = {"usage": model_response.usage, "metadata": model_response.metadata} + + add_turn( + continuation_id, + "assistant", + response_text, + files=self.get_request_files(request), + images=self.get_request_images(request), + tool_name=self.get_name(), + model_provider=model_provider, + model_name=model_name, + model_metadata=model_metadata, + ) + + # Convenience methods for common tool patterns + + def build_standard_prompt( + self, system_prompt: str, user_content: str, request, file_context_title: str = "CONTEXT FILES" + ) -> str: + """ + Build a standard prompt with system prompt, user content, and optional files. + + This is a convenience method that handles the common pattern of: + 1. Adding file content if present + 2. Checking token limits + 3. Adding web search instructions + 4. Combining everything into a well-formatted prompt + + Args: + system_prompt: The system prompt for the tool + user_content: The main user request/content + request: The validated request object + file_context_title: Title for the file context section + + Returns: + Complete formatted prompt ready for the AI model + """ + # Add context files if provided + files = self.get_request_files(request) + if files: + file_content, processed_files = self._prepare_file_content_for_prompt( + files, + self.get_request_continuation_id(request), + "Context files", + model_context=getattr(self, "_model_context", None), + ) + self._actually_processed_files = processed_files + if file_content: + user_content = f"{user_content}\n\n=== {file_context_title} ===\n{file_content}\n=== END CONTEXT ====" + + # Check token limits - only validate original user prompt, not conversation history + content_to_validate = self.get_prompt_content_for_size_validation(user_content) + self._validate_token_limit(content_to_validate, "Content") + + # Add standardized web search guidance + websearch_instruction = self.get_websearch_instruction(True, self.get_websearch_guidance()) + + # Combine system prompt with user content + full_prompt = f"""{system_prompt}{websearch_instruction} + +=== USER REQUEST === +{user_content} +=== END REQUEST === + +Please provide a thoughtful, comprehensive response:""" + + return full_prompt + + def get_prompt_content_for_size_validation(self, user_content: str) -> str: + """ + Override to use original user prompt for size validation when conversation history is embedded. + + When server.py embeds conversation history into the prompt field, it also stores + the original user prompt in _original_user_prompt. We use that for size validation + to avoid incorrectly triggering size limits due to conversation history. + + Args: + user_content: The user content (may include conversation history) + + Returns: + The original user prompt if available, otherwise the full user content + """ + # Check if we have the current arguments from execute() method + current_args = getattr(self, "_current_arguments", None) + if current_args: + # If server.py embedded conversation history, it stores original prompt separately + original_user_prompt = current_args.get("_original_user_prompt") + if original_user_prompt is not None: + # Use original user prompt for size validation (excludes conversation history) + return original_user_prompt + + # Fallback to default behavior (validate full user content) + return user_content + + def get_websearch_guidance(self) -> Optional[str]: + """ + Return tool-specific web search guidance. + + Override this to provide tool-specific guidance for when web searches + would be helpful. Return None to use the default guidance. + + Returns: + Tool-specific web search guidance or None for default + """ + return None + + def handle_prompt_file_with_fallback(self, request) -> str: + """ + Handle prompt.txt files with fallback to request field. + + This is a convenience method for tools that accept prompts either + as a field or as a prompt.txt file. It handles the extraction + and validation automatically. + + Args: + request: The validated request object + + Returns: + The effective prompt content + + Raises: + ValueError: If prompt is too large for MCP transport + """ + # Check for prompt.txt in files + files = self.get_request_files(request) + if files: + prompt_content, updated_files = self.handle_prompt_file(files) + + # Update request files list if needed + if updated_files is not None: + self.set_request_files(request, updated_files) + else: + prompt_content = None + + # Use prompt.txt content if available, otherwise use the prompt field + user_content = prompt_content if prompt_content else self.get_request_prompt(request) + + # Check user input size at MCP transport boundary (excluding conversation history) + validation_content = self.get_prompt_content_for_size_validation(user_content) + size_check = self.check_prompt_size(validation_content) + if size_check: + from tools.models import ToolOutput + + raise ValueError(f"MCP_SIZE_CHECK:{ToolOutput(**size_check).model_dump_json()}") + + return user_content + + def get_chat_style_websearch_guidance(self) -> str: + """ + Get Chat tool-style web search guidance. + + Returns web search guidance that matches the original Chat tool pattern. + This is useful for tools that want to maintain the same search behavior. + + Returns: + Web search guidance text + """ + return """When discussing topics, consider if searches for these would help: +- Documentation for any technologies or concepts mentioned +- Current best practices and patterns +- Recent developments or updates +- Community discussions and solutions""" + + def supports_custom_request_model(self) -> bool: + """ + Indicate whether this tool supports custom request models. + + Simple tools support custom request models by default. Tools that override + get_request_model() to return something other than ToolRequest should + return True here. + + Returns: + True if the tool uses a custom request model + """ + return self.get_request_model() != ToolRequest + + def _validate_file_paths(self, request) -> Optional[str]: + """ + Validate that all file paths in the request are absolute paths. + + This is a security measure to prevent path traversal attacks and ensure + proper access control. All file paths must be absolute (starting with '/'). + + Args: + request: The validated request object + + Returns: + Optional[str]: Error message if validation fails, None if all paths are valid + """ + import os + + # Check if request has 'files' attribute (used by most tools) + files = self.get_request_files(request) + if files: + for file_path in files: + if not os.path.isabs(file_path): + return ( + f"Error: All file paths must be FULL absolute paths to real files / folders - DO NOT SHORTEN. " + f"Received relative path: {file_path}\n" + f"Please provide the full absolute path starting with '/' (must be FULL absolute paths to real files / folders - DO NOT SHORTEN)" + ) + + return None + + def prepare_chat_style_prompt(self, request, system_prompt: str = None) -> str: + """ + Prepare a prompt using Chat tool-style patterns. + + This convenience method replicates the Chat tool's prompt preparation logic: + 1. Handle prompt.txt file if present + 2. Add file context with specific formatting + 3. Add web search guidance + 4. Format with system prompt + + Args: + request: The validated request object + system_prompt: System prompt to use (uses get_system_prompt() if None) + + Returns: + Complete formatted prompt + """ + # Use provided system prompt or get from tool + if system_prompt is None: + system_prompt = self.get_system_prompt() + + # Get user content (handles prompt.txt files) + user_content = self.handle_prompt_file_with_fallback(request) + + # Build standard prompt with Chat-style web search guidance + websearch_guidance = self.get_chat_style_websearch_guidance() + + # Override the websearch guidance temporarily + original_guidance = self.get_websearch_guidance + self.get_websearch_guidance = lambda: websearch_guidance + + try: + full_prompt = self.build_standard_prompt(system_prompt, user_content, request, "CONTEXT FILES") + finally: + # Restore original guidance method + self.get_websearch_guidance = original_guidance + + return full_prompt diff --git a/tools/version.py b/tools/version.py new file mode 100644 index 0000000..3acaf7b --- /dev/null +++ b/tools/version.py @@ -0,0 +1,368 @@ +""" +Version Tool - Display Zen MCP Server version and system information + +This tool provides version information about the Zen MCP Server including +version number, last update date, author, and basic system information. +It also checks for updates from the GitHub repository. +""" + +import logging +import platform +import re +import sys +from pathlib import Path +from typing import Any, Optional + +try: + from urllib.error import HTTPError, URLError + from urllib.request import urlopen + + HAS_URLLIB = True +except ImportError: + HAS_URLLIB = False + +from mcp.types import TextContent + +from config import __author__, __updated__, __version__ +from tools.models import ToolModelCategory, ToolOutput +from tools.shared.base_models import ToolRequest +from tools.shared.base_tool import BaseTool + +logger = logging.getLogger(__name__) + + +def parse_version(version_str: str) -> tuple[int, int, int]: + """ + Parse version string to tuple of integers for comparison. + + Args: + version_str: Version string like "5.5.5" + + Returns: + Tuple of (major, minor, patch) as integers + """ + try: + parts = version_str.strip().split(".") + if len(parts) >= 3: + return (int(parts[0]), int(parts[1]), int(parts[2])) + elif len(parts) == 2: + return (int(parts[0]), int(parts[1]), 0) + elif len(parts) == 1: + return (int(parts[0]), 0, 0) + else: + return (0, 0, 0) + except (ValueError, IndexError): + return (0, 0, 0) + + +def compare_versions(current: str, remote: str) -> int: + """ + Compare two version strings. + + Args: + current: Current version string + remote: Remote version string + + Returns: + -1 if current < remote (update available) + 0 if current == remote (up to date) + 1 if current > remote (ahead of remote) + """ + current_tuple = parse_version(current) + remote_tuple = parse_version(remote) + + if current_tuple < remote_tuple: + return -1 + elif current_tuple > remote_tuple: + return 1 + else: + return 0 + + +def fetch_github_version() -> Optional[tuple[str, str]]: + """ + Fetch the latest version information from GitHub repository. + + Returns: + Tuple of (version, last_updated) if successful, None if failed + """ + if not HAS_URLLIB: + logger.warning("urllib not available, cannot check for updates") + return None + + github_url = "https://raw.githubusercontent.com/BeehiveInnovations/zen-mcp-server/main/config.py" + + try: + # Set a 10-second timeout + with urlopen(github_url, timeout=10) as response: + if response.status != 200: + logger.warning(f"HTTP error while checking GitHub: {response.status}") + return None + + content = response.read().decode("utf-8") + + # Extract version using regex + version_match = re.search(r'__version__\s*=\s*["\']([^"\']+)["\']', content) + updated_match = re.search(r'__updated__\s*=\s*["\']([^"\']+)["\']', content) + + if version_match: + remote_version = version_match.group(1) + remote_updated = updated_match.group(1) if updated_match else "Unknown" + return (remote_version, remote_updated) + else: + logger.warning("Could not parse version from GitHub config.py") + return None + + except HTTPError as e: + logger.warning(f"HTTP error while checking GitHub: {e.code}") + return None + except URLError as e: + logger.warning(f"URL error while checking GitHub: {e.reason}") + return None + except Exception as e: + logger.warning(f"Error checking GitHub for updates: {e}") + return None + + +class VersionTool(BaseTool): + """ + Tool for displaying Zen MCP Server version and system information. + + This tool provides: + - Current server version + - Last update date + - Author information + - Python version + - Platform information + """ + + def get_name(self) -> str: + return "version" + + def get_description(self) -> str: + return "Get server version, configuration details, and list of available tools." + + def get_input_schema(self) -> dict[str, Any]: + """Return the JSON schema for the tool's input""" + return { + "type": "object", + "properties": {"model": {"type": "string", "description": "Model to use (ignored by version tool)"}}, + "required": [], + } + + def get_annotations(self) -> Optional[dict[str, Any]]: + """Return tool annotations indicating this is a read-only tool""" + return {"readOnlyHint": True} + + def get_system_prompt(self) -> str: + """No AI model needed for this tool""" + return "" + + def get_request_model(self): + """Return the Pydantic model for request validation.""" + return ToolRequest + + def requires_model(self) -> bool: + return False + + async def prepare_prompt(self, request: ToolRequest) -> str: + """Not used for this utility tool""" + return "" + + def format_response(self, response: str, request: ToolRequest, model_info: dict = None) -> str: + """Not used for this utility tool""" + return response + + async def execute(self, arguments: dict[str, Any]) -> list[TextContent]: + """ + Display Zen MCP Server version and system information. + + This overrides the base class execute to provide direct output without AI model calls. + + Args: + arguments: Standard tool arguments (none required) + + Returns: + Formatted version and system information + """ + output_lines = ["# Zen MCP Server Version\n"] + + # Server version information + output_lines.append("## Server Information") + output_lines.append(f"**Current Version**: {__version__}") + output_lines.append(f"**Last Updated**: {__updated__}") + output_lines.append(f"**Author**: {__author__}") + + model_selection_metadata = {"mode": "unknown", "default_model": None} + model_selection_display = "Model selection status unavailable" + + # Model selection configuration + try: + from config import DEFAULT_MODEL + from tools.shared.base_tool import BaseTool + + auto_mode = BaseTool.is_effective_auto_mode(self) + if auto_mode: + output_lines.append( + "**Model Selection**: Auto model selection mode (call `listmodels` to inspect options)" + ) + model_selection_metadata = {"mode": "auto", "default_model": DEFAULT_MODEL} + model_selection_display = "Auto model selection (use `listmodels` for options)" + else: + output_lines.append(f"**Model Selection**: Default model set to `{DEFAULT_MODEL}`") + model_selection_metadata = {"mode": "default", "default_model": DEFAULT_MODEL} + model_selection_display = f"Default model: `{DEFAULT_MODEL}`" + except Exception as exc: + logger.debug(f"Could not determine model selection mode: {exc}") + + output_lines.append("") + output_lines.append("## Quick Summary — relay everything below") + output_lines.append(f"- Version `{__version__}` (updated {__updated__})") + output_lines.append(f"- {model_selection_display}") + output_lines.append("- Run `listmodels` for the complete model catalog and capabilities") + output_lines.append("") + + # Try to get client information + try: + # We need access to the server instance + # This is a bit hacky but works for now + import server as server_module + from utils.client_info import format_client_info, get_client_info_from_context + + client_info = get_client_info_from_context(server_module.server) + if client_info: + formatted = format_client_info(client_info) + output_lines.append(f"**Connected Client**: {formatted}") + except Exception as e: + logger.debug(f"Could not get client info: {e}") + + # Get the current working directory (MCP server location) + current_path = Path.cwd() + output_lines.append(f"**Installation Path**: `{current_path}`") + output_lines.append("") + output_lines.append("## Agent Reporting Guidance") + output_lines.append( + "Agents MUST report: version, model-selection status, configured providers, and available-model count." + ) + output_lines.append("Repeat the quick-summary bullets verbatim in your reply.") + output_lines.append("Reference `listmodels` when users ask about model availability or capabilities.") + output_lines.append("") + + # Check for updates from GitHub + output_lines.append("## Update Status") + + try: + github_info = fetch_github_version() + + if github_info: + remote_version, remote_updated = github_info + comparison = compare_versions(__version__, remote_version) + + output_lines.append(f"**Latest Version (GitHub)**: {remote_version}") + output_lines.append(f"**Latest Updated**: {remote_updated}") + + if comparison < 0: + # Update available + output_lines.append("") + output_lines.append("🚀 **UPDATE AVAILABLE!**") + output_lines.append( + f"Your version `{__version__}` is older than the latest version `{remote_version}`" + ) + output_lines.append("") + output_lines.append("**To update:**") + output_lines.append("```bash") + output_lines.append(f"cd {current_path}") + output_lines.append("git pull") + output_lines.append("```") + output_lines.append("") + output_lines.append("*Note: Restart your session after updating to use the new version.*") + elif comparison == 0: + # Up to date + output_lines.append("") + output_lines.append("✅ **UP TO DATE**") + output_lines.append("You are running the latest version.") + else: + # Ahead of remote (development version) + output_lines.append("") + output_lines.append("🔬 **DEVELOPMENT VERSION**") + output_lines.append( + f"Your version `{__version__}` is ahead of the published version `{remote_version}`" + ) + output_lines.append("You may be running a development or custom build.") + else: + output_lines.append("❌ **Could not check for updates**") + output_lines.append("Unable to connect to GitHub or parse version information.") + output_lines.append("Check your internet connection or try again later.") + + except Exception as e: + logger.error(f"Error during version check: {e}") + output_lines.append("❌ **Error checking for updates**") + output_lines.append(f"Error: {str(e)}") + + output_lines.append("") + + # Configuration information + output_lines.append("## Configuration") + + # Check for configured providers + try: + from providers.registry import ModelProviderRegistry + from providers.shared import ProviderType + + provider_status = [] + + # Check each provider type + provider_types = [ + ProviderType.GOOGLE, + ProviderType.OPENAI, + ProviderType.XAI, + ProviderType.DIAL, + ProviderType.OPENROUTER, + ProviderType.CUSTOM, + ] + provider_names = ["Google Gemini", "OpenAI", "X.AI", "DIAL", "OpenRouter", "Custom/Local"] + + for provider_type, provider_name in zip(provider_types, provider_names): + provider = ModelProviderRegistry.get_provider(provider_type) + status = "✅ Configured" if provider is not None else "❌ Not configured" + provider_status.append(f"- **{provider_name}**: {status}") + + output_lines.append("**Providers**:") + output_lines.extend(provider_status) + + # Get total available models + try: + available_models = ModelProviderRegistry.get_available_models(respect_restrictions=True) + output_lines.append(f"\n\n**Available Models**: {len(available_models)}") + except Exception: + output_lines.append("\n\n**Available Models**: Unknown") + + except Exception as e: + logger.warning(f"Error checking provider configuration: {e}") + output_lines.append("\n\n**Providers**: Error checking configuration") + + output_lines.append("") + + # Format output + content = "\n".join(output_lines) + + tool_output = ToolOutput( + status="success", + content=content, + content_type="text", + metadata={ + "tool_name": self.name, + "server_version": __version__, + "last_updated": __updated__, + "python_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", + "platform": f"{platform.system()} {platform.release()}", + "model_selection_mode": model_selection_metadata["mode"], + "default_model": model_selection_metadata["default_model"], + }, + ) + + return [TextContent(type="text", text=tool_output.model_dump_json())] + + def get_model_category(self) -> ToolModelCategory: + """Return the model category for this tool.""" + return ToolModelCategory.FAST_RESPONSE # Simple version info, no AI needed diff --git a/tools/workflow/__init__.py b/tools/workflow/__init__.py new file mode 100644 index 0000000..9603937 --- /dev/null +++ b/tools/workflow/__init__.py @@ -0,0 +1,22 @@ +""" +Workflow tools for Zen MCP. + +Workflow tools follow a multi-step pattern with forced pauses between steps +to encourage thorough investigation and analysis. They inherit from WorkflowTool +which combines BaseTool with BaseWorkflowMixin. + +Available workflow tools: +- debug: Systematic investigation and root cause analysis +- planner: Sequential planning (special case - no AI calls) +- analyze: Code analysis workflow +- codereview: Code review workflow +- precommit: Pre-commit validation workflow +- refactor: Refactoring analysis workflow +- thinkdeep: Deep thinking workflow +""" + +from .base import WorkflowTool +from .schema_builders import WorkflowSchemaBuilder +from .workflow_mixin import BaseWorkflowMixin + +__all__ = ["WorkflowTool", "WorkflowSchemaBuilder", "BaseWorkflowMixin"] diff --git a/tools/workflow/base.py b/tools/workflow/base.py new file mode 100644 index 0000000..fb085d4 --- /dev/null +++ b/tools/workflow/base.py @@ -0,0 +1,444 @@ +""" +Base class for workflow MCP tools. + +Workflow tools follow a multi-step pattern: +1. Claude calls tool with work step data +2. Tool tracks findings and progress +3. Tool forces Claude to pause and investigate between steps +4. Once work is complete, tool calls external AI model for expert analysis +5. Tool returns structured response combining investigation + expert analysis + +They combine BaseTool's capabilities with BaseWorkflowMixin's workflow functionality +and use SchemaBuilder for consistent schema generation. +""" + +from abc import abstractmethod +from typing import Any, Optional + +from tools.shared.base_models import WorkflowRequest +from tools.shared.base_tool import BaseTool + +from .schema_builders import WorkflowSchemaBuilder +from .workflow_mixin import BaseWorkflowMixin + + +class WorkflowTool(BaseTool, BaseWorkflowMixin): + """ + Base class for workflow (multi-step) tools. + + Workflow tools perform systematic multi-step work with expert analysis. + They benefit from: + - Automatic workflow orchestration from BaseWorkflowMixin + - Automatic schema generation using SchemaBuilder + - Inherited conversation handling and file processing from BaseTool + - Progress tracking with ConsolidatedFindings + - Expert analysis integration + + To create a workflow tool: + 1. Inherit from WorkflowTool + 2. Tool name is automatically provided by get_name() method + 3. Implement get_required_actions() for step guidance + 4. Implement should_call_expert_analysis() for completion criteria + 5. Implement prepare_expert_analysis_context() for expert prompts + 6. Optionally implement get_tool_fields() for additional fields + 7. Optionally override workflow behavior methods + + Example: + class DebugTool(WorkflowTool): + # get_name() is inherited from BaseTool + + def get_tool_fields(self) -> Dict[str, Dict[str, Any]]: + return { + "hypothesis": { + "type": "string", + "description": "Current theory about the issue", + } + } + + def get_required_actions( + self, step_number: int, confidence: str, findings: str, total_steps: int + ) -> List[str]: + return ["Examine relevant code files", "Trace execution flow", "Check error logs"] + + def should_call_expert_analysis(self, consolidated_findings) -> bool: + return len(consolidated_findings.relevant_files) > 0 + """ + + def __init__(self): + """Initialize WorkflowTool with proper multiple inheritance.""" + BaseTool.__init__(self) + BaseWorkflowMixin.__init__(self) + + def get_tool_fields(self) -> dict[str, dict[str, Any]]: + """ + Return tool-specific field definitions beyond the standard workflow fields. + + Workflow tools automatically get all standard workflow fields: + - step, step_number, total_steps, next_step_required + - findings, files_checked, relevant_files, relevant_context + - issues_found, confidence, hypothesis, backtrack_from_step + - plus common fields (model, temperature, etc.) + + Override this method to add additional tool-specific fields. + + Returns: + Dict mapping field names to JSON schema objects + + Example: + return { + "severity_filter": { + "type": "string", + "enum": ["low", "medium", "high"], + "description": "Minimum severity level to report", + } + } + """ + return {} + + def get_required_fields(self) -> list[str]: + """ + Return additional required fields beyond the standard workflow requirements. + + Workflow tools automatically require: + - step, step_number, total_steps, next_step_required, findings + - model (if in auto mode) + + Override this to add additional required fields. + + Returns: + List of additional required field names + """ + return [] + + def get_annotations(self) -> Optional[dict[str, Any]]: + """ + Return tool annotations. Workflow tools are read-only by default. + + All workflow tools perform analysis and investigation without modifying + the environment. They may call external AI models for expert analysis, + but they don't write files or make system changes. + + Override this method if your workflow tool needs different annotations. + + Returns: + Dictionary with readOnlyHint set to True + """ + return {"readOnlyHint": True} + + def get_input_schema(self) -> dict[str, Any]: + """ + Generate the complete input schema using SchemaBuilder. + + This method automatically combines: + - Standard workflow fields (step, findings, etc.) + - Common fields (temperature, thinking_mode, etc.) + - Model field with proper auto-mode handling + - Tool-specific fields from get_tool_fields() + - Required fields from get_required_fields() + + Returns: + Complete JSON schema for the workflow tool + """ + return WorkflowSchemaBuilder.build_schema( + tool_specific_fields=self.get_tool_fields(), + required_fields=self.get_required_fields(), + model_field_schema=self.get_model_field_schema(), + auto_mode=self.is_effective_auto_mode(), + tool_name=self.get_name(), + ) + + def get_workflow_request_model(self): + """ + Return the workflow request model class. + + Workflow tools use WorkflowRequest by default, which includes + all the standard workflow fields. Override this if your tool + needs a custom request model. + """ + return WorkflowRequest + + # Implement the abstract method from BaseWorkflowMixin + def get_work_steps(self, request) -> list[str]: + """ + Default implementation - workflow tools typically don't need predefined steps. + + The workflow is driven by Claude's investigation process rather than + predefined steps. Override this if your tool needs specific step guidance. + """ + return [] + + # Default implementations for common workflow patterns + + def get_standard_required_actions(self, step_number: int, confidence: str, base_actions: list[str]) -> list[str]: + """ + Helper method to generate standard required actions based on confidence and step. + + This provides common patterns that most workflow tools can use: + - Early steps: broad exploration + - Low confidence: deeper investigation + - Medium/high confidence: verification and confirmation + + Args: + step_number: Current step number + confidence: Current confidence level + base_actions: Tool-specific base actions + + Returns: + List of required actions appropriate for the current state + """ + if step_number == 1: + # Initial investigation + return [ + "Search for code related to the reported issue or symptoms", + "Examine relevant files and understand the current implementation", + "Understand the project structure and locate relevant modules", + "Identify how the affected functionality is supposed to work", + ] + elif confidence in ["exploring", "low"]: + # Need deeper investigation + return base_actions + [ + "Trace method calls and data flow through the system", + "Check for edge cases, boundary conditions, and assumptions in the code", + "Look for related configuration, dependencies, or external factors", + ] + elif confidence in ["medium", "high"]: + # Close to solution - need confirmation + return base_actions + [ + "Examine the exact code sections where you believe the issue occurs", + "Trace the execution path that leads to the failure", + "Verify your hypothesis with concrete code evidence", + "Check for any similar patterns elsewhere in the codebase", + ] + else: + # General continued investigation + return base_actions + [ + "Continue examining the code paths identified in your hypothesis", + "Gather more evidence using appropriate investigation tools", + "Test edge cases and boundary conditions", + "Look for patterns that confirm or refute your theory", + ] + + def should_call_expert_analysis_default(self, consolidated_findings) -> bool: + """ + Default implementation for expert analysis decision. + + This provides a reasonable default that most workflow tools can use: + - Call expert analysis if we have relevant files or significant findings + - Skip if confidence is "certain" (handled by the workflow mixin) + + Override this for tool-specific logic. + + Args: + consolidated_findings: The consolidated findings from all work steps + + Returns: + True if expert analysis should be called + """ + # Call expert analysis if we have relevant files or substantial findings + return ( + len(consolidated_findings.relevant_files) > 0 + or len(consolidated_findings.findings) >= 2 + or len(consolidated_findings.issues_found) > 0 + ) + + def prepare_standard_expert_context( + self, consolidated_findings, initial_description: str, context_sections: dict[str, str] = None + ) -> str: + """ + Helper method to prepare standard expert analysis context. + + This provides a common structure that most workflow tools can use, + with the ability to add tool-specific sections. + + Args: + consolidated_findings: The consolidated findings from all work steps + initial_description: Description of the initial request/issue + context_sections: Optional additional sections to include + + Returns: + Formatted context string for expert analysis + """ + context_parts = [f"=== ISSUE DESCRIPTION ===\n{initial_description}\n=== END DESCRIPTION ==="] + + # Add work progression + if consolidated_findings.findings: + findings_text = "\n".join(consolidated_findings.findings) + context_parts.append(f"\n=== INVESTIGATION FINDINGS ===\n{findings_text}\n=== END FINDINGS ===") + + # Add relevant methods if available + if consolidated_findings.relevant_context: + methods_text = "\n".join(f"- {method}" for method in consolidated_findings.relevant_context) + context_parts.append(f"\n=== RELEVANT METHODS/FUNCTIONS ===\n{methods_text}\n=== END METHODS ===") + + # Add hypothesis evolution if available + if consolidated_findings.hypotheses: + hypotheses_text = "\n".join( + f"Step {h['step']} ({h['confidence']} confidence): {h['hypothesis']}" + for h in consolidated_findings.hypotheses + ) + context_parts.append(f"\n=== HYPOTHESIS EVOLUTION ===\n{hypotheses_text}\n=== END HYPOTHESES ===") + + # Add issues found if available + if consolidated_findings.issues_found: + issues_text = "\n".join( + f"[{issue.get('severity', 'unknown').upper()}] {issue.get('description', 'No description')}" + for issue in consolidated_findings.issues_found + ) + context_parts.append(f"\n=== ISSUES IDENTIFIED ===\n{issues_text}\n=== END ISSUES ===") + + # Add tool-specific sections + if context_sections: + for section_title, section_content in context_sections.items(): + context_parts.append( + f"\n=== {section_title.upper()} ===\n{section_content}\n=== END {section_title.upper()} ===" + ) + + return "\n".join(context_parts) + + def handle_completion_without_expert_analysis( + self, request, consolidated_findings, initial_description: str = None + ) -> dict[str, Any]: + """ + Generic handler for completion when expert analysis is not needed. + + This provides a standard response format for when the tool determines + that external expert analysis is not required. All workflow tools + can use this generic implementation or override for custom behavior. + + Args: + request: The workflow request object + consolidated_findings: The consolidated findings from all work steps + initial_description: Optional initial description (defaults to request.step) + + Returns: + Dictionary with completion response data + """ + # Prepare work summary using inheritance hook + work_summary = self.prepare_work_summary() + + return { + "status": self.get_completion_status(), + self.get_completion_data_key(): { + "initial_request": initial_description or request.step, + "steps_taken": len(consolidated_findings.findings), + "files_examined": list(consolidated_findings.files_checked), + "relevant_files": list(consolidated_findings.relevant_files), + "relevant_context": list(consolidated_findings.relevant_context), + "work_summary": work_summary, + "final_analysis": self.get_final_analysis_from_request(request), + "confidence_level": self.get_confidence_level(request), + }, + "next_steps": self.get_completion_message(), + "skip_expert_analysis": True, + "expert_analysis": { + "status": self.get_skip_expert_analysis_status(), + "reason": self.get_skip_reason(), + }, + } + + # Inheritance hooks for customization + + def prepare_work_summary(self) -> str: + """ + Prepare a summary of the work performed. Override for custom summaries. + Default implementation provides a basic summary. + """ + try: + return self._prepare_work_summary() + except AttributeError: + try: + return f"Completed {len(self.work_history)} work steps" + except AttributeError: + return "Completed 0 work steps" + + def get_completion_status(self) -> str: + """Get the status to use when completing without expert analysis.""" + return "high_confidence_completion" + + def get_completion_data_key(self) -> str: + """Get the key name for completion data in the response.""" + return f"complete_{self.get_name()}" + + def get_final_analysis_from_request(self, request) -> Optional[str]: + """Extract final analysis from request. Override for tool-specific extraction.""" + try: + return request.hypothesis + except AttributeError: + return None + + def get_confidence_level(self, request) -> str: + """Get confidence level from request. Override for tool-specific logic.""" + try: + return request.confidence or "high" + except AttributeError: + return "high" + + def get_completion_message(self) -> str: + """Get completion message. Override for tool-specific messaging.""" + return ( + f"{self.get_name().capitalize()} complete with high confidence. You have identified the exact " + "analysis and solution. MANDATORY: Present the user with the results " + "and proceed with implementing the solution without requiring further " + "consultation. Focus on the precise, actionable steps needed." + ) + + def get_skip_reason(self) -> str: + """Get reason for skipping expert analysis. Override for tool-specific reasons.""" + return f"{self.get_name()} completed with sufficient confidence" + + def get_skip_expert_analysis_status(self) -> str: + """Get status for skipped expert analysis. Override for tool-specific status.""" + return "skipped_by_tool_design" + + def is_continuation_workflow(self, request) -> bool: + """ + Check if this is a continuation workflow that should skip multi-step investigation. + + When continuation_id is provided, the workflow typically continues from a previous + conversation and should go directly to expert analysis rather than starting a new + multi-step investigation. + + Args: + request: The workflow request object + + Returns: + True if this is a continuation that should skip multi-step workflow + """ + continuation_id = self.get_request_continuation_id(request) + return bool(continuation_id) + + # Abstract methods that must be implemented by specific workflow tools + # (These are inherited from BaseWorkflowMixin and must be implemented) + + @abstractmethod + def get_required_actions( + self, step_number: int, confidence: str, findings: str, total_steps: int, request=None + ) -> list[str]: + """Define required actions for each work phase. + + Args: + step_number: Current step number + confidence: Current confidence level + findings: Current findings text + total_steps: Total estimated steps + request: Optional request object for continuation-aware decisions + + Returns: + List of required actions for the current step + """ + pass + + @abstractmethod + def should_call_expert_analysis(self, consolidated_findings) -> bool: + """Decide when to call external model based on tool-specific criteria""" + pass + + @abstractmethod + def prepare_expert_analysis_context(self, consolidated_findings) -> str: + """Prepare context for external model call""" + pass + + # Default execute method - delegates to workflow + async def execute(self, arguments: dict[str, Any]) -> list: + """Execute the workflow tool - delegates to BaseWorkflowMixin.""" + return await self.execute_workflow(arguments) diff --git a/tools/workflow/schema_builders.py b/tools/workflow/schema_builders.py new file mode 100644 index 0000000..4ae1e27 --- /dev/null +++ b/tools/workflow/schema_builders.py @@ -0,0 +1,174 @@ +""" +Schema builders for workflow MCP tools. + +This module provides workflow-specific schema generation functionality, +keeping workflow concerns separated from simple tool concerns. +""" + +from typing import Any + +from ..shared.base_models import WORKFLOW_FIELD_DESCRIPTIONS +from ..shared.schema_builders import SchemaBuilder + + +class WorkflowSchemaBuilder: + """ + Schema builder for workflow MCP tools. + + This class extends the base SchemaBuilder with workflow-specific fields + and schema generation logic, maintaining separation of concerns. + """ + + # Workflow-specific field schemas + WORKFLOW_FIELD_SCHEMAS = { + "step": { + "type": "string", + "description": WORKFLOW_FIELD_DESCRIPTIONS["step"], + }, + "step_number": { + "type": "integer", + "minimum": 1, + "description": WORKFLOW_FIELD_DESCRIPTIONS["step_number"], + }, + "total_steps": { + "type": "integer", + "minimum": 1, + "description": WORKFLOW_FIELD_DESCRIPTIONS["total_steps"], + }, + "next_step_required": { + "type": "boolean", + "description": WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"], + }, + "findings": { + "type": "string", + "description": WORKFLOW_FIELD_DESCRIPTIONS["findings"], + }, + "files_checked": { + "type": "array", + "items": {"type": "string"}, + "description": WORKFLOW_FIELD_DESCRIPTIONS["files_checked"], + }, + "relevant_files": { + "type": "array", + "items": {"type": "string"}, + "description": WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"], + }, + "relevant_context": { + "type": "array", + "items": {"type": "string"}, + "description": WORKFLOW_FIELD_DESCRIPTIONS["relevant_context"], + }, + "issues_found": { + "type": "array", + "items": {"type": "object"}, + "description": WORKFLOW_FIELD_DESCRIPTIONS["issues_found"], + }, + "confidence": { + "type": "string", + "enum": ["exploring", "low", "medium", "high", "very_high", "almost_certain", "certain"], + "description": WORKFLOW_FIELD_DESCRIPTIONS["confidence"], + }, + "hypothesis": { + "type": "string", + "description": WORKFLOW_FIELD_DESCRIPTIONS["hypothesis"], + }, + "backtrack_from_step": { + "type": "integer", + "minimum": 1, + "description": WORKFLOW_FIELD_DESCRIPTIONS["backtrack_from_step"], + }, + "use_assistant_model": { + "type": "boolean", + "default": True, + "description": WORKFLOW_FIELD_DESCRIPTIONS["use_assistant_model"], + }, + } + + @staticmethod + def build_schema( + tool_specific_fields: dict[str, dict[str, Any]] = None, + required_fields: list[str] = None, + model_field_schema: dict[str, Any] = None, + auto_mode: bool = False, + tool_name: str = None, + excluded_workflow_fields: list[str] = None, + excluded_common_fields: list[str] = None, + require_model: bool = False, + ) -> dict[str, Any]: + """ + Build complete schema for workflow tools. + + Args: + tool_specific_fields: Additional fields specific to the tool + required_fields: List of required field names (beyond workflow defaults) + model_field_schema: Schema for the model field + auto_mode: Whether the tool is in auto mode (affects model requirement) + tool_name: Name of the tool (for schema title) + excluded_workflow_fields: Workflow fields to exclude from schema (e.g., for planning tools) + excluded_common_fields: Common fields to exclude from schema + + Returns: + Complete JSON schema for the workflow tool + """ + properties = {} + + # Add workflow fields first, excluding any specified fields + workflow_fields = WorkflowSchemaBuilder.WORKFLOW_FIELD_SCHEMAS.copy() + if excluded_workflow_fields: + for field in excluded_workflow_fields: + workflow_fields.pop(field, None) + properties.update(workflow_fields) + + # Add common fields (temperature, thinking_mode, etc.) from base builder, excluding any specified fields + common_fields = SchemaBuilder.COMMON_FIELD_SCHEMAS.copy() + if excluded_common_fields: + for field in excluded_common_fields: + common_fields.pop(field, None) + properties.update(common_fields) + + # Add model field if provided + if model_field_schema: + properties["model"] = model_field_schema + + # Add tool-specific fields if provided + if tool_specific_fields: + properties.update(tool_specific_fields) + + # Build required fields list - workflow tools have standard required fields + standard_required = ["step", "step_number", "total_steps", "next_step_required", "findings"] + + # Filter out excluded fields from required fields + if excluded_workflow_fields: + standard_required = [field for field in standard_required if field not in excluded_workflow_fields] + + required = standard_required + (required_fields or []) + + if (auto_mode or require_model) and "model" not in required: + required.append("model") + + # Build the complete schema + schema = { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": properties, + "required": required, + "additionalProperties": False, + } + + if tool_name: + schema["title"] = f"{tool_name.capitalize()}Request" + + return schema + + @staticmethod + def get_workflow_fields() -> dict[str, dict[str, Any]]: + """Get the standard field schemas for workflow tools.""" + combined = {} + combined.update(WorkflowSchemaBuilder.WORKFLOW_FIELD_SCHEMAS) + combined.update(SchemaBuilder.COMMON_FIELD_SCHEMAS) + return combined + + @staticmethod + def get_workflow_only_fields() -> dict[str, dict[str, Any]]: + """Get only the workflow-specific field schemas.""" + return WorkflowSchemaBuilder.WORKFLOW_FIELD_SCHEMAS.copy() diff --git a/tools/workflow/workflow_mixin.py b/tools/workflow/workflow_mixin.py new file mode 100644 index 0000000..80ebe30 --- /dev/null +++ b/tools/workflow/workflow_mixin.py @@ -0,0 +1,1619 @@ +""" +Workflow Mixin for Zen MCP Tools + +This module provides a sophisticated workflow-based pattern that enables tools to +perform multi-step work with structured findings and expert analysis. + +Key Components: +- BaseWorkflowMixin: Abstract base class providing comprehensive workflow functionality + +The workflow pattern enables tools like debug, precommit, and codereview to perform +systematic multi-step work with pause/resume capabilities, context-aware file embedding, +and seamless integration with external AI models for expert analysis. + +Features: +- Multi-step workflow orchestration with pause/resume +- Context-aware file embedding optimization +- Expert analysis integration with token budgeting +- Conversation memory and threading support +- Proper inheritance-based architecture (no hasattr/getattr) +- Comprehensive type annotations for IDE support +""" + +import json +import logging +import os +import re +from abc import ABC, abstractmethod +from typing import Any, Optional + +from mcp.types import TextContent + +from config import MCP_PROMPT_SIZE_LIMIT +from utils.conversation_memory import add_turn, create_thread + +from ..shared.base_models import ConsolidatedFindings + +logger = logging.getLogger(__name__) + + +class BaseWorkflowMixin(ABC): + """ + Abstract base class providing guided workflow functionality for tools. + + This class implements a sophisticated workflow pattern where Claude performs + systematic local work before calling external models for expert analysis. + Tools can inherit from this class to gain comprehensive workflow capabilities. + + Architecture: + - Uses proper inheritance patterns instead of hasattr/getattr + - Provides hook methods with default implementations + - Requires abstract methods to be implemented by subclasses + - Fully type-annotated for excellent IDE support + + Context-Aware File Embedding: + - Intermediate steps: Only reference file names (saves Claude's context) + - Final steps: Embed full file content for expert analysis + - Integrates with existing token budgeting infrastructure + + Requirements: + This class expects to be used with BaseTool and requires implementation of: + - get_model_provider(model_name) + - _resolve_model_context(arguments, request) + - get_system_prompt() + - get_default_temperature() + - _prepare_file_content_for_prompt() + """ + + def __init__(self) -> None: + super().__init__() + self.work_history: list[dict[str, Any]] = [] + self.consolidated_findings: ConsolidatedFindings = ConsolidatedFindings() + self.initial_request: Optional[str] = None + + # ================================================================================ + # Abstract Methods - Required Implementation by BaseTool or Subclasses + # ================================================================================ + + @abstractmethod + def get_name(self) -> str: + """Return the name of this tool. Usually provided by BaseTool.""" + pass + + @abstractmethod + def get_workflow_request_model(self) -> type: + """Return the request model class for this workflow tool.""" + pass + + @abstractmethod + def get_system_prompt(self) -> str: + """Return the system prompt for this tool. Usually provided by BaseTool.""" + pass + + @abstractmethod + def get_language_instruction(self) -> str: + """Return the language instruction for localization. Usually provided by BaseTool.""" + pass + + @abstractmethod + def get_default_temperature(self) -> float: + """Return the default temperature for this tool. Usually provided by BaseTool.""" + pass + + @abstractmethod + def get_model_provider(self, model_name: str) -> Any: + """Get model provider for the given model. Usually provided by BaseTool.""" + pass + + @abstractmethod + def _resolve_model_context(self, arguments: dict[str, Any], request: Any) -> tuple[str, Any]: + """Resolve model context from arguments. Usually provided by BaseTool.""" + pass + + @abstractmethod + def _prepare_file_content_for_prompt( + self, + request_files: list[str], + continuation_id: Optional[str], + context_description: str = "New files", + max_tokens: Optional[int] = None, + reserve_tokens: int = 1_000, + remaining_budget: Optional[int] = None, + arguments: Optional[dict[str, Any]] = None, + model_context: Optional[Any] = None, + ) -> tuple[str, list[str]]: + """Prepare file content for prompts. Usually provided by BaseTool.""" + pass + + # ================================================================================ + # Abstract Methods - Tool-Specific Implementation Required + # ================================================================================ + + @abstractmethod + def get_work_steps(self, request: Any) -> list[str]: + """Define tool-specific work steps and criteria""" + pass + + @abstractmethod + def get_required_actions( + self, step_number: int, confidence: str, findings: str, total_steps: int, request=None + ) -> list[str]: + """Define required actions for each work phase. + + Args: + step_number: Current step (1-based) + confidence: Current confidence level (exploring, low, medium, high, certain) + findings: Current findings text + total_steps: Total estimated steps for this work + request: Optional request object for continuation-aware decisions + + Returns: + List of specific actions Claude should take before calling tool again + """ + pass + + # ================================================================================ + # Hook Methods - Default Implementations with Override Capability + # ================================================================================ + + def should_call_expert_analysis(self, consolidated_findings: ConsolidatedFindings, request=None) -> bool: + """ + Decide when to call external model based on tool-specific criteria. + + Default implementation for tools that don't use expert analysis. + Override this for tools that do use expert analysis. + + Args: + consolidated_findings: Findings from workflow steps + request: Current request object (optional for backwards compatibility) + """ + if not self.requires_expert_analysis(): + return False + + # Check if user requested to skip assistant model + if request and not self.get_request_use_assistant_model(request): + return False + + # Default logic for tools that support expert analysis + return ( + len(consolidated_findings.relevant_files) > 0 + or len(consolidated_findings.findings) >= 2 + or len(consolidated_findings.issues_found) > 0 + ) + + def prepare_expert_analysis_context(self, consolidated_findings: ConsolidatedFindings) -> str: + """ + Prepare context for external model call. + + Default implementation for tools that don't use expert analysis. + Override this for tools that do use expert analysis. + """ + if not self.requires_expert_analysis(): + return "" + + # Default context preparation + context_parts = [ + f"=== {self.get_name().upper()} WORK SUMMARY ===", + f"Total steps: {len(consolidated_findings.findings)}", + f"Files examined: {len(consolidated_findings.files_checked)}", + f"Relevant files: {len(consolidated_findings.relevant_files)}", + "", + "=== WORK PROGRESSION ===", + ] + + for finding in consolidated_findings.findings: + context_parts.append(finding) + + return "\n".join(context_parts) + + def requires_expert_analysis(self) -> bool: + """ + Override this to completely disable expert analysis for the tool. + + Returns True if the tool supports expert analysis (default). + Returns False if the tool is self-contained (like planner). + """ + return True + + def should_include_files_in_expert_prompt(self) -> bool: + """ + Whether to include file content in the expert analysis prompt. + Override this to return True if your tool needs files in the prompt. + """ + return False + + def should_embed_system_prompt(self) -> bool: + """ + Whether to embed the system prompt in the main prompt. + Override this to return True if your tool needs the system prompt embedded. + """ + return False + + def get_expert_thinking_mode(self) -> str: + """ + Get the thinking mode for expert analysis. + Override this to customize the thinking mode. + """ + return "high" + + def get_request_temperature(self, request) -> float: + """Get temperature from request. Override for custom temperature handling.""" + try: + return request.temperature if request.temperature is not None else self.get_default_temperature() + except AttributeError: + return self.get_default_temperature() + + def get_validated_temperature(self, request, model_context: Any) -> tuple[float, list[str]]: + """ + Get temperature from request and validate it against model constraints. + + This is a convenience method that combines temperature extraction and validation + for workflow tools. It ensures temperature is within valid range for the model. + + Args: + request: The request object containing temperature + model_context: Model context object containing model info + + Returns: + Tuple of (validated_temperature, warning_messages) + """ + temperature = self.get_request_temperature(request) + return self.validate_and_correct_temperature(temperature, model_context) + + def get_request_thinking_mode(self, request) -> str: + """Get thinking mode from request. Override for custom thinking mode handling.""" + try: + return request.thinking_mode if request.thinking_mode is not None else self.get_expert_thinking_mode() + except AttributeError: + return self.get_expert_thinking_mode() + + def get_expert_analysis_instruction(self) -> str: + """ + Get the instruction to append after the expert context. + Override this to provide tool-specific instructions. + """ + return "Please provide expert analysis based on the investigation findings." + + def get_request_use_assistant_model(self, request) -> bool: + """ + Get use_assistant_model from request. Override for custom assistant model handling. + + Args: + request: Current request object + + Returns: + True if assistant model should be used, False otherwise + """ + try: + return request.use_assistant_model if request.use_assistant_model is not None else True + except AttributeError: + return True + + def get_step_guidance_message(self, request) -> str: + """ + Get step guidance message. Override for tool-specific guidance. + Default implementation uses required actions. + """ + required_actions = self.get_required_actions( + request.step_number, self.get_request_confidence(request), request.findings, request.total_steps, request + ) + + next_step_number = request.step_number + 1 + return ( + f"MANDATORY: DO NOT call the {self.get_name()} tool again immediately. " + f"You MUST first work using appropriate tools. " + f"REQUIRED ACTIONS before calling {self.get_name()} step {next_step_number}:\n" + + "\n".join(f"{i + 1}. {action}" for i, action in enumerate(required_actions)) + + f"\n\nOnly call {self.get_name()} again with step_number: {next_step_number} " + f"AFTER completing this work." + ) + + def _prepare_files_for_expert_analysis(self) -> str: + """ + Prepare file content for expert analysis. + + EXPERT ANALYSIS REQUIRES ACTUAL FILE CONTENT: + Expert analysis needs actual file content of all unique files marked as relevant + throughout the workflow, regardless of conversation history optimization. + + SIMPLIFIED LOGIC: + Expert analysis gets all unique files from relevant_files across the entire workflow. + This includes: + - Current step's relevant_files (consolidated_findings.relevant_files) + - Plus any additional relevant_files from conversation history (if continued workflow) + + This ensures expert analysis has complete context without including irrelevant files. + """ + all_relevant_files = set() + + # 1. Get files from current consolidated relevant_files + all_relevant_files.update(self.consolidated_findings.relevant_files) + + # 2. Get additional relevant_files from conversation history (if continued workflow) + try: + current_arguments = self.get_current_arguments() + if current_arguments: + continuation_id = current_arguments.get("continuation_id") + + if continuation_id: + from utils.conversation_memory import get_conversation_file_list, get_thread + + thread_context = get_thread(continuation_id) + if thread_context: + # Get all files from conversation (these were relevant_files in previous steps) + conversation_files = get_conversation_file_list(thread_context) + all_relevant_files.update(conversation_files) + logger.debug( + f"[WORKFLOW_FILES] {self.get_name()}: Added {len(conversation_files)} files from conversation history" + ) + except Exception as e: + logger.warning(f"[WORKFLOW_FILES] {self.get_name()}: Could not get conversation files: {e}") + + # Convert to list and remove any empty/None values + files_for_expert = [f for f in all_relevant_files if f and f.strip()] + + if not files_for_expert: + logger.debug(f"[WORKFLOW_FILES] {self.get_name()}: No relevant files found for expert analysis") + return "" + + # Expert analysis needs actual file content, bypassing conversation optimization + try: + file_content, processed_files = self._force_embed_files_for_expert_analysis(files_for_expert) + + logger.info( + f"[WORKFLOW_FILES] {self.get_name()}: Prepared {len(processed_files)} unique relevant files for expert analysis " + f"(from {len(self.consolidated_findings.relevant_files)} current relevant files)" + ) + + return file_content + + except Exception as e: + logger.error(f"[WORKFLOW_FILES] {self.get_name()}: Failed to prepare files for expert analysis: {e}") + return "" + + def _force_embed_files_for_expert_analysis(self, files: list[str]) -> tuple[str, list[str]]: + """ + Force embed files for expert analysis, bypassing conversation history filtering. + + Expert analysis has different requirements than normal workflow steps: + - Normal steps: Optimize tokens by skipping files in conversation history + - Expert analysis: Needs actual file content regardless of conversation history + + Args: + files: List of file paths to embed + + Returns: + tuple[str, list[str]]: (file_content, processed_files) + """ + # Use read_files directly with token budgeting, bypassing filter_new_files + from utils.file_utils import expand_paths, read_files + + # Get token budget for files + current_model_context = self.get_current_model_context() + if current_model_context: + try: + token_allocation = current_model_context.calculate_token_allocation() + max_tokens = token_allocation.file_tokens + logger.debug( + f"[WORKFLOW_FILES] {self.get_name()}: Using {max_tokens:,} tokens for expert analysis files" + ) + except Exception as e: + logger.warning(f"[WORKFLOW_FILES] {self.get_name()}: Failed to get token allocation: {e}") + max_tokens = 100_000 # Fallback + else: + max_tokens = 100_000 # Fallback + + # Read files directly without conversation history filtering + logger.debug(f"[WORKFLOW_FILES] {self.get_name()}: Force embedding {len(files)} files for expert analysis") + file_content = read_files( + files, + max_tokens=max_tokens, + reserve_tokens=1000, + include_line_numbers=self.wants_line_numbers_by_default(), + ) + + # Expand paths to get individual files for tracking + processed_files = expand_paths(files) + + logger.debug( + f"[WORKFLOW_FILES] {self.get_name()}: Expert analysis embedding: {len(processed_files)} files, " + f"{len(file_content):,} characters" + ) + + return file_content, processed_files + + def wants_line_numbers_by_default(self) -> bool: + """ + Whether this tool wants line numbers in file content by default. + Override this to customize line number behavior. + """ + return True # Most workflow tools benefit from line numbers for analysis + + def _add_files_to_expert_context(self, expert_context: str, file_content: str) -> str: + """ + Add file content to the expert context. + Override this to customize how files are added to the context. + """ + return f"{expert_context}\n\n=== ESSENTIAL FILES ===\n{file_content}\n=== END ESSENTIAL FILES ===" + + # ================================================================================ + # Context-Aware File Embedding - Core Implementation + # ================================================================================ + + def _handle_workflow_file_context(self, request: Any, arguments: dict[str, Any]) -> None: + """ + Handle file context appropriately based on workflow phase. + + CONTEXT-AWARE FILE EMBEDDING STRATEGY: + 1. Intermediate steps + continuation: Only reference file names (save Claude's context) + 2. Final step: Embed full file content for expert analysis + 3. Expert analysis: Always embed relevant files with token budgeting + + This prevents wasting Claude's limited context on intermediate steps while ensuring + the final expert analysis has complete file context. + """ + continuation_id = self.get_request_continuation_id(request) + is_final_step = not self.get_request_next_step_required(request) + step_number = self.get_request_step_number(request) + + # Extract model context for token budgeting + model_context = arguments.get("_model_context") + self._model_context = model_context + + # Clear any previous file context to ensure clean state + self._embedded_file_content = "" + self._file_reference_note = "" + self._actually_processed_files = [] + + # Determine if we should embed files or just reference them + should_embed_files = self._should_embed_files_in_workflow_step(step_number, continuation_id, is_final_step) + + if should_embed_files: + # Final step or expert analysis - embed full file content + logger.debug(f"[WORKFLOW_FILES] {self.get_name()}: Embedding files for final step/expert analysis") + self._embed_workflow_files(request, arguments) + else: + # Intermediate step with continuation - only reference file names + logger.debug(f"[WORKFLOW_FILES] {self.get_name()}: Only referencing file names for intermediate step") + self._reference_workflow_files(request) + + def _should_embed_files_in_workflow_step( + self, step_number: int, continuation_id: Optional[str], is_final_step: bool + ) -> bool: + """ + Determine whether to embed file content based on workflow context. + + CORRECT LOGIC: + - NEVER embed files when Claude is getting the next step (next_step_required=True) + - ONLY embed files when sending to external model (next_step_required=False) + + Args: + step_number: Current step number + continuation_id: Thread continuation ID (None for new conversations) + is_final_step: Whether this is the final step (next_step_required == False) + + Returns: + bool: True if files should be embedded, False if only referenced + """ + # RULE 1: Final steps (no more steps needed) - embed files for expert analysis + if is_final_step: + logger.debug("[WORKFLOW_FILES] Final step - will embed files for expert analysis") + return True + + # RULE 2: Any intermediate step (more steps needed) - NEVER embed files + # This includes: + # - New conversations with next_step_required=True + # - Steps with continuation_id and next_step_required=True + logger.debug("[WORKFLOW_FILES] Intermediate step (more work needed) - will only reference files") + return False + + def _embed_workflow_files(self, request: Any, arguments: dict[str, Any]) -> None: + """ + Embed full file content for final steps and expert analysis. + Uses proper token budgeting like existing debug.py. + """ + # Use relevant_files as the standard field for workflow tools + request_files = self.get_request_relevant_files(request) + if not request_files: + logger.debug(f"[WORKFLOW_FILES] {self.get_name()}: No relevant_files to embed") + return + + try: + # Model context should be available from early validation, but might be deferred for tests + current_model_context = self.get_current_model_context() + if not current_model_context: + # Try to resolve model context now (deferred from early validation) + try: + model_name, model_context = self._resolve_model_context(arguments, request) + self._model_context = model_context + self._current_model_name = model_name + except Exception as e: + logger.error(f"[WORKFLOW_FILES] {self.get_name()}: Failed to resolve model context: {e}") + # Create fallback model context (preserves existing test behavior) + from utils.model_context import ModelContext + + model_name = self.get_request_model_name(request) + self._model_context = ModelContext(model_name) + self._current_model_name = model_name + + # Use the same file preparation logic as BaseTool with token budgeting + continuation_id = self.get_request_continuation_id(request) + remaining_tokens = arguments.get("_remaining_tokens") + + file_content, processed_files = self._prepare_file_content_for_prompt( + request_files, + continuation_id, + "Workflow files for analysis", + remaining_budget=remaining_tokens, + arguments=arguments, + model_context=self._model_context, + ) + + # Store for use in expert analysis + self._embedded_file_content = file_content + self._actually_processed_files = processed_files + + logger.info( + f"[WORKFLOW_FILES] {self.get_name()}: Embedded {len(processed_files)} relevant_files for final analysis" + ) + + except Exception as e: + logger.error(f"[WORKFLOW_FILES] {self.get_name()}: Failed to embed files: {e}") + # Continue without file embedding rather than failing + self._embedded_file_content = "" + self._actually_processed_files = [] + + def _reference_workflow_files(self, request: Any) -> None: + """ + Reference file names without embedding content for intermediate steps. + Saves Claude's context while still providing file awareness. + """ + # Workflow tools use relevant_files, not files + request_files = self.get_request_relevant_files(request) + logger.debug( + f"[WORKFLOW_FILES] {self.get_name()}: _reference_workflow_files called with {len(request_files)} relevant_files" + ) + + if not request_files: + logger.debug(f"[WORKFLOW_FILES] {self.get_name()}: No files to reference, skipping") + return + + # Store file references for conversation context + self._referenced_files = request_files + + # Create a simple reference note + file_names = [os.path.basename(f) for f in request_files] + reference_note = f"Files referenced in this step: {', '.join(file_names)}\n" + + self._file_reference_note = reference_note + logger.debug(f"[WORKFLOW_FILES] {self.get_name()}: Set _file_reference_note: {self._file_reference_note}") + + logger.info( + f"[WORKFLOW_FILES] {self.get_name()}: Referenced {len(request_files)} files without embedding content" + ) + + # ================================================================================ + # Main Workflow Orchestration + # ================================================================================ + + async def execute_workflow(self, arguments: dict[str, Any]) -> list[TextContent]: + """ + Main workflow orchestration following debug tool pattern. + + Comprehensive workflow implementation that handles all common patterns: + 1. Request validation and step management + 2. Continuation and backtracking support + 3. Step data processing and consolidation + 4. Tool-specific field mapping and customization + 5. Completion logic with optional expert analysis + 6. Generic "certain confidence" handling + 7. Step guidance and required actions + 8. Conversation memory integration + """ + from mcp.types import TextContent + + try: + # Store arguments for access by helper methods + self._current_arguments = arguments + + # Validate request using tool-specific model + request = self.get_workflow_request_model()(**arguments) + + # Validate step field size (basic validation for workflow instructions) + # If step is too large, user should use shorter instructions and put details in files + step_content = request.step + if step_content and len(step_content) > MCP_PROMPT_SIZE_LIMIT: + from tools.models import ToolOutput + + error_output = ToolOutput( + status="resend_prompt", + content="Step instructions are too long. Please use shorter instructions and provide detailed context via file paths instead.", + content_type="text", + metadata={"prompt_size": len(step_content), "limit": MCP_PROMPT_SIZE_LIMIT}, + ) + raise ValueError(f"MCP_SIZE_CHECK:{error_output.model_dump_json()}") + + # Validate file paths for security (same as base tool) + # Use try/except instead of hasattr as per coding standards + try: + path_error = self.validate_file_paths(request) + if path_error: + from tools.models import ToolOutput + + error_output = ToolOutput( + status="error", + content=path_error, + content_type="text", + ) + return [TextContent(type="text", text=error_output.model_dump_json())] + except AttributeError: + # validate_file_paths method not available - skip validation + pass + + # Try to validate model availability early for production scenarios + # For tests, defer model validation to later to allow mocks to work + try: + model_name, model_context = self._resolve_model_context(arguments, request) + # Store for later use + self._current_model_name = model_name + self._model_context = model_context + except ValueError as e: + # Model resolution failed - in production this would be an error, + # but for tests we defer to allow mocks to handle model resolution + logger.debug(f"Early model validation failed, deferring to later: {e}") + self._current_model_name = None + self._model_context = None + + # Handle continuation + continuation_id = request.continuation_id + + # Restore workflow state on continuation + if continuation_id: + from utils.conversation_memory import get_thread + + thread = get_thread(continuation_id) + if thread and thread.turns: + # Find the most recent assistant turn from this tool with workflow state + for turn in reversed(thread.turns): + if turn.role == "assistant" and turn.tool_name == self.get_name() and turn.model_metadata: + state = turn.model_metadata + if isinstance(state, dict) and "work_history" in state: + self.work_history = state.get("work_history", []) + self.initial_request = state.get("initial_request") + # Rebuild consolidated findings from restored history + self._reprocess_consolidated_findings() + logger.debug( + f"[{self.get_name()}] Restored workflow state with {len(self.work_history)} history items" + ) + break # State restored, exit loop + + # Adjust total steps if needed + if request.step_number > request.total_steps: + request.total_steps = request.step_number + + # Create thread for first step + if not continuation_id and request.step_number == 1: + clean_args = {k: v for k, v in arguments.items() if k not in ["_model_context", "_resolved_model_name"]} + continuation_id = create_thread(self.get_name(), clean_args) + self.initial_request = request.step + # Allow tools to store initial description for expert analysis + self.store_initial_issue(request.step) + + # Handle backtracking if requested + backtrack_step = self.get_backtrack_step(request) + if backtrack_step: + self._handle_backtracking(backtrack_step) + + # Process work step - allow tools to customize field mapping + step_data = self.prepare_step_data(request) + + # Store in history + self.work_history.append(step_data) + + # Update consolidated findings + self._update_consolidated_findings(step_data) + + # Handle file context appropriately based on workflow phase + self._handle_workflow_file_context(request, arguments) + + # Build response with tool-specific customization + response_data = self.build_base_response(request, continuation_id) + + # If work is complete, handle completion logic + if not request.next_step_required: + response_data = await self.handle_work_completion(response_data, request, arguments) + else: + # Force Claude to work before calling tool again + response_data = self.handle_work_continuation(response_data, request) + + # Allow tools to customize the final response + response_data = self.customize_workflow_response(response_data, request) + + # Add metadata (provider_used and model_used) to workflow response + self._add_workflow_metadata(response_data, arguments) + + # Store in conversation memory + if continuation_id: + self.store_conversation_turn(continuation_id, response_data, request) + + return [TextContent(type="text", text=json.dumps(response_data, indent=2, ensure_ascii=False))] + + except Exception as e: + logger.error(f"Error in {self.get_name()} work: {e}", exc_info=True) + error_data = { + "status": f"{self.get_name()}_failed", + "error": str(e), + "step_number": arguments.get("step_number", 0), + } + + # Add metadata to error responses too + self._add_workflow_metadata(error_data, arguments) + + return [TextContent(type="text", text=json.dumps(error_data, indent=2, ensure_ascii=False))] + + # Hook methods for tool customization + + def prepare_step_data(self, request) -> dict: + """ + Prepare step data from request. Tools can override to customize field mapping. + """ + step_data = { + "step": request.step, + "step_number": request.step_number, + "findings": request.findings, + "files_checked": self.get_request_files_checked(request), + "relevant_files": self.get_request_relevant_files(request), + "relevant_context": self.get_request_relevant_context(request), + "issues_found": self.get_request_issues_found(request), + "confidence": self.get_request_confidence(request), + "hypothesis": self.get_request_hypothesis(request), + "images": self.get_request_images(request), + } + return step_data + + def build_base_response(self, request, continuation_id: str = None) -> dict: + """ + Build the base response structure. Tools can override for custom response fields. + """ + response_data = { + "status": f"{self.get_name()}_in_progress", + "step_number": request.step_number, + "total_steps": request.total_steps, + "next_step_required": request.next_step_required, + f"{self.get_name()}_status": { + "files_checked": len(self.consolidated_findings.files_checked), + "relevant_files": len(self.consolidated_findings.relevant_files), + "relevant_context": len(self.consolidated_findings.relevant_context), + "issues_found": len(self.consolidated_findings.issues_found), + "images_collected": len(self.consolidated_findings.images), + "current_confidence": self.get_request_confidence(request), + }, + } + + if continuation_id: + response_data["continuation_id"] = continuation_id + + # Add file context information based on workflow phase + embedded_content = self.get_embedded_file_content() + reference_note = self.get_file_reference_note() + processed_files = self.get_actually_processed_files() + + logger.debug( + f"[WORKFLOW_FILES] {self.get_name()}: Building response - has embedded_content: {bool(embedded_content)}, has reference_note: {bool(reference_note)}" + ) + + # Prioritize embedded content over references for final steps + if embedded_content: + # Final step - include embedded file information + logger.debug(f"[WORKFLOW_FILES] {self.get_name()}: Adding fully_embedded file context") + response_data["file_context"] = { + "type": "fully_embedded", + "files_embedded": len(processed_files), + "context_optimization": "Full file content embedded for expert analysis", + } + elif reference_note: + # Intermediate step - include file reference note + logger.debug(f"[WORKFLOW_FILES] {self.get_name()}: Adding reference_only file context") + response_data["file_context"] = { + "type": "reference_only", + "note": reference_note, + "context_optimization": "Files referenced but not embedded to preserve Claude's context window", + } + + return response_data + + def should_skip_expert_analysis(self, request, consolidated_findings) -> bool: + """ + Determine if expert analysis should be skipped due to high certainty. + + Default: False (always call expert analysis) + Override in tools like debug to check for "certain" confidence. + """ + return False + + def handle_completion_without_expert_analysis(self, request, consolidated_findings) -> dict: + """ + Handle completion when skipping expert analysis. + + Tools can override this for custom high-confidence completion handling. + Default implementation provides generic response. + """ + work_summary = self.prepare_work_summary() + continuation_id = self.get_request_continuation_id(request) + + response_data = { + "status": self.get_completion_status(), + f"complete_{self.get_name()}": { + "initial_request": self.get_initial_request(request.step), + "steps_taken": len(consolidated_findings.findings), + "files_examined": list(consolidated_findings.files_checked), + "relevant_files": list(consolidated_findings.relevant_files), + "relevant_context": list(consolidated_findings.relevant_context), + "work_summary": work_summary, + "final_analysis": self.get_final_analysis_from_request(request), + "confidence_level": self.get_confidence_level(request), + }, + "next_steps": self.get_completion_message(), + "skip_expert_analysis": True, + "expert_analysis": { + "status": self.get_skip_expert_analysis_status(), + "reason": self.get_skip_reason(), + }, + } + + if continuation_id: + response_data["continuation_id"] = continuation_id + + return response_data + + # ================================================================================ + # Inheritance Hook Methods - Replace hasattr/getattr Anti-patterns + # ================================================================================ + + def get_request_confidence(self, request: Any) -> str: + """Get confidence from request. Override for custom confidence handling.""" + try: + return request.confidence or "low" + except AttributeError: + return "low" + + def get_request_relevant_context(self, request: Any) -> list[str]: + """Get relevant context from request. Override for custom field mapping.""" + try: + return request.relevant_context or [] + except AttributeError: + return [] + + def get_request_issues_found(self, request: Any) -> list[str]: + """Get issues found from request. Override for custom field mapping.""" + try: + return request.issues_found or [] + except AttributeError: + return [] + + def get_request_hypothesis(self, request: Any) -> Optional[str]: + """Get hypothesis from request. Override for custom field mapping.""" + try: + return request.hypothesis + except AttributeError: + return None + + def get_request_images(self, request: Any) -> list[str]: + """Get images from request. Override for custom field mapping.""" + try: + return request.images or [] + except AttributeError: + return [] + + # File Context Access Methods + + def get_embedded_file_content(self) -> str: + """Get embedded file content. Returns empty string if not available.""" + try: + return self._embedded_file_content or "" + except AttributeError: + return "" + + def get_file_reference_note(self) -> str: + """Get file reference note. Returns empty string if not available.""" + try: + return self._file_reference_note or "" + except AttributeError: + return "" + + def get_actually_processed_files(self) -> list[str]: + """Get list of actually processed files. Returns empty list if not available.""" + try: + return self._actually_processed_files or [] + except AttributeError: + return [] + + def get_current_model_context(self): + """Get current model context. Returns None if not available.""" + try: + return self._model_context + except AttributeError: + return None + + def get_request_model_name(self, request: Any) -> str: + """Get model name from request. Override for custom model handling.""" + try: + return request.model or "flash" + except AttributeError: + return "flash" + + def get_request_continuation_id(self, request: Any) -> Optional[str]: + """Get continuation ID from request. Override for custom continuation handling.""" + try: + return request.continuation_id + except AttributeError: + return None + + def get_request_next_step_required(self, request: Any) -> bool: + """Get next step required from request. Override for custom step handling.""" + try: + return request.next_step_required + except AttributeError: + return True + + def get_request_step_number(self, request: Any) -> int: + """Get step number from request. Override for custom step handling.""" + try: + return request.step_number or 1 + except AttributeError: + return 1 + + def get_request_relevant_files(self, request: Any) -> list[str]: + """Get relevant files from request. Override for custom file handling.""" + try: + return request.relevant_files or [] + except AttributeError: + return [] + + def get_request_files_checked(self, request: Any) -> list[str]: + """Get files checked from request. Override for custom file handling.""" + try: + return request.files_checked or [] + except AttributeError: + return [] + + def get_current_arguments(self) -> dict[str, Any]: + """Get current arguments. Returns empty dict if not available.""" + try: + return self._current_arguments or {} + except AttributeError: + return {} + + def get_backtrack_step(self, request) -> Optional[int]: + """Get backtrack step from request. Override for custom backtrack handling.""" + try: + return request.backtrack_from_step + except AttributeError: + return None + + def store_initial_issue(self, step_description: str): + """Store initial issue description. Override for custom storage.""" + # Default implementation - tools can override to store differently + self.initial_issue = step_description + + def get_initial_request(self, fallback_step: str) -> str: + """Get initial request description. Override for custom retrieval.""" + try: + return self.initial_request or fallback_step + except AttributeError: + return fallback_step + + # Default implementations for inheritance hooks + + def prepare_work_summary(self) -> str: + """Prepare work summary. Override for custom implementation.""" + return f"Completed {len(self.consolidated_findings.findings)} work steps" + + def get_completion_status(self) -> str: + """Get completion status. Override for tool-specific status.""" + return "high_confidence_completion" + + def get_final_analysis_from_request(self, request): + """Extract final analysis from request. Override for tool-specific fields.""" + return self.get_request_hypothesis(request) + + def get_confidence_level(self, request) -> str: + """Get confidence level. Override for tool-specific confidence handling.""" + return self.get_request_confidence(request) or "high" + + def get_completion_message(self) -> str: + """Get completion message. Override for tool-specific messaging.""" + return ( + f"{self.get_name().capitalize()} complete with high confidence. Present results " + "and proceed with implementation without requiring further consultation." + ) + + def get_skip_reason(self) -> str: + """Get reason for skipping expert analysis. Override for tool-specific reasons.""" + return f"{self.get_name()} completed with sufficient confidence" + + def get_skip_expert_analysis_status(self) -> str: + """Get status for skipped expert analysis. Override for tool-specific status.""" + return "skipped_by_tool_design" + + def get_completion_next_steps_message(self, expert_analysis_used: bool = False) -> str: + """ + Get the message to show when work is complete. + Tools can override for custom messaging. + + Args: + expert_analysis_used: True if expert analysis was successfully executed + """ + base_message = ( + f"{self.get_name().upper()} IS COMPLETE. You MUST now summarize and present ALL key findings, confirmed " + "hypotheses, and exact recommended solutions. Clearly identify the most likely root cause and " + "provide concrete, actionable implementation guidance. Highlight affected code paths and display " + "reasoning that led to this conclusion—make it easy for a developer to understand exactly where " + "the problem lies." + ) + + # Add expert analysis guidance only when expert analysis was actually used + if expert_analysis_used: + expert_guidance = self.get_expert_analysis_guidance() + if expert_guidance: + return f"{base_message}\n\n{expert_guidance}" + + return base_message + + def get_expert_analysis_guidance(self) -> str: + """ + Get additional guidance for handling expert analysis results. + + Subclasses can override this to provide specific instructions about how + to validate and use expert analysis findings. Returns empty string by default. + + When expert analysis is called, this guidance will be: + 1. Appended to the completion next steps message + 2. Added as "important_considerations" field in the response data + + Example implementation: + ```python + def get_expert_analysis_guidance(self) -> str: + return ( + "IMPORTANT: Expert analysis provided above. You MUST validate " + "the expert findings rather than accepting them blindly. " + "Cross-reference with your own investigation and ensure " + "recommendations align with the codebase context." + ) + ``` + + Returns: + Additional guidance text or empty string if no guidance needed + """ + return "" + + def customize_workflow_response(self, response_data: dict, request) -> dict: + """ + Allow tools to customize the workflow response before returning. + + Tools can override this to add tool-specific fields, modify status names, + customize field mapping, etc. Default implementation returns unchanged. + """ + # Ensure file context information is preserved in all response paths + if not response_data.get("file_context"): + embedded_content = self.get_embedded_file_content() + reference_note = self.get_file_reference_note() + processed_files = self.get_actually_processed_files() + + # Prioritize embedded content over references for final steps + if embedded_content: + response_data["file_context"] = { + "type": "fully_embedded", + "files_embedded": len(processed_files), + "context_optimization": "Full file content embedded for expert analysis", + } + elif reference_note: + response_data["file_context"] = { + "type": "reference_only", + "note": reference_note, + "context_optimization": "Files referenced but not embedded to preserve Claude's context window", + } + + return response_data + + def store_conversation_turn(self, continuation_id: str, response_data: dict, request): + """ + Store the conversation turn. Tools can override for custom memory storage. + """ + # CRITICAL: Extract clean content for conversation history (exclude internal workflow metadata) + clean_content = self._extract_clean_workflow_content_for_history(response_data) + + # Serialize workflow state for persistence across stateless tool calls + workflow_state = {"work_history": self.work_history, "initial_request": getattr(self, "initial_request", None)} + + add_turn( + thread_id=continuation_id, + role="assistant", + content=clean_content, # Use cleaned content instead of full response_data + tool_name=self.get_name(), + files=self.get_request_relevant_files(request), + images=self.get_request_images(request), + model_metadata=workflow_state, # Persist the state + ) + + def _add_workflow_metadata(self, response_data: dict, arguments: dict[str, Any]) -> None: + """ + Add metadata (provider_used and model_used) to workflow response. + + This ensures workflow tools have the same metadata as regular tools, + making it consistent across all tool types for tracking which provider + and model were used for the response. + + Args: + response_data: The response data dictionary to modify + arguments: The original arguments containing model context + """ + try: + # Get model information from arguments (set by server.py) + resolved_model_name = arguments.get("_resolved_model_name") + model_context = arguments.get("_model_context") + + if resolved_model_name and model_context: + # Extract provider information from model context + provider = model_context.provider + provider_name = provider.get_provider_type().value if provider else "unknown" + + # Create metadata dictionary + metadata = { + "tool_name": self.get_name(), + "model_used": resolved_model_name, + "provider_used": provider_name, + } + + # Preserve existing metadata and add workflow metadata + if "metadata" not in response_data: + response_data["metadata"] = {} + response_data["metadata"].update(metadata) + + logger.debug( + f"[WORKFLOW_METADATA] {self.get_name()}: Added metadata - " + f"model: {resolved_model_name}, provider: {provider_name}" + ) + else: + # Fallback - try to get model info from request + request = self.get_workflow_request_model()(**arguments) + model_name = self.get_request_model_name(request) + + # Basic metadata without provider info + metadata = { + "tool_name": self.get_name(), + "model_used": model_name, + "provider_used": "unknown", + } + + # Preserve existing metadata and add workflow metadata + if "metadata" not in response_data: + response_data["metadata"] = {} + response_data["metadata"].update(metadata) + + logger.debug( + f"[WORKFLOW_METADATA] {self.get_name()}: Added fallback metadata - " + f"model: {model_name}, provider: unknown" + ) + + except Exception as e: + # Don't fail the workflow if metadata addition fails + logger.warning(f"[WORKFLOW_METADATA] {self.get_name()}: Failed to add metadata: {e}") + # Still add basic metadata with tool name + response_data["metadata"] = {"tool_name": self.get_name()} + + def _extract_clean_workflow_content_for_history(self, response_data: dict) -> str: + """ + Extract clean content from workflow response suitable for conversation history. + + This method removes internal workflow metadata, continuation offers, and + status information that should not appear when the conversation is + reconstructed for expert models or other tools. + + Args: + response_data: The full workflow response data + + Returns: + str: Clean content suitable for conversation history storage + """ + # Create a clean copy with only essential content for conversation history + clean_data = {} + + # Include core content if present + if "content" in response_data: + clean_data["content"] = response_data["content"] + + # Include expert analysis if present (but clean it) + if "expert_analysis" in response_data: + expert_analysis = response_data["expert_analysis"] + if isinstance(expert_analysis, dict): + # Only include the actual analysis content, not metadata + clean_expert = {} + if "raw_analysis" in expert_analysis: + clean_expert["analysis"] = expert_analysis["raw_analysis"] + elif "content" in expert_analysis: + clean_expert["analysis"] = expert_analysis["content"] + if clean_expert: + clean_data["expert_analysis"] = clean_expert + + # Include findings/issues if present (core workflow output) + if "complete_analysis" in response_data: + complete_analysis = response_data["complete_analysis"] + if isinstance(complete_analysis, dict): + clean_complete = {} + # Include essential analysis data without internal metadata + for key in ["findings", "issues_found", "relevant_context", "insights"]: + if key in complete_analysis: + clean_complete[key] = complete_analysis[key] + if clean_complete: + clean_data["analysis_summary"] = clean_complete + + # Include step information for context but remove internal workflow metadata + if "step_number" in response_data: + clean_data["step_info"] = { + "step": response_data.get("step", ""), + "step_number": response_data.get("step_number", 1), + "total_steps": response_data.get("total_steps", 1), + } + + # Exclude problematic fields that should never appear in conversation history: + # - continuation_id (confuses LLMs with old IDs) + # - status (internal workflow state) + # - next_step_required (internal control flow) + # - analysis_status (internal tracking) + # - file_context (internal optimization info) + # - required_actions (internal workflow instructions) + + return json.dumps(clean_data, indent=2, ensure_ascii=False) + + # Core workflow logic methods + + async def handle_work_completion(self, response_data: dict, request, arguments: dict) -> dict: + """ + Handle work completion logic - expert analysis decision and response building. + """ + response_data[f"{self.get_name()}_complete"] = True + + # Check if tool wants to skip expert analysis due to high certainty + if self.should_skip_expert_analysis(request, self.consolidated_findings): + # Handle completion without expert analysis + completion_response = self.handle_completion_without_expert_analysis(request, self.consolidated_findings) + response_data.update(completion_response) + elif self.requires_expert_analysis() and self.should_call_expert_analysis(self.consolidated_findings, request): + # Standard expert analysis path + response_data["status"] = "calling_expert_analysis" + + # Call expert analysis + expert_analysis = await self._call_expert_analysis(arguments, request) + response_data["expert_analysis"] = expert_analysis + + # Handle special expert analysis statuses + if isinstance(expert_analysis, dict) and expert_analysis.get("status") in [ + "files_required_to_continue", + "investigation_paused", + "refactoring_paused", + ]: + # Promote the special status to the main response + special_status = expert_analysis["status"] + response_data["status"] = special_status + response_data["content"] = expert_analysis.get( + "raw_analysis", json.dumps(expert_analysis, ensure_ascii=False) + ) + del response_data["expert_analysis"] + + # Update next steps for special status + if special_status == "files_required_to_continue": + response_data["next_steps"] = "Provide the requested files and continue the analysis." + else: + response_data["next_steps"] = expert_analysis.get( + "next_steps", "Continue based on expert analysis." + ) + elif isinstance(expert_analysis, dict) and expert_analysis.get("status") == "analysis_error": + # Expert analysis failed - promote error status + response_data["status"] = "error" + response_data["content"] = expert_analysis.get("error", "Expert analysis failed") + response_data["content_type"] = "text" + del response_data["expert_analysis"] + else: + # Expert analysis was successfully executed - include expert guidance + response_data["next_steps"] = self.get_completion_next_steps_message(expert_analysis_used=True) + + # Add expert analysis guidance as important considerations + expert_guidance = self.get_expert_analysis_guidance() + if expert_guidance: + response_data["important_considerations"] = expert_guidance + + # Prepare complete work summary + work_summary = self._prepare_work_summary() + response_data[f"complete_{self.get_name()}"] = { + "initial_request": self.get_initial_request(request.step), + "steps_taken": len(self.work_history), + "files_examined": list(self.consolidated_findings.files_checked), + "relevant_files": list(self.consolidated_findings.relevant_files), + "relevant_context": list(self.consolidated_findings.relevant_context), + "issues_found": self.consolidated_findings.issues_found, + "work_summary": work_summary, + } + else: + # Tool doesn't require expert analysis or local work was sufficient + if not self.requires_expert_analysis(): + # Tool is self-contained (like planner) + response_data["status"] = f"{self.get_name()}_complete" + response_data["next_steps"] = ( + f"{self.get_name().capitalize()} work complete. Present results to the user." + ) + else: + # Local work was sufficient for tools that support expert analysis + response_data["status"] = "local_work_complete" + response_data["next_steps"] = ( + f"Local {self.get_name()} complete with sufficient confidence. Present findings " + "and recommendations to the user based on the work results." + ) + + return response_data + + def handle_work_continuation(self, response_data: dict, request) -> dict: + """ + Handle work continuation - force pause and provide guidance. + """ + response_data["status"] = f"pause_for_{self.get_name()}" + response_data[f"{self.get_name()}_required"] = True + + # Get tool-specific required actions + required_actions = self.get_required_actions( + request.step_number, self.get_request_confidence(request), request.findings, request.total_steps, request + ) + response_data["required_actions"] = required_actions + + # Generate step guidance + response_data["next_steps"] = self.get_step_guidance_message(request) + + return response_data + + def _handle_backtracking(self, backtrack_step: int): + """Handle backtracking to a previous step""" + # Remove findings after the backtrack point + self.work_history = [s for s in self.work_history if s["step_number"] < backtrack_step] + # Reprocess consolidated findings + self._reprocess_consolidated_findings() + + def _update_consolidated_findings(self, step_data: dict): + """Update consolidated findings with new step data""" + self.consolidated_findings.files_checked.update(step_data.get("files_checked", [])) + self.consolidated_findings.relevant_files.update(step_data.get("relevant_files", [])) + self.consolidated_findings.relevant_context.update(step_data.get("relevant_context", [])) + self.consolidated_findings.findings.append(f"Step {step_data['step_number']}: {step_data['findings']}") + if step_data.get("hypothesis"): + self.consolidated_findings.hypotheses.append( + { + "step": step_data["step_number"], + "hypothesis": step_data["hypothesis"], + "confidence": step_data["confidence"], + } + ) + if step_data.get("issues_found"): + self.consolidated_findings.issues_found.extend(step_data["issues_found"]) + if step_data.get("images"): + self.consolidated_findings.images.extend(step_data["images"]) + # Update confidence to latest value from this step + if step_data.get("confidence"): + self.consolidated_findings.confidence = step_data["confidence"] + + def _reprocess_consolidated_findings(self): + """Reprocess consolidated findings after backtracking""" + self.consolidated_findings = ConsolidatedFindings() + for step in self.work_history: + self._update_consolidated_findings(step) + + def _prepare_work_summary(self) -> str: + """Prepare a comprehensive summary of the work""" + summary_parts = [ + f"=== {self.get_name().upper()} WORK SUMMARY ===", + f"Total steps: {len(self.work_history)}", + f"Files examined: {len(self.consolidated_findings.files_checked)}", + f"Relevant files identified: {len(self.consolidated_findings.relevant_files)}", + f"Methods/functions involved: {len(self.consolidated_findings.relevant_context)}", + f"Issues found: {len(self.consolidated_findings.issues_found)}", + "", + "=== WORK PROGRESSION ===", + ] + + for finding in self.consolidated_findings.findings: + summary_parts.append(finding) + + if self.consolidated_findings.hypotheses: + summary_parts.extend( + [ + "", + "=== HYPOTHESIS EVOLUTION ===", + ] + ) + for hyp in self.consolidated_findings.hypotheses: + summary_parts.append(f"Step {hyp['step']} ({hyp['confidence']} confidence): {hyp['hypothesis']}") + + if self.consolidated_findings.issues_found: + summary_parts.extend( + [ + "", + "=== ISSUES IDENTIFIED ===", + ] + ) + for issue in self.consolidated_findings.issues_found: + severity = issue.get("severity", "unknown") + description = issue.get("description", "No description") + summary_parts.append(f"[{severity.upper()}] {description}") + + return "\n".join(summary_parts) + + async def _call_expert_analysis(self, arguments: dict, request) -> dict: + """Call external model for expert analysis""" + try: + # Model context should be resolved from early validation, but handle fallback for tests + if not self._model_context: + # Try to resolve model context for expert analysis (deferred from early validation) + try: + model_name, model_context = self._resolve_model_context(arguments, request) + self._model_context = model_context + self._current_model_name = model_name + except Exception as e: + logger.error(f"Failed to resolve model context for expert analysis: {e}") + # Use request model as fallback (preserves existing test behavior) + model_name = self.get_request_model_name(request) + from utils.model_context import ModelContext + + model_context = ModelContext(model_name) + self._model_context = model_context + self._current_model_name = model_name + else: + model_name = self._current_model_name + + provider = self._model_context.provider + + # Prepare expert analysis context + expert_context = self.prepare_expert_analysis_context(self.consolidated_findings) + + # Check if tool wants to include files in prompt + if self.should_include_files_in_expert_prompt(): + file_content = self._prepare_files_for_expert_analysis() + if file_content: + expert_context = self._add_files_to_expert_context(expert_context, file_content) + + # Get system prompt for this tool with localization support + base_system_prompt = self.get_system_prompt() + language_instruction = self.get_language_instruction() + system_prompt = language_instruction + base_system_prompt + + # Check if tool wants system prompt embedded in main prompt + if self.should_embed_system_prompt(): + prompt = f"{system_prompt}\n\n{expert_context}\n\n{self.get_expert_analysis_instruction()}" + system_prompt = "" # Clear it since we embedded it + else: + prompt = expert_context + + # Validate temperature against model constraints + validated_temperature, temp_warnings = self.get_validated_temperature(request, self._model_context) + + # Log any temperature corrections + for warning in temp_warnings: + logger.warning(warning) + + # Generate AI response - use request parameters if available + model_response = provider.generate_content( + prompt=prompt, + model_name=model_name, + system_prompt=system_prompt, + temperature=validated_temperature, + thinking_mode=self.get_request_thinking_mode(request), + images=list(set(self.consolidated_findings.images)) if self.consolidated_findings.images else None, + ) + + if model_response.content: + content = model_response.content.strip() + + # Try to extract JSON from markdown code blocks if present + if "```json" in content or "```" in content: + json_match = re.search(r"```(?:json)?\s*(.*?)\s*```", content, re.DOTALL) + if json_match: + content = json_match.group(1).strip() + + try: + # Try to parse as JSON + analysis_result = json.loads(content) + return analysis_result + except json.JSONDecodeError as e: + # Log the parse error with more details but don't fail + logger.info( + f"[{self.get_name()}] Expert analysis returned non-JSON response (this is OK for smaller models). " + f"Parse error: {str(e)}. Response length: {len(model_response.content)} chars." + ) + logger.debug(f"First 500 chars of response: {model_response.content[:500]!r}") + + # Still return the analysis as plain text - this is valid + return { + "status": "analysis_complete", + "raw_analysis": model_response.content, + "format": "text", # Indicate it's plain text, not an error + "note": "Analysis provided in plain text format", + } + else: + return {"error": "No response from model", "status": "empty_response"} + + except Exception as e: + logger.error(f"Error calling expert analysis: {e}", exc_info=True) + return {"error": str(e), "status": "analysis_error"} + + def _process_work_step(self, step_data: dict): + """ + Process a single work step and update internal state. + + This method is useful for testing and manual step processing. + It adds the step to work history and updates consolidated findings. + + Args: + step_data: Dictionary containing step information including: + step, step_number, findings, files_checked, etc. + """ + # Store in history + self.work_history.append(step_data) + + # Update consolidated findings + self._update_consolidated_findings(step_data) + + # Common execute method for workflow-based tools + + async def execute(self, arguments: dict[str, Any]) -> list[TextContent]: + """ + Common execute logic for workflow-based tools. + + This method provides common validation and delegates to execute_workflow. + Tools that need custom execute logic can override this method. + """ + try: + # Common validation + if not arguments: + error_data = {"status": "error", "content": "No arguments provided"} + # Add basic metadata even for validation errors + error_data["metadata"] = {"tool_name": self.get_name()} + return [TextContent(type="text", text=json.dumps(error_data, ensure_ascii=False))] + + # Delegate to execute_workflow + return await self.execute_workflow(arguments) + + except Exception as e: + logger.error(f"Error in {self.get_name()} tool execution: {e}", exc_info=True) + error_data = { + "status": "error", + "content": f"Error in {self.get_name()}: {str(e)}", + } # Add metadata to error responses + self._add_workflow_metadata(error_data, arguments) + return [ + TextContent( + type="text", + text=json.dumps(error_data, ensure_ascii=False), + ) + ] + + # Default implementations for methods that workflow-based tools typically don't need + + async def prepare_prompt(self, request) -> str: + """ + Base implementation for workflow tools - compatible with BaseTool signature. + + Workflow tools typically don't need to return a prompt since they handle + their own prompt preparation internally through the workflow execution. + + Args: + request: The validated request object + + Returns: + Empty string since workflow tools manage prompts internally + """ + # Workflow tools handle their prompts internally during workflow execution + return "" + + def format_response(self, response: str, request, model_info=None): + """ + Workflow tools handle their own response formatting. + The BaseWorkflowMixin formats responses internally. + """ + return response diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..8024036 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,21 @@ +""" +Utility functions for Zen MCP Server +""" + +from .file_types import CODE_EXTENSIONS, FILE_CATEGORIES, PROGRAMMING_EXTENSIONS, TEXT_EXTENSIONS +from .file_utils import expand_paths, read_file_content, read_files +from .security_config import EXCLUDED_DIRS +from .token_utils import check_token_limit, estimate_tokens + +__all__ = [ + "read_files", + "read_file_content", + "expand_paths", + "CODE_EXTENSIONS", + "PROGRAMMING_EXTENSIONS", + "TEXT_EXTENSIONS", + "FILE_CATEGORIES", + "EXCLUDED_DIRS", + "estimate_tokens", + "check_token_limit", +] diff --git a/utils/client_info.py b/utils/client_info.py new file mode 100644 index 0000000..e32b7f3 --- /dev/null +++ b/utils/client_info.py @@ -0,0 +1,293 @@ +""" +Client Information Utility for MCP Server + +This module provides utilities to extract and format client information +from the MCP protocol's clientInfo sent during initialization. + +It also provides friendly name mapping and caching for consistent client +identification across the application. +""" + +import logging +from typing import Any, Optional + +logger = logging.getLogger(__name__) + +# Global cache for client information +_client_info_cache: Optional[dict[str, Any]] = None + +# Mapping of known client names to friendly names +# This is case-insensitive and checks if the key is contained in the client name +CLIENT_NAME_MAPPINGS = { + # Claude variants + "claude-ai": "Claude", + "claude": "Claude", + "claude-desktop": "Claude", + "claude-code": "Claude", + "anthropic": "Claude", + # Gemini variants + "gemini-cli-mcp-client": "Gemini", + "gemini-cli": "Gemini", + "gemini": "Gemini", + "google": "Gemini", + # Other known clients + "cursor": "Cursor", + "vscode": "VS Code", + "codeium": "Codeium", + "copilot": "GitHub Copilot", + # Generic MCP clients + "mcp-client": "MCP Client", + "test-client": "Test Client", +} + +# Default friendly name when no match is found +DEFAULT_FRIENDLY_NAME = "Claude" + + +def get_friendly_name(client_name: str) -> str: + """ + Map a client name to a friendly name. + + Args: + client_name: The raw client name from clientInfo + + Returns: + A friendly name for display (e.g., "Claude", "Gemini") + """ + if not client_name: + return DEFAULT_FRIENDLY_NAME + + # Convert to lowercase for case-insensitive matching + client_name_lower = client_name.lower() + + # Check each mapping - using 'in' to handle partial matches + for key, friendly_name in CLIENT_NAME_MAPPINGS.items(): + if key.lower() in client_name_lower: + return friendly_name + + # If no match found, return the default + return DEFAULT_FRIENDLY_NAME + + +def get_cached_client_info() -> Optional[dict[str, Any]]: + """ + Get cached client information if available. + + Returns: + Cached client info dictionary or None + """ + global _client_info_cache + return _client_info_cache + + +def get_client_info_from_context(server: Any) -> Optional[dict[str, Any]]: + """ + Extract client information from the MCP server's request context. + + The MCP protocol sends clientInfo during initialization containing: + - name: The client application name (e.g., "Claude Code", "Claude Desktop") + - version: The client version string + + This function also adds a friendly_name field and caches the result. + + Args: + server: The MCP server instance + + Returns: + Dictionary with client info or None if not available: + { + "name": "claude-ai", + "version": "1.0.0", + "friendly_name": "Claude" + } + """ + global _client_info_cache + + # Return cached info if available + if _client_info_cache is not None: + return _client_info_cache + + try: + # Try to access the request context and session + if not server: + return None + + # Check if server has request_context property + request_context = None + try: + request_context = server.request_context + except AttributeError: + logger.debug("Server does not have request_context property") + return None + + if not request_context: + logger.debug("Request context is None") + return None + + # Try to access session from request context + session = None + try: + session = request_context.session + except AttributeError: + logger.debug("Request context does not have session property") + return None + + if not session: + logger.debug("Session is None") + return None + + # Try to access client params from session + client_params = None + try: + # The clientInfo is stored in _client_params.clientInfo + client_params = session._client_params + except AttributeError: + logger.debug("Session does not have _client_params property") + return None + + if not client_params: + logger.debug("Client params is None") + return None + + # Try to extract clientInfo + client_info = None + try: + client_info = client_params.clientInfo + except AttributeError: + logger.debug("Client params does not have clientInfo property") + return None + + if not client_info: + logger.debug("Client info is None") + return None + + # Extract name and version + result = {} + + try: + result["name"] = client_info.name + except AttributeError: + logger.debug("Client info does not have name property") + + try: + result["version"] = client_info.version + except AttributeError: + logger.debug("Client info does not have version property") + + if not result: + return None + + # Add friendly name + raw_name = result.get("name", "") + result["friendly_name"] = get_friendly_name(raw_name) + + # Cache the result + _client_info_cache = result + logger.debug(f"Cached client info: {result}") + + return result + + except Exception as e: + logger.debug(f"Error extracting client info: {e}") + return None + + +def format_client_info(client_info: Optional[dict[str, Any]], use_friendly_name: bool = True) -> str: + """ + Format client information for display. + + Args: + client_info: Dictionary with client info or None + use_friendly_name: If True, use the friendly name instead of raw name + + Returns: + Formatted string like "Claude v1.0.0" or "Claude" + """ + if not client_info: + return DEFAULT_FRIENDLY_NAME + + if use_friendly_name: + name = client_info.get("friendly_name", client_info.get("name", DEFAULT_FRIENDLY_NAME)) + else: + name = client_info.get("name", "Unknown") + + version = client_info.get("version", "") + + if version and not use_friendly_name: + return f"{name} v{version}" + else: + # For friendly names, we just return the name without version + return name + + +def get_client_friendly_name() -> str: + """ + Get the cached client's friendly name. + + This is a convenience function that returns just the friendly name + from the cached client info, defaulting to "Claude" if not available. + + Returns: + The friendly name (e.g., "Claude", "Gemini") + """ + cached_info = get_cached_client_info() + if cached_info: + return cached_info.get("friendly_name", DEFAULT_FRIENDLY_NAME) + return DEFAULT_FRIENDLY_NAME + + +def log_client_info(server: Any, logger_instance: Optional[logging.Logger] = None) -> None: + """ + Log client information extracted from the server. + + Args: + server: The MCP server instance + logger_instance: Optional logger to use (defaults to module logger) + """ + log = logger_instance or logger + + client_info = get_client_info_from_context(server) + if client_info: + # Log with both raw and friendly names for debugging + raw_name = client_info.get("name", "Unknown") + friendly_name = client_info.get("friendly_name", DEFAULT_FRIENDLY_NAME) + version = client_info.get("version", "") + + if raw_name != friendly_name: + log.info(f"MCP Client Connected: {friendly_name} (raw: {raw_name} v{version})") + else: + log.info(f"MCP Client Connected: {friendly_name} v{version}") + + # Log to activity logger as well + try: + activity_logger = logging.getLogger("mcp_activity") + activity_logger.info(f"CLIENT_IDENTIFIED: {friendly_name} (name={raw_name}, version={version})") + except Exception: + pass + else: + log.debug("Could not extract client info from MCP protocol") + + +# Example usage in tools: +# +# from utils.client_info import get_client_friendly_name, get_cached_client_info +# +# # In a tool's execute method: +# def execute(self, arguments: dict[str, Any]) -> list[TextContent]: +# # Get the friendly name of the connected client +# client_name = get_client_friendly_name() # Returns "Claude" or "Gemini" etc. +# +# # Or get full cached info if needed +# client_info = get_cached_client_info() +# if client_info: +# raw_name = client_info['name'] # e.g., "claude-ai" +# version = client_info['version'] # e.g., "1.0.0" +# friendly = client_info['friendly_name'] # e.g., "Claude" +# +# # Customize response based on client +# if client_name == "Claude": +# response = f"Hello from Zen MCP Server to {client_name}!" +# elif client_name == "Gemini": +# response = f"Greetings {client_name}, welcome to Zen MCP Server!" +# else: +# response = f"Welcome {client_name}!" diff --git a/utils/conversation_memory.py b/utils/conversation_memory.py new file mode 100644 index 0000000..4226651 --- /dev/null +++ b/utils/conversation_memory.py @@ -0,0 +1,1095 @@ +""" +Conversation Memory for AI-to-AI Multi-turn Discussions + +This module provides conversation persistence and context reconstruction for +stateless MCP (Model Context Protocol) environments. It enables multi-turn +conversations between Claude and Gemini by storing conversation state in memory +across independent request cycles. + +CRITICAL ARCHITECTURAL REQUIREMENT: +This conversation memory system is designed for PERSISTENT MCP SERVER PROCESSES. +It uses in-memory storage that persists only within a single Python process. + +⚠️ IMPORTANT: This system will NOT work correctly if MCP tool calls are made + as separate subprocess invocations (each subprocess starts with empty memory). + + WORKING SCENARIO: Claude Desktop with persistent MCP server process + FAILING SCENARIO: Simulator tests calling server.py as individual subprocesses + + Root cause of test failures: Each subprocess call loses the conversation + state from previous calls because memory is process-specific, not shared + across subprocess boundaries. + +ARCHITECTURE OVERVIEW: +The MCP protocol is inherently stateless - each tool request is independent +with no memory of previous interactions. This module bridges that gap by: + +1. Creating persistent conversation threads with unique UUIDs +2. Storing complete conversation context (turns, files, metadata) in memory +3. Reconstructing conversation history when tools are called with continuation_id +4. Supporting cross-tool continuation - seamlessly switch between different tools + while maintaining full conversation context and file references + +CROSS-TOOL CONTINUATION: +A conversation started with one tool (e.g., 'analyze') can be continued with +any other tool (e.g., 'codereview', 'debug', 'chat') using the same continuation_id. +The second tool will have access to: +- All previous conversation turns and responses +- File context from previous tools (preserved in conversation history) +- Original thread metadata and timing information +- Accumulated knowledge from the entire conversation + +Key Features: +- UUID-based conversation thread identification with security validation +- Turn-by-turn conversation history storage with tool attribution +- Cross-tool continuation support - switch tools while preserving context +- File context preservation - files shared in earlier turns remain accessible +- NEWEST-FIRST FILE PRIORITIZATION - when the same file appears in multiple turns, + references from newer turns take precedence over older ones. This ensures the + most recent file context is preserved when token limits require exclusions. +- Automatic turn limiting (20 turns max) to prevent runaway conversations +- Context reconstruction for stateless request continuity +- In-memory persistence with automatic expiration (3 hour TTL) +- Thread-safe operations for concurrent access +- Graceful degradation when storage is unavailable + +DUAL PRIORITIZATION STRATEGY (Files & Conversations): +The conversation memory system implements sophisticated prioritization for both files and +conversation turns, using a consistent "newest-first" approach during collection but +presenting information in the optimal format for LLM consumption: + +FILE PRIORITIZATION (Newest-First Throughout): +1. When collecting files across conversation turns, the system walks BACKWARDS through + turns (newest to oldest) and builds a unique file list +2. If the same file path appears in multiple turns, only the reference from the + NEWEST turn is kept in the final list +3. This "newest-first" ordering is preserved throughout the entire pipeline: + - get_conversation_file_list() establishes the order + - build_conversation_history() maintains it during token budgeting + - When token limits are hit, OLDER files are excluded first +4. This strategy works across conversation chains - files from newer turns in ANY + thread take precedence over files from older turns in ANY thread + +CONVERSATION TURN PRIORITIZATION (Newest-First Collection, Chronological Presentation): +1. COLLECTION PHASE: Processes turns newest-to-oldest to prioritize recent context + - When token budget is tight, OLDER turns are excluded first + - Ensures most contextually relevant recent exchanges are preserved +2. PRESENTATION PHASE: Reverses collected turns to chronological order (oldest-first) + - LLM sees natural conversation flow: "Turn 1 → Turn 2 → Turn 3..." + - Maintains proper sequential understanding while preserving recency prioritization + +This dual approach ensures optimal context preservation (newest-first) with natural +conversation flow (chronological) for maximum LLM comprehension and relevance. + +USAGE EXAMPLE: +1. Tool A creates thread: create_thread("analyze", request_data) → returns UUID +2. Tool A adds response: add_turn(UUID, "assistant", response, files=[...], tool_name="analyze") +3. Tool B continues thread: get_thread(UUID) → retrieves full context +4. Tool B sees conversation history via build_conversation_history() +5. Tool B adds its response: add_turn(UUID, "assistant", response, tool_name="codereview") + +DUAL STRATEGY EXAMPLE: +Conversation has 5 turns, token budget allows only 3 turns: + +Collection Phase (Newest-First Priority): +- Evaluates: Turn 5 → Turn 4 → Turn 3 → Turn 2 → Turn 1 +- Includes: Turn 5, Turn 4, Turn 3 (newest 3 fit in budget) +- Excludes: Turn 2, Turn 1 (oldest, dropped due to token limits) + +Presentation Phase (Chronological Order): +- LLM sees: "--- Turn 3 (Claude) ---", "--- Turn 4 (Gemini) ---", "--- Turn 5 (Claude) ---" +- Natural conversation flow maintained despite prioritizing recent context + +This enables true AI-to-AI collaboration across the entire tool ecosystem with optimal +context preservation and natural conversation understanding. +""" + +import logging +import os +import uuid +from datetime import datetime, timezone +from typing import Any, Optional + +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +# Configuration constants +# Get max conversation turns from environment, default to 20 turns (10 exchanges) +try: + MAX_CONVERSATION_TURNS = int(os.getenv("MAX_CONVERSATION_TURNS", "20")) + if MAX_CONVERSATION_TURNS <= 0: + logger.warning(f"Invalid MAX_CONVERSATION_TURNS value ({MAX_CONVERSATION_TURNS}), using default of 20 turns") + MAX_CONVERSATION_TURNS = 20 +except ValueError: + logger.warning( + f"Invalid MAX_CONVERSATION_TURNS value ('{os.getenv('MAX_CONVERSATION_TURNS')}'), using default of 20 turns" + ) + MAX_CONVERSATION_TURNS = 20 + +# Get conversation timeout from environment (in hours), default to 3 hours +try: + CONVERSATION_TIMEOUT_HOURS = int(os.getenv("CONVERSATION_TIMEOUT_HOURS", "3")) + if CONVERSATION_TIMEOUT_HOURS <= 0: + logger.warning( + f"Invalid CONVERSATION_TIMEOUT_HOURS value ({CONVERSATION_TIMEOUT_HOURS}), using default of 3 hours" + ) + CONVERSATION_TIMEOUT_HOURS = 3 +except ValueError: + logger.warning( + f"Invalid CONVERSATION_TIMEOUT_HOURS value ('{os.getenv('CONVERSATION_TIMEOUT_HOURS')}'), using default of 3 hours" + ) + CONVERSATION_TIMEOUT_HOURS = 3 + +CONVERSATION_TIMEOUT_SECONDS = CONVERSATION_TIMEOUT_HOURS * 3600 + + +class ConversationTurn(BaseModel): + """ + Single turn in a conversation + + Represents one exchange in the AI-to-AI conversation, tracking both + the content and metadata needed for cross-tool continuation. + + Attributes: + role: "user" (Claude) or "assistant" (Gemini/O3/etc) + content: The actual message content/response + timestamp: ISO timestamp when this turn was created + files: List of file paths referenced in this specific turn + images: List of image paths referenced in this specific turn + tool_name: Which tool generated this turn (for cross-tool tracking) + model_provider: Provider used (e.g., "google", "openai") + model_name: Specific model used (e.g., "gemini-2.5-flash", "o3-mini") + model_metadata: Additional model-specific metadata (e.g., thinking mode, token usage) + """ + + role: str # "user" or "assistant" + content: str + timestamp: str + files: Optional[list[str]] = None # Files referenced in this turn + images: Optional[list[str]] = None # Images referenced in this turn + tool_name: Optional[str] = None # Tool used for this turn + model_provider: Optional[str] = None # Model provider (google, openai, etc) + model_name: Optional[str] = None # Specific model used + model_metadata: Optional[dict[str, Any]] = None # Additional model info + + +class ThreadContext(BaseModel): + """ + Complete conversation context for a thread + + Contains all information needed to reconstruct a conversation state + across different tools and request cycles. This is the core data + structure that enables cross-tool continuation. + + Attributes: + thread_id: UUID identifying this conversation thread + parent_thread_id: UUID of parent thread (for conversation chains) + created_at: ISO timestamp when thread was created + last_updated_at: ISO timestamp of last modification + tool_name: Name of the tool that initiated this thread + turns: List of all conversation turns in chronological order + initial_context: Original request data that started the conversation + """ + + thread_id: str + parent_thread_id: Optional[str] = None # Parent thread for conversation chains + created_at: str + last_updated_at: str + tool_name: str # Tool that created this thread (preserved for attribution) + turns: list[ConversationTurn] + initial_context: dict[str, Any] # Original request parameters + + +def get_storage(): + """ + Get in-memory storage backend for conversation persistence. + + Returns: + InMemoryStorage: Thread-safe in-memory storage backend + """ + from .storage_backend import get_storage_backend + + return get_storage_backend() + + +def create_thread(tool_name: str, initial_request: dict[str, Any], parent_thread_id: Optional[str] = None) -> str: + """ + Create new conversation thread and return thread ID + + Initializes a new conversation thread for AI-to-AI discussions. + This is called when a tool wants to enable follow-up conversations + or when Claude explicitly starts a multi-turn interaction. + + Args: + tool_name: Name of the tool creating this thread (e.g., "analyze", "chat") + initial_request: Original request parameters (will be filtered for serialization) + parent_thread_id: Optional parent thread ID for conversation chains + + Returns: + str: UUID thread identifier that can be used for continuation + + Note: + - Thread expires after the configured timeout (default: 3 hours) + - Non-serializable parameters are filtered out automatically + - Thread can be continued by any tool using the returned UUID + - Parent thread creates a chain for conversation history traversal + """ + thread_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc).isoformat() + + # Filter out non-serializable parameters to avoid JSON encoding issues + filtered_context = { + k: v + for k, v in initial_request.items() + if k not in ["temperature", "thinking_mode", "model", "continuation_id"] + } + + context = ThreadContext( + thread_id=thread_id, + parent_thread_id=parent_thread_id, # Link to parent for conversation chains + created_at=now, + last_updated_at=now, + tool_name=tool_name, # Track which tool initiated this conversation + turns=[], # Empty initially, turns added via add_turn() + initial_context=filtered_context, + ) + + # Store in memory with configurable TTL to prevent indefinite accumulation + storage = get_storage() + key = f"thread:{thread_id}" + storage.setex(key, CONVERSATION_TIMEOUT_SECONDS, context.model_dump_json()) + + logger.debug(f"[THREAD] Created new thread {thread_id} with parent {parent_thread_id}") + + return thread_id + + +def get_thread(thread_id: str) -> Optional[ThreadContext]: + """ + Retrieve thread context from in-memory storage + + Fetches complete conversation context for cross-tool continuation. + This is the core function that enables tools to access conversation + history from previous interactions. + + Args: + thread_id: UUID of the conversation thread + + Returns: + ThreadContext: Complete conversation context if found + None: If thread doesn't exist, expired, or invalid UUID + + Security: + - Validates UUID format to prevent injection attacks + - Handles storage connection failures gracefully + - No error information leakage on failure + """ + if not thread_id or not _is_valid_uuid(thread_id): + return None + + try: + storage = get_storage() + key = f"thread:{thread_id}" + data = storage.get(key) + + if data: + return ThreadContext.model_validate_json(data) + return None + except Exception: + # Silently handle errors to avoid exposing storage details + return None + + +def add_turn( + thread_id: str, + role: str, + content: str, + files: Optional[list[str]] = None, + images: Optional[list[str]] = None, + tool_name: Optional[str] = None, + model_provider: Optional[str] = None, + model_name: Optional[str] = None, + model_metadata: Optional[dict[str, Any]] = None, +) -> bool: + """ + Add turn to existing thread with atomic file ordering. + + Appends a new conversation turn to an existing thread. This is the core + function for building conversation history and enabling cross-tool + continuation. Each turn preserves the tool and model that generated it. + + Args: + thread_id: UUID of the conversation thread + role: "user" (Claude) or "assistant" (Gemini/O3/etc) + content: The actual message/response content + files: Optional list of files referenced in this turn + images: Optional list of images referenced in this turn + tool_name: Name of the tool adding this turn (for attribution) + model_provider: Provider used (e.g., "google", "openai") + model_name: Specific model used (e.g., "gemini-2.5-flash", "o3-mini") + model_metadata: Additional model info (e.g., thinking mode, token usage) + + Returns: + bool: True if turn was successfully added, False otherwise + + Failure cases: + - Thread doesn't exist or expired + - Maximum turn limit reached + - Storage connection failure + + Note: + - Refreshes thread TTL to configured timeout on successful update + - Turn limits prevent runaway conversations + - File references are preserved for cross-tool access with atomic ordering + - Image references are preserved for cross-tool visual context + - Model information enables cross-provider conversations + """ + logger.debug(f"[FLOW] Adding {role} turn to {thread_id} ({tool_name})") + + context = get_thread(thread_id) + if not context: + logger.debug(f"[FLOW] Thread {thread_id} not found for turn addition") + return False + + # Check turn limit to prevent runaway conversations + if len(context.turns) >= MAX_CONVERSATION_TURNS: + logger.debug(f"[FLOW] Thread {thread_id} at max turns ({MAX_CONVERSATION_TURNS})") + return False + + # Create new turn with complete metadata + turn = ConversationTurn( + role=role, + content=content, + timestamp=datetime.now(timezone.utc).isoformat(), + files=files, # Preserved for cross-tool file context + images=images, # Preserved for cross-tool visual context + tool_name=tool_name, # Track which tool generated this turn + model_provider=model_provider, # Track model provider + model_name=model_name, # Track specific model + model_metadata=model_metadata, # Additional model info + ) + + context.turns.append(turn) + context.last_updated_at = datetime.now(timezone.utc).isoformat() + + # Save back to storage and refresh TTL + try: + storage = get_storage() + key = f"thread:{thread_id}" + storage.setex(key, CONVERSATION_TIMEOUT_SECONDS, context.model_dump_json()) # Refresh TTL to configured timeout + return True + except Exception as e: + logger.debug(f"[FLOW] Failed to save turn to storage: {type(e).__name__}") + return False + + +def get_thread_chain(thread_id: str, max_depth: int = 20) -> list[ThreadContext]: + """ + Traverse the parent chain to get all threads in conversation sequence. + + Retrieves the complete conversation chain by following parent_thread_id + links. Returns threads in chronological order (oldest first). + + Args: + thread_id: Starting thread ID + max_depth: Maximum chain depth to prevent infinite loops + + Returns: + list[ThreadContext]: All threads in chain, oldest first + """ + chain = [] + current_id = thread_id + seen_ids = set() + + # Build chain from current to oldest + while current_id and len(chain) < max_depth: + # Prevent circular references + if current_id in seen_ids: + logger.warning(f"[THREAD] Circular reference detected in thread chain at {current_id}") + break + + seen_ids.add(current_id) + + context = get_thread(current_id) + if not context: + logger.debug(f"[THREAD] Thread {current_id} not found in chain traversal") + break + + chain.append(context) + current_id = context.parent_thread_id + + # Reverse to get chronological order (oldest first) + chain.reverse() + + logger.debug(f"[THREAD] Retrieved chain of {len(chain)} threads for {thread_id}") + return chain + + +def get_conversation_file_list(context: ThreadContext) -> list[str]: + """ + Extract all unique files from conversation turns with newest-first prioritization. + + This function implements the core file prioritization logic used throughout the + conversation memory system. It walks backwards through conversation turns + (from newest to oldest) and collects unique file references, ensuring that + when the same file appears in multiple turns, the reference from the NEWEST + turn takes precedence. + + PRIORITIZATION ALGORITHM: + 1. Iterate through turns in REVERSE order (index len-1 down to 0) + 2. For each turn, process files in the order they appear in turn.files + 3. Add file to result list only if not already seen (newest reference wins) + 4. Skip duplicate files that were already added from newer turns + + This ensures that: + - Files from newer conversation turns appear first in the result + - When the same file is referenced multiple times, only the newest reference is kept + - The order reflects the most recent conversation context + + Example: + Turn 1: files = ["main.py", "utils.py"] + Turn 2: files = ["test.py"] + Turn 3: files = ["main.py", "config.py"] # main.py appears again + + Result: ["main.py", "config.py", "test.py", "utils.py"] + (main.py from Turn 3 takes precedence over Turn 1) + + Args: + context: ThreadContext containing all conversation turns to process + + Returns: + list[str]: Unique file paths ordered by newest reference first. + Empty list if no turns exist or no files are referenced. + + Performance: + - Time Complexity: O(n*m) where n=turns, m=avg files per turn + - Space Complexity: O(f) where f=total unique files + - Uses set for O(1) duplicate detection + """ + if not context.turns: + logger.debug("[FILES] No turns found, returning empty file list") + return [] + + # Collect files by walking backwards (newest to oldest turns) + seen_files = set() + file_list = [] + + logger.debug(f"[FILES] Collecting files from {len(context.turns)} turns (newest first)") + + # Process turns in reverse order (newest first) - this is the CORE of newest-first prioritization + # By iterating from len-1 down to 0, we encounter newer turns before older turns + # When we find a duplicate file, we skip it because the newer version is already in our list + for i in range(len(context.turns) - 1, -1, -1): # REVERSE: newest turn first + turn = context.turns[i] + if turn.files: + logger.debug(f"[FILES] Turn {i + 1} has {len(turn.files)} files: {turn.files}") + for file_path in turn.files: + if file_path not in seen_files: + # First time seeing this file - add it (this is the NEWEST reference) + seen_files.add(file_path) + file_list.append(file_path) + logger.debug(f"[FILES] Added new file: {file_path} (from turn {i + 1})") + else: + # File already seen from a NEWER turn - skip this older reference + logger.debug(f"[FILES] Skipping duplicate file: {file_path} (newer version already included)") + + logger.debug(f"[FILES] Final file list ({len(file_list)}): {file_list}") + return file_list + + +def get_conversation_image_list(context: ThreadContext) -> list[str]: + """ + Extract all unique images from conversation turns with newest-first prioritization. + + This function implements the identical prioritization logic as get_conversation_file_list() + to ensure consistency in how images are handled across conversation turns. It walks + backwards through conversation turns (from newest to oldest) and collects unique image + references, ensuring that when the same image appears in multiple turns, the reference + from the NEWEST turn takes precedence. + + PRIORITIZATION ALGORITHM: + 1. Iterate through turns in REVERSE order (index len-1 down to 0) + 2. For each turn, process images in the order they appear in turn.images + 3. Add image to result list only if not already seen (newest reference wins) + 4. Skip duplicate images that were already added from newer turns + + This ensures that: + - Images from newer conversation turns appear first in the result + - When the same image is referenced multiple times, only the newest reference is kept + - The order reflects the most recent conversation context + + Example: + Turn 1: images = ["diagram.png", "flow.jpg"] + Turn 2: images = ["error.png"] + Turn 3: images = ["diagram.png", "updated.png"] # diagram.png appears again + + Result: ["diagram.png", "updated.png", "error.png", "flow.jpg"] + (diagram.png from Turn 3 takes precedence over Turn 1) + + Args: + context: ThreadContext containing all conversation turns to process + + Returns: + list[str]: Unique image paths ordered by newest reference first. + Empty list if no turns exist or no images are referenced. + + Performance: + - Time Complexity: O(n*m) where n=turns, m=avg images per turn + - Space Complexity: O(i) where i=total unique images + - Uses set for O(1) duplicate detection + """ + if not context.turns: + logger.debug("[IMAGES] No turns found, returning empty image list") + return [] + + # Collect images by walking backwards (newest to oldest turns) + seen_images = set() + image_list = [] + + logger.debug(f"[IMAGES] Collecting images from {len(context.turns)} turns (newest first)") + + # Process turns in reverse order (newest first) - this is the CORE of newest-first prioritization + # By iterating from len-1 down to 0, we encounter newer turns before older turns + # When we find a duplicate image, we skip it because the newer version is already in our list + for i in range(len(context.turns) - 1, -1, -1): # REVERSE: newest turn first + turn = context.turns[i] + if turn.images: + logger.debug(f"[IMAGES] Turn {i + 1} has {len(turn.images)} images: {turn.images}") + for image_path in turn.images: + if image_path not in seen_images: + # First time seeing this image - add it (this is the NEWEST reference) + seen_images.add(image_path) + image_list.append(image_path) + logger.debug(f"[IMAGES] Added new image: {image_path} (from turn {i + 1})") + else: + # Image already seen from a NEWER turn - skip this older reference + logger.debug(f"[IMAGES] Skipping duplicate image: {image_path} (newer version already included)") + + logger.debug(f"[IMAGES] Final image list ({len(image_list)}): {image_list}") + return image_list + + +def _plan_file_inclusion_by_size(all_files: list[str], max_file_tokens: int) -> tuple[list[str], list[str], int]: + """ + Plan which files to include based on size constraints. + + This is ONLY used for conversation history building, not MCP boundary checks. + + Args: + all_files: List of files to consider for inclusion + max_file_tokens: Maximum tokens available for file content + + Returns: + Tuple of (files_to_include, files_to_skip, estimated_total_tokens) + """ + if not all_files: + return [], [], 0 + + files_to_include = [] + files_to_skip = [] + total_tokens = 0 + + logger.debug(f"[FILES] Planning inclusion for {len(all_files)} files with budget {max_file_tokens:,} tokens") + + for file_path in all_files: + try: + from utils.file_utils import estimate_file_tokens + + if os.path.exists(file_path) and os.path.isfile(file_path): + # Use centralized token estimation for consistency + estimated_tokens = estimate_file_tokens(file_path) + + if total_tokens + estimated_tokens <= max_file_tokens: + files_to_include.append(file_path) + total_tokens += estimated_tokens + logger.debug( + f"[FILES] Including {file_path} - {estimated_tokens:,} tokens (total: {total_tokens:,})" + ) + else: + files_to_skip.append(file_path) + logger.debug( + f"[FILES] Skipping {file_path} - would exceed budget (needs {estimated_tokens:,} tokens)" + ) + else: + files_to_skip.append(file_path) + # More descriptive message for missing files + if not os.path.exists(file_path): + logger.debug( + f"[FILES] Skipping {file_path} - file no longer exists (may have been moved/deleted since conversation)" + ) + else: + logger.debug(f"[FILES] Skipping {file_path} - file not accessible (not a regular file)") + + except Exception as e: + files_to_skip.append(file_path) + logger.debug(f"[FILES] Skipping {file_path} - error during processing: {type(e).__name__}: {e}") + + logger.debug( + f"[FILES] Inclusion plan: {len(files_to_include)} include, {len(files_to_skip)} skip, {total_tokens:,} tokens" + ) + return files_to_include, files_to_skip, total_tokens + + +def build_conversation_history(context: ThreadContext, model_context=None, read_files_func=None) -> tuple[str, int]: + """ + Build formatted conversation history for tool prompts with embedded file contents. + + Creates a comprehensive conversation history that includes both conversation turns and + file contents, with intelligent prioritization to maximize relevant context within + token limits. This function enables stateless tools to access complete conversation + context from previous interactions, including cross-tool continuations. + + FILE PRIORITIZATION BEHAVIOR: + Files from newer conversation turns are prioritized over files from older turns. + When the same file appears in multiple turns, the reference from the NEWEST turn + takes precedence. This ensures the most recent file context is preserved when + token limits require file exclusions. + + CONVERSATION CHAIN HANDLING: + If the thread has a parent_thread_id, this function traverses the entire chain + to include complete conversation history across multiple linked threads. File + prioritization works across the entire chain, not just the current thread. + + CONVERSATION TURN ORDERING STRATEGY: + The function employs a sophisticated two-phase approach for optimal token utilization: + + PHASE 1 - COLLECTION (Newest-First for Token Budget): + - Processes conversation turns in REVERSE chronological order (newest to oldest) + - Prioritizes recent turns within token constraints + - If token budget is exceeded, OLDER turns are excluded first + - Ensures the most contextually relevant recent exchanges are preserved + + PHASE 2 - PRESENTATION (Chronological for LLM Understanding): + - Reverses the collected turns back to chronological order (oldest to newest) + - Presents conversation flow naturally for LLM comprehension + - Maintains "--- Turn 1, Turn 2, Turn 3..." sequential numbering + - Enables LLM to follow conversation progression logically + + This approach balances recency prioritization with natural conversation flow. + + TOKEN MANAGEMENT: + - Uses model-specific token allocation (file_tokens + history_tokens) + - Files are embedded ONCE at the start to prevent duplication + - Turn collection prioritizes newest-first, presentation shows chronologically + - Stops adding turns when token budget would be exceeded + - Gracefully handles token limits with informative notes + + Args: + context: ThreadContext containing the conversation to format + model_context: ModelContext for token allocation (optional, uses DEFAULT_MODEL fallback) + read_files_func: Optional function to read files (primarily for testing) + + Returns: + tuple[str, int]: (formatted_conversation_history, total_tokens_used) + Returns ("", 0) if no conversation turns exist in the context + + Output Format: + === CONVERSATION HISTORY (CONTINUATION) === + Thread: + Tool: + Turn / + You are continuing this conversation thread from where it left off. + + === FILES REFERENCED IN THIS CONVERSATION === + The following files have been shared and analyzed during our conversation. + [NOTE: X files omitted due to size constraints] + Refer to these when analyzing the context and requests below: + + + + === END REFERENCED FILES === + + Previous conversation turns: + + --- Turn 1 (Claude) --- + Files used in this turn: file1.py, file2.py + + + + --- Turn 2 (Gemini using analyze via google/gemini-2.5-flash) --- + Files used in this turn: file3.py + + + + === END CONVERSATION HISTORY === + + IMPORTANT: You are continuing an existing conversation thread... + This is turn X of the conversation - use the conversation history above... + + Cross-Tool Collaboration: + This formatted history allows any tool to "see" both conversation context AND + file contents from previous tools, enabling seamless handoffs between analyze, + codereview, debug, chat, and other tools while maintaining complete context. + + Performance Characteristics: + - O(n) file collection with newest-first prioritization + - Intelligent token budgeting prevents context window overflow + - In-memory persistence with automatic TTL management + - Graceful degradation when files are inaccessible or too large + """ + # Get the complete thread chain + if context.parent_thread_id: + # This thread has a parent, get the full chain + chain = get_thread_chain(context.thread_id) + + # Collect all turns from all threads in chain + all_turns = [] + total_turns = 0 + + for thread in chain: + all_turns.extend(thread.turns) + total_turns += len(thread.turns) + + # Use centralized file collection logic for consistency across the entire chain + # This ensures files from newer turns across ALL threads take precedence + # over files from older turns, maintaining the newest-first prioritization + # even when threads are chained together + temp_context = ThreadContext( + thread_id="merged_chain", + created_at=context.created_at, + last_updated_at=context.last_updated_at, + tool_name=context.tool_name, + turns=all_turns, # All turns from entire chain in chronological order + initial_context=context.initial_context, + ) + all_files = get_conversation_file_list(temp_context) # Applies newest-first logic to entire chain + logger.debug(f"[THREAD] Built history from {len(chain)} threads with {total_turns} total turns") + else: + # Single thread, no parent chain + all_turns = context.turns + total_turns = len(context.turns) + all_files = get_conversation_file_list(context) + + if not all_turns: + return "", 0 + + logger.debug(f"[FILES] Found {len(all_files)} unique files in conversation history") + + # Get model-specific token allocation early (needed for both files and turns) + if model_context is None: + from config import DEFAULT_MODEL, IS_AUTO_MODE + from utils.model_context import ModelContext + + # In auto mode, use an intelligent fallback model for token calculations + # since "auto" is not a real model with a provider + model_name = DEFAULT_MODEL + if IS_AUTO_MODE and model_name.lower() == "auto": + # Use intelligent fallback based on available API keys + from providers.registry import ModelProviderRegistry + + model_name = ModelProviderRegistry.get_preferred_fallback_model() + + model_context = ModelContext(model_name) + + token_allocation = model_context.calculate_token_allocation() + max_file_tokens = token_allocation.file_tokens + max_history_tokens = token_allocation.history_tokens + + logger.debug(f"[HISTORY] Using model-specific limits for {model_context.model_name}:") + logger.debug(f"[HISTORY] Max file tokens: {max_file_tokens:,}") + logger.debug(f"[HISTORY] Max history tokens: {max_history_tokens:,}") + + history_parts = [ + "=== CONVERSATION HISTORY (CONTINUATION) ===", + f"Thread: {context.thread_id}", + f"Tool: {context.tool_name}", # Original tool that started the conversation + f"Turn {total_turns}/{MAX_CONVERSATION_TURNS}", + "You are continuing this conversation thread from where it left off.", + "", + ] + + # Embed files referenced in this conversation with size-aware selection + if all_files: + logger.debug(f"[FILES] Starting embedding for {len(all_files)} files") + + # Plan file inclusion based on size constraints + # CRITICAL: all_files is already ordered by newest-first prioritization from get_conversation_file_list() + # So when _plan_file_inclusion_by_size() hits token limits, it naturally excludes OLDER files first + # while preserving the most recent file references - exactly what we want! + files_to_include, files_to_skip, estimated_tokens = _plan_file_inclusion_by_size(all_files, max_file_tokens) + + if files_to_skip: + logger.info(f"[FILES] Excluding {len(files_to_skip)} files from conversation history: {files_to_skip}") + logger.debug("[FILES] Files excluded for various reasons (size constraints, missing files, access issues)") + + if files_to_include: + history_parts.extend( + [ + "=== FILES REFERENCED IN THIS CONVERSATION ===", + "The following files have been shared and analyzed during our conversation.", + ( + "" + if not files_to_skip + else f"[NOTE: {len(files_to_skip)} files omitted (size constraints, missing files, or access issues)]" + ), + "Refer to these when analyzing the context and requests below:", + "", + ] + ) + + if read_files_func is None: + from utils.file_utils import read_file_content + + # Process files for embedding + file_contents = [] + total_tokens = 0 + files_included = 0 + + for file_path in files_to_include: + try: + logger.debug(f"[FILES] Processing file {file_path}") + formatted_content, content_tokens = read_file_content(file_path) + if formatted_content: + file_contents.append(formatted_content) + total_tokens += content_tokens + files_included += 1 + logger.debug( + f"File embedded in conversation history: {file_path} ({content_tokens:,} tokens)" + ) + else: + logger.debug(f"File skipped (empty content): {file_path}") + except Exception as e: + # More descriptive error handling for missing files + try: + if not os.path.exists(file_path): + logger.info( + f"File no longer accessible for conversation history: {file_path} - file was moved/deleted since conversation (marking as excluded)" + ) + else: + logger.warning( + f"Failed to embed file in conversation history: {file_path} - {type(e).__name__}: {e}" + ) + except Exception: + # Fallback if path translation also fails + logger.warning( + f"Failed to embed file in conversation history: {file_path} - {type(e).__name__}: {e}" + ) + continue + + if file_contents: + files_content = "".join(file_contents) + if files_to_skip: + files_content += ( + f"\n[NOTE: {len(files_to_skip)} additional file(s) were omitted due to size constraints, missing files, or access issues. " + f"These were older files from earlier conversation turns.]\n" + ) + history_parts.append(files_content) + logger.debug( + f"Conversation history file embedding complete: {files_included} files embedded, {len(files_to_skip)} omitted, {total_tokens:,} total tokens" + ) + else: + history_parts.append("(No accessible files found)") + logger.debug(f"[FILES] No accessible files found from {len(files_to_include)} planned files") + else: + # Fallback to original read_files function + files_content = read_files_func(all_files) + if files_content: + # Add token validation for the combined file content + from utils.token_utils import check_token_limit + + within_limit, estimated_tokens = check_token_limit(files_content) + if within_limit: + history_parts.append(files_content) + else: + # Handle token limit exceeded for conversation files + error_message = f"ERROR: The total size of files referenced in this conversation has exceeded the context limit and cannot be displayed.\nEstimated tokens: {estimated_tokens}, but limit is {max_file_tokens}." + history_parts.append(error_message) + else: + history_parts.append("(No accessible files found)") + + history_parts.extend( + [ + "", + "=== END REFERENCED FILES ===", + "", + ] + ) + + history_parts.append("Previous conversation turns:") + + # === PHASE 1: COLLECTION (Newest-First for Token Budget) === + # Build conversation turns bottom-up (most recent first) to prioritize recent context within token limits + # This ensures we include as many recent turns as possible within the token budget by excluding + # OLDER turns first when space runs out, preserving the most contextually relevant exchanges + turn_entries = [] # Will store (index, formatted_turn_content) for chronological ordering later + total_turn_tokens = 0 + file_embedding_tokens = sum(model_context.estimate_tokens(part) for part in history_parts) + + # CRITICAL: Process turns in REVERSE chronological order (newest to oldest) + # This prioritization strategy ensures recent context is preserved when token budget is tight + for idx in range(len(all_turns) - 1, -1, -1): + turn = all_turns[idx] + turn_num = idx + 1 + role_label = "Claude" if turn.role == "user" else "Gemini" + + # Build the complete turn content + turn_parts = [] + + # Add turn header with tool attribution for cross-tool tracking + turn_header = f"\n--- Turn {turn_num} ({role_label}" + if turn.tool_name: + turn_header += f" using {turn.tool_name}" + + # Add model info if available + if turn.model_provider and turn.model_name: + turn_header += f" via {turn.model_provider}/{turn.model_name}" + + turn_header += ") ---" + turn_parts.append(turn_header) + + # Get tool-specific formatting if available + # This includes file references and the actual content + tool_formatted_content = _get_tool_formatted_content(turn) + turn_parts.extend(tool_formatted_content) + + # Calculate tokens for this turn + turn_content = "\n".join(turn_parts) + turn_tokens = model_context.estimate_tokens(turn_content) + + # Check if adding this turn would exceed history budget + if file_embedding_tokens + total_turn_tokens + turn_tokens > max_history_tokens: + # Stop adding turns - we've reached the limit + logger.debug(f"[HISTORY] Stopping at turn {turn_num} - would exceed history budget") + logger.debug(f"[HISTORY] File tokens: {file_embedding_tokens:,}") + logger.debug(f"[HISTORY] Turn tokens so far: {total_turn_tokens:,}") + logger.debug(f"[HISTORY] This turn: {turn_tokens:,}") + logger.debug(f"[HISTORY] Would total: {file_embedding_tokens + total_turn_tokens + turn_tokens:,}") + logger.debug(f"[HISTORY] Budget: {max_history_tokens:,}") + break + + # Add this turn to our collection (we'll reverse it later for chronological presentation) + # Store the original index to maintain proper turn numbering in final output + turn_entries.append((idx, turn_content)) + total_turn_tokens += turn_tokens + + # === PHASE 2: PRESENTATION (Chronological for LLM Understanding) === + # Reverse the collected turns to restore chronological order (oldest first) + # This gives the LLM a natural conversation flow: Turn 1 → Turn 2 → Turn 3... + # while still having prioritized recent turns during the token-constrained collection phase + turn_entries.reverse() + + # Add the turns in chronological order for natural LLM comprehension + # The LLM will see: "--- Turn 1 (Claude) ---" followed by "--- Turn 2 (Gemini) ---" etc. + for _, turn_content in turn_entries: + history_parts.append(turn_content) + + # Log what we included + included_turns = len(turn_entries) + total_turns = len(all_turns) + if included_turns < total_turns: + logger.info(f"[HISTORY] Included {included_turns}/{total_turns} turns due to token limit") + history_parts.append(f"\n[Note: Showing {included_turns} most recent turns out of {total_turns} total]") + + history_parts.extend( + [ + "", + "=== END CONVERSATION HISTORY ===", + "", + "IMPORTANT: You are continuing an existing conversation thread. Build upon the previous exchanges shown above,", + "reference earlier points, and maintain consistency with what has been discussed.", + "", + "DO NOT repeat or summarize previous analysis, findings, or instructions that are already covered in the", + "conversation history. Instead, provide only new insights, additional analysis, or direct answers to", + "the follow-up question / concerns / insights. Assume the user has read the prior conversation.", + "", + f"This is turn {len(all_turns) + 1} of the conversation - use the conversation history above to provide a coherent continuation.", + ] + ) + + # Calculate total tokens for the complete conversation history + complete_history = "\n".join(history_parts) + from utils.token_utils import estimate_tokens + + total_conversation_tokens = estimate_tokens(complete_history) + + # Summary log of what was built + user_turns = len([t for t in all_turns if t.role == "user"]) + assistant_turns = len([t for t in all_turns if t.role == "assistant"]) + logger.debug( + f"[FLOW] Built conversation history: {user_turns} user + {assistant_turns} assistant turns, {len(all_files)} files, {total_conversation_tokens:,} tokens" + ) + + return complete_history, total_conversation_tokens + + +def _get_tool_formatted_content(turn: ConversationTurn) -> list[str]: + """ + Get tool-specific formatting for a conversation turn. + + This function attempts to use the tool's custom formatting method if available, + falling back to default formatting if the tool cannot be found or doesn't + provide custom formatting. + + Args: + turn: The conversation turn to format + + Returns: + list[str]: Formatted content lines for this turn + """ + if turn.tool_name: + try: + # Dynamically import to avoid circular dependencies + from server import TOOLS + + tool = TOOLS.get(turn.tool_name) + if tool: + # Use inheritance pattern - try to call the method directly + # If it doesn't exist or raises AttributeError, fall back to default + try: + return tool.format_conversation_turn(turn) + except AttributeError: + # Tool doesn't implement format_conversation_turn - use default + pass + except Exception as e: + # Log but don't fail - fall back to default formatting + logger.debug(f"[HISTORY] Could not get tool-specific formatting for {turn.tool_name}: {e}") + + # Default formatting + return _default_turn_formatting(turn) + + +def _default_turn_formatting(turn: ConversationTurn) -> list[str]: + """ + Default formatting for conversation turns. + + This provides the standard formatting when no tool-specific + formatting is available. + + Args: + turn: The conversation turn to format + + Returns: + list[str]: Default formatted content lines + """ + parts = [] + + # Add files context if present + if turn.files: + parts.append(f"Files used in this turn: {', '.join(turn.files)}") + parts.append("") # Empty line for readability + + # Add the actual content + parts.append(turn.content) + + return parts + + +def _is_valid_uuid(val: str) -> bool: + """ + Validate UUID format for security + + Ensures thread IDs are valid UUIDs to prevent injection attacks + and malformed requests. + + Args: + val: String to validate as UUID + + Returns: + bool: True if valid UUID format, False otherwise + """ + try: + uuid.UUID(val) + return True + except ValueError: + return False diff --git a/utils/file_types.py b/utils/file_types.py new file mode 100644 index 0000000..87a1059 --- /dev/null +++ b/utils/file_types.py @@ -0,0 +1,271 @@ +""" +File type definitions and constants for file processing + +This module centralizes all file type and extension definitions used +throughout the MCP server for consistent file handling. +""" + +# Programming language file extensions - core code files +PROGRAMMING_LANGUAGES = { + ".py", # Python + ".js", # JavaScript + ".ts", # TypeScript + ".jsx", # React JavaScript + ".tsx", # React TypeScript + ".java", # Java + ".cpp", # C++ + ".c", # C + ".h", # C/C++ Header + ".hpp", # C++ Header + ".cs", # C# + ".go", # Go + ".rs", # Rust + ".rb", # Ruby + ".php", # PHP + ".swift", # Swift + ".kt", # Kotlin + ".scala", # Scala + ".r", # R + ".m", # Objective-C + ".mm", # Objective-C++ +} + +# Script and shell file extensions +SCRIPTS = { + ".sql", # SQL + ".sh", # Shell + ".bash", # Bash + ".zsh", # Zsh + ".fish", # Fish shell + ".ps1", # PowerShell + ".bat", # Batch + ".cmd", # Command +} + +# Configuration and data file extensions +CONFIGS = { + ".yml", # YAML + ".yaml", # YAML + ".json", # JSON + ".xml", # XML + ".toml", # TOML + ".ini", # INI + ".cfg", # Config + ".conf", # Config + ".properties", # Properties + ".env", # Environment +} + +# Documentation and markup file extensions +DOCS = { + ".txt", # Text + ".md", # Markdown + ".rst", # reStructuredText + ".tex", # LaTeX +} + +# Web development file extensions +WEB = { + ".html", # HTML + ".css", # CSS + ".scss", # Sass + ".sass", # Sass + ".less", # Less +} + +# Additional text file extensions for logs and data +TEXT_DATA = { + ".log", # Log files + ".csv", # CSV + ".tsv", # TSV + ".gitignore", # Git ignore + ".dockerfile", # Dockerfile + ".makefile", # Make + ".cmake", # CMake + ".gradle", # Gradle + ".sbt", # SBT + ".pom", # Maven POM + ".lock", # Lock files + ".changeset", # Precommit changeset +} + +# Image file extensions - limited to what AI models actually support +# Based on OpenAI and Gemini supported formats: PNG, JPEG, GIF, WebP +IMAGES = {".jpg", ".jpeg", ".png", ".gif", ".webp"} + +# Binary executable and library extensions +BINARIES = { + ".exe", # Windows executable + ".dll", # Windows library + ".so", # Linux shared object + ".dylib", # macOS dynamic library + ".bin", # Binary + ".class", # Java class +} + +# Archive and package file extensions +ARCHIVES = { + ".jar", + ".war", + ".ear", # Java archives + ".zip", + ".tar", + ".gz", # General archives + ".7z", + ".rar", # Compression + ".deb", + ".rpm", # Linux packages + ".dmg", + ".pkg", # macOS packages +} + +# Derived sets for different use cases +CODE_EXTENSIONS = PROGRAMMING_LANGUAGES | SCRIPTS | CONFIGS | DOCS | WEB +PROGRAMMING_EXTENSIONS = PROGRAMMING_LANGUAGES # For line numbering +TEXT_EXTENSIONS = CODE_EXTENSIONS | TEXT_DATA +IMAGE_EXTENSIONS = IMAGES +BINARY_EXTENSIONS = BINARIES | ARCHIVES + +# All extensions by category for easy access +FILE_CATEGORIES = { + "programming": PROGRAMMING_LANGUAGES, + "scripts": SCRIPTS, + "configs": CONFIGS, + "docs": DOCS, + "web": WEB, + "text_data": TEXT_DATA, + "images": IMAGES, + "binaries": BINARIES, + "archives": ARCHIVES, +} + + +def get_file_category(file_path: str) -> str: + """ + Determine the category of a file based on its extension. + + Args: + file_path: Path to the file + + Returns: + Category name or "unknown" if not recognized + """ + from pathlib import Path + + extension = Path(file_path).suffix.lower() + + for category, extensions in FILE_CATEGORIES.items(): + if extension in extensions: + return category + + return "unknown" + + +def is_code_file(file_path: str) -> bool: + """Check if a file is a code file (programming language).""" + from pathlib import Path + + return Path(file_path).suffix.lower() in PROGRAMMING_LANGUAGES + + +def is_text_file(file_path: str) -> bool: + """Check if a file is a text file.""" + from pathlib import Path + + return Path(file_path).suffix.lower() in TEXT_EXTENSIONS + + +def is_binary_file(file_path: str) -> bool: + """Check if a file is a binary file.""" + from pathlib import Path + + return Path(file_path).suffix.lower() in BINARY_EXTENSIONS + + +# File-type specific token-to-byte ratios for accurate token estimation +# Based on empirical analysis of file compression characteristics and tokenization patterns +TOKEN_ESTIMATION_RATIOS = { + # Programming languages + ".py": 3.5, # Python - moderate verbosity + ".js": 3.2, # JavaScript - compact syntax + ".ts": 3.3, # TypeScript - type annotations add tokens + ".jsx": 3.1, # React JSX - JSX tags are tokenized efficiently + ".tsx": 3.0, # React TSX - combination of TypeScript + JSX + ".java": 3.6, # Java - verbose syntax, long identifiers + ".cpp": 3.7, # C++ - preprocessor directives, templates + ".c": 3.8, # C - function definitions, struct declarations + ".go": 3.9, # Go - explicit error handling, package names + ".rs": 3.5, # Rust - similar to Python in verbosity + ".php": 3.3, # PHP - mixed HTML/code, variable prefixes + ".rb": 3.6, # Ruby - descriptive method names + ".swift": 3.4, # Swift - modern syntax, type inference + ".kt": 3.5, # Kotlin - similar to modern languages + ".scala": 3.2, # Scala - functional programming, concise + # Scripts and configuration + ".sh": 4.1, # Shell scripts - commands and paths + ".bat": 4.0, # Batch files - similar to shell + ".ps1": 3.8, # PowerShell - more structured than bash + ".sql": 3.8, # SQL - keywords and table/column names + # Data and configuration formats + ".json": 2.5, # JSON - lots of punctuation and quotes + ".yaml": 3.0, # YAML - structured but readable + ".yml": 3.0, # YAML (alternative extension) + ".xml": 2.8, # XML - tags and attributes + ".toml": 3.2, # TOML - similar to config files + # Documentation and text + ".md": 4.2, # Markdown - natural language with formatting + ".txt": 4.0, # Plain text - mostly natural language + ".rst": 4.1, # reStructuredText - documentation format + # Web technologies + ".html": 2.9, # HTML - tags and attributes + ".css": 3.4, # CSS - properties and selectors + # Logs and data + ".log": 4.5, # Log files - timestamps, messages, stack traces + ".csv": 3.1, # CSV - data with delimiters + # Infrastructure files + ".dockerfile": 3.7, # Dockerfile - commands and paths + ".tf": 3.5, # Terraform - infrastructure as code +} + + +def get_token_estimation_ratio(file_path: str) -> float: + """ + Get the token estimation ratio for a file based on its extension. + + Args: + file_path: Path to the file + + Returns: + Token-to-byte ratio for the file type (default: 3.5 for unknown types) + """ + from pathlib import Path + + extension = Path(file_path).suffix.lower() + return TOKEN_ESTIMATION_RATIOS.get(extension, 3.5) # Conservative default + + +# MIME type mappings for image files - limited to what AI models actually support +# Based on OpenAI and Gemini supported formats: PNG, JPEG, GIF, WebP +IMAGE_MIME_TYPES = { + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".png": "image/png", + ".gif": "image/gif", + ".webp": "image/webp", +} + + +def get_image_mime_type(extension: str) -> str: + """ + Get the MIME type for an image file extension. + + Args: + extension: File extension (with or without leading dot) + + Returns: + MIME type string (default: image/jpeg for unknown extensions) + """ + if not extension.startswith("."): + extension = "." + extension + extension = extension.lower() + return IMAGE_MIME_TYPES.get(extension, "image/jpeg") diff --git a/utils/file_utils.py b/utils/file_utils.py new file mode 100644 index 0000000..d22a585 --- /dev/null +++ b/utils/file_utils.py @@ -0,0 +1,864 @@ +""" +File reading utilities with directory support and token management + +This module provides secure file access functionality for the MCP server. +It implements critical security measures to prevent unauthorized file access +and manages token limits to ensure efficient API usage. + +Key Features: +- Path validation and sandboxing to prevent directory traversal attacks +- Support for both individual files and recursive directory reading +- Token counting and management to stay within API limits +- Automatic file type detection and filtering +- Comprehensive error handling with informative messages + +Security Model: +- All file access is restricted to PROJECT_ROOT and its subdirectories +- Absolute paths are required to prevent ambiguity +- Symbolic links are resolved to ensure they stay within bounds + +CONVERSATION MEMORY INTEGRATION: +This module works with the conversation memory system to support efficient +multi-turn file handling: + +1. DEDUPLICATION SUPPORT: + - File reading functions are called by conversation-aware tools + - Supports newest-first file prioritization by providing accurate token estimation + - Enables efficient file content caching and token budget management + +2. TOKEN BUDGET OPTIMIZATION: + - Provides accurate token estimation for file content before reading + - Supports the dual prioritization strategy by enabling precise budget calculations + - Enables tools to make informed decisions about which files to include + +3. CROSS-TOOL FILE PERSISTENCE: + - File reading results are used across different tools in conversation chains + - Consistent file access patterns support conversation continuation scenarios + - Error handling preserves conversation flow when files become unavailable +""" + +import json +import logging +import os +from pathlib import Path +from typing import Optional + +from .file_types import BINARY_EXTENSIONS, CODE_EXTENSIONS, IMAGE_EXTENSIONS, TEXT_EXTENSIONS +from .security_config import EXCLUDED_DIRS, is_dangerous_path +from .token_utils import DEFAULT_CONTEXT_WINDOW, estimate_tokens + + +def _is_builtin_custom_models_config(path_str: str) -> bool: + """ + Check if path points to the server's built-in custom_models.json config file. + + This only matches the server's internal config, not user-specified CUSTOM_MODELS_CONFIG_PATH. + We identify the built-in config by checking if it resolves to the server's conf directory. + + Args: + path_str: Path to check + + Returns: + True if this is the server's built-in custom_models.json config file + """ + try: + path = Path(path_str) + + # Get the server root by going up from this file: utils/file_utils.py -> server_root + server_root = Path(__file__).parent.parent + builtin_config = server_root / "conf" / "custom_models.json" + + # Check if the path resolves to the same file as our built-in config + # This handles both relative and absolute paths to the same file + return path.resolve() == builtin_config.resolve() + + except Exception: + # If path resolution fails, it's not our built-in config + return False + + +logger = logging.getLogger(__name__) + + +def is_mcp_directory(path: Path) -> bool: + """ + Check if a directory is the MCP server's own directory. + + This prevents the MCP from including its own code when scanning projects + where the MCP has been cloned as a subdirectory. + + Args: + path: Directory path to check + + Returns: + True if this is the MCP server directory or a subdirectory + """ + if not path.is_dir(): + return False + + # Get the directory where the MCP server is running from + # __file__ is utils/file_utils.py, so parent.parent is the MCP root + mcp_server_dir = Path(__file__).parent.parent.resolve() + + # Check if the given path is the MCP server directory or a subdirectory + try: + path.resolve().relative_to(mcp_server_dir) + logger.info(f"Detected MCP server directory at {path}, will exclude from scanning") + return True + except ValueError: + # Not a subdirectory of MCP server + return False + + +def get_user_home_directory() -> Optional[Path]: + """ + Get the user's home directory. + + Returns: + User's home directory path + """ + return Path.home() + + +def is_home_directory_root(path: Path) -> bool: + """ + Check if the given path is the user's home directory root. + + This prevents scanning the entire home directory which could include + sensitive data and non-project files. + + Args: + path: Directory path to check + + Returns: + True if this is the home directory root + """ + user_home = get_user_home_directory() + if not user_home: + return False + + try: + resolved_path = path.resolve() + resolved_home = user_home.resolve() + + # Check if this is exactly the home directory + if resolved_path == resolved_home: + logger.warning( + f"Attempted to scan user home directory root: {path}. Please specify a subdirectory instead." + ) + return True + + # Also check common home directory patterns + path_str = str(resolved_path).lower() + home_patterns = [ + "/users/", # macOS + "/home/", # Linux + "c:\\users\\", # Windows + "c:/users/", # Windows with forward slashes + ] + + for pattern in home_patterns: + if pattern in path_str: + # Extract the user directory path + # e.g., /Users/fahad or /home/username + parts = path_str.split(pattern) + if len(parts) > 1: + # Get the part after the pattern + after_pattern = parts[1] + # Check if we're at the user's root (no subdirectories) + if "/" not in after_pattern and "\\" not in after_pattern: + logger.warning( + f"Attempted to scan user home directory root: {path}. " + f"Please specify a subdirectory instead." + ) + return True + + except Exception as e: + logger.debug(f"Error checking if path is home directory: {e}") + + return False + + +def detect_file_type(file_path: str) -> str: + """ + Detect file type for appropriate processing strategy. + + This function is intended for specific file type handling (e.g., image processing, + binary file analysis, or enhanced file filtering). + + Args: + file_path: Path to the file to analyze + + Returns: + str: "text", "binary", or "image" + """ + path = Path(file_path) + + # Check extension first (fast) + extension = path.suffix.lower() + if extension in TEXT_EXTENSIONS: + return "text" + elif extension in IMAGE_EXTENSIONS: + return "image" + elif extension in BINARY_EXTENSIONS: + return "binary" + + # Fallback: check magic bytes for text vs binary + # This is helpful for files without extensions or unknown extensions + try: + with open(path, "rb") as f: + chunk = f.read(1024) + # Simple heuristic: if we can decode as UTF-8, likely text + chunk.decode("utf-8") + return "text" + except UnicodeDecodeError: + return "binary" + except (FileNotFoundError, PermissionError) as e: + logger.warning(f"Could not access file {file_path} for type detection: {e}") + return "unknown" + + +def should_add_line_numbers(file_path: str, include_line_numbers: Optional[bool] = None) -> bool: + """ + Determine if line numbers should be added to a file. + + Args: + file_path: Path to the file + include_line_numbers: Explicit preference, or None for auto-detection + + Returns: + bool: True if line numbers should be added + """ + if include_line_numbers is not None: + return include_line_numbers + + # Default: DO NOT add line numbers + # Tools that want line numbers must explicitly request them + return False + + +def _normalize_line_endings(content: str) -> str: + """ + Normalize line endings for consistent line numbering. + + Args: + content: File content with potentially mixed line endings + + Returns: + str: Content with normalized LF line endings + """ + # Normalize all line endings to LF for consistent counting + return content.replace("\r\n", "\n").replace("\r", "\n") + + +def _add_line_numbers(content: str) -> str: + """ + Add line numbers to text content for precise referencing. + + Args: + content: Text content to number + + Returns: + str: Content with line numbers in format " 45│ actual code line" + Supports files up to 99,999 lines with dynamic width allocation + """ + # Normalize line endings first + normalized_content = _normalize_line_endings(content) + lines = normalized_content.split("\n") + + # Dynamic width allocation based on total line count + # This supports files of any size by computing required width + total_lines = len(lines) + width = len(str(total_lines)) + width = max(width, 4) # Minimum padding for readability + + # Format with dynamic width and clear separator + numbered_lines = [f"{i + 1:{width}d}│ {line}" for i, line in enumerate(lines)] + + return "\n".join(numbered_lines) + + +def resolve_and_validate_path(path_str: str) -> Path: + """ + Resolves and validates a path against security policies. + + This function ensures safe file access by: + 1. Requiring absolute paths (no ambiguity) + 2. Resolving symlinks to prevent deception + 3. Blocking access to dangerous system directories + + Args: + path_str: Path string (must be absolute) + + Returns: + Resolved Path object that is safe to access + + Raises: + ValueError: If path is not absolute or otherwise invalid + PermissionError: If path is in a dangerous location + """ + # Step 1: Create a Path object + user_path = Path(path_str) + + # Step 2: Security Policy - Require absolute paths + # Relative paths could be interpreted differently depending on working directory + if not user_path.is_absolute(): + raise ValueError(f"Relative paths are not supported. Please provide an absolute path.\nReceived: {path_str}") + + # Step 3: Resolve the absolute path (follows symlinks, removes .. and .) + # This is critical for security as it reveals the true destination of symlinks + resolved_path = user_path.resolve() + + # Step 4: Check against dangerous paths + if is_dangerous_path(resolved_path): + logger.warning(f"Access denied - dangerous path: {resolved_path}") + raise PermissionError(f"Access to system directory denied: {path_str}") + + # Step 5: Check if it's the home directory root + if is_home_directory_root(resolved_path): + raise PermissionError( + f"Cannot scan entire home directory: {path_str}\n" f"Please specify a subdirectory within your home folder." + ) + + return resolved_path + + +def expand_paths(paths: list[str], extensions: Optional[set[str]] = None) -> list[str]: + """ + Expand paths to individual files, handling both files and directories. + + This function recursively walks directories to find all matching files. + It automatically filters out hidden files and common non-code directories + like __pycache__ to avoid including generated or system files. + + Args: + paths: List of file or directory paths (must be absolute) + extensions: Optional set of file extensions to include (defaults to CODE_EXTENSIONS) + + Returns: + List of individual file paths, sorted for consistent ordering + """ + if extensions is None: + extensions = CODE_EXTENSIONS + + expanded_files = [] + seen = set() + + for path in paths: + try: + # Validate each path for security before processing + path_obj = resolve_and_validate_path(path) + except (ValueError, PermissionError): + # Skip invalid paths silently to allow partial success + continue + + if not path_obj.exists(): + continue + + # Safety checks for directory scanning + if path_obj.is_dir(): + # Check 1: Prevent scanning user's home directory root + if is_home_directory_root(path_obj): + logger.warning(f"Skipping home directory root: {path}. Please specify a project subdirectory instead.") + continue + + # Check 2: Skip if this is the MCP's own directory + if is_mcp_directory(path_obj): + logger.info( + f"Skipping MCP server directory: {path}. The MCP server code is excluded from project scans." + ) + continue + + if path_obj.is_file(): + # Add file directly + if str(path_obj) not in seen: + expanded_files.append(str(path_obj)) + seen.add(str(path_obj)) + + elif path_obj.is_dir(): + # Walk directory recursively to find all files + for root, dirs, files in os.walk(path_obj): + # Filter directories in-place to skip hidden and excluded directories + # This prevents descending into .git, .venv, __pycache__, node_modules, etc. + original_dirs = dirs[:] + dirs[:] = [] + for d in original_dirs: + # Skip hidden directories + if d.startswith("."): + continue + # Skip excluded directories + if d in EXCLUDED_DIRS: + continue + # Skip MCP directories found during traversal + dir_path = Path(root) / d + if is_mcp_directory(dir_path): + logger.debug(f"Skipping MCP directory during traversal: {dir_path}") + continue + dirs.append(d) + + for file in files: + # Skip hidden files (e.g., .DS_Store, .gitignore) + if file.startswith("."): + continue + + file_path = Path(root) / file + + # Filter by extension if specified + if not extensions or file_path.suffix.lower() in extensions: + full_path = str(file_path) + # Use set to prevent duplicates + if full_path not in seen: + expanded_files.append(full_path) + seen.add(full_path) + + # Sort for consistent ordering across different runs + # This makes output predictable and easier to debug + expanded_files.sort() + return expanded_files + + +def read_file_content( + file_path: str, max_size: int = 1_000_000, *, include_line_numbers: Optional[bool] = None +) -> tuple[str, int]: + """ + Read a single file and format it for inclusion in AI prompts. + + This function handles various error conditions gracefully and always + returns formatted content, even for errors. This ensures the AI model + gets context about what files were attempted but couldn't be read. + + Args: + file_path: Path to file (must be absolute) + max_size: Maximum file size to read (default 1MB to prevent memory issues) + include_line_numbers: Whether to add line numbers. If None, auto-detects based on file type + + Returns: + Tuple of (formatted_content, estimated_tokens) + Content is wrapped with clear delimiters for AI parsing + """ + logger.debug(f"[FILES] read_file_content called for: {file_path}") + try: + # Validate path security before any file operations + path = resolve_and_validate_path(file_path) + logger.debug(f"[FILES] Path validated and resolved: {path}") + except (ValueError, PermissionError) as e: + # Return error in a format that provides context to the AI + logger.debug(f"[FILES] Path validation failed for {file_path}: {type(e).__name__}: {e}") + error_msg = str(e) + content = f"\n--- ERROR ACCESSING FILE: {file_path} ---\nError: {error_msg}\n--- END FILE ---\n" + tokens = estimate_tokens(content) + logger.debug(f"[FILES] Returning error content for {file_path}: {tokens} tokens") + return content, tokens + + try: + # Validate file existence and type + if not path.exists(): + logger.debug(f"[FILES] File does not exist: {file_path}") + content = f"\n--- FILE NOT FOUND: {file_path} ---\nError: File does not exist\n--- END FILE ---\n" + return content, estimate_tokens(content) + + if not path.is_file(): + logger.debug(f"[FILES] Path is not a file: {file_path}") + content = f"\n--- NOT A FILE: {file_path} ---\nError: Path is not a file\n--- END FILE ---\n" + return content, estimate_tokens(content) + + # Check file size to prevent memory exhaustion + file_size = path.stat().st_size + logger.debug(f"[FILES] File size for {file_path}: {file_size:,} bytes") + if file_size > max_size: + logger.debug(f"[FILES] File too large: {file_path} ({file_size:,} > {max_size:,} bytes)") + content = f"\n--- FILE TOO LARGE: {file_path} ---\nFile size: {file_size:,} bytes (max: {max_size:,})\n--- END FILE ---\n" + return content, estimate_tokens(content) + + # Determine if we should add line numbers + add_line_numbers = should_add_line_numbers(file_path, include_line_numbers) + logger.debug(f"[FILES] Line numbers for {file_path}: {'enabled' if add_line_numbers else 'disabled'}") + + # Read the file with UTF-8 encoding, replacing invalid characters + # This ensures we can handle files with mixed encodings + logger.debug(f"[FILES] Reading file content for {file_path}") + with open(path, encoding="utf-8", errors="replace") as f: + file_content = f.read() + + logger.debug(f"[FILES] Successfully read {len(file_content)} characters from {file_path}") + + # Add line numbers if requested or auto-detected + if add_line_numbers: + file_content = _add_line_numbers(file_content) + logger.debug(f"[FILES] Added line numbers to {file_path}") + else: + # Still normalize line endings for consistency + file_content = _normalize_line_endings(file_content) + + # Format with clear delimiters that help the AI understand file boundaries + # Using consistent markers makes it easier for the model to parse + # NOTE: These markers ("--- BEGIN FILE: ... ---") are distinct from git diff markers + # ("--- BEGIN DIFF: ... ---") to allow AI to distinguish between complete file content + # vs. partial diff content when files appear in both sections + formatted = f"\n--- BEGIN FILE: {file_path} ---\n{file_content}\n--- END FILE: {file_path} ---\n" + tokens = estimate_tokens(formatted) + logger.debug(f"[FILES] Formatted content for {file_path}: {len(formatted)} chars, {tokens} tokens") + return formatted, tokens + + except Exception as e: + logger.debug(f"[FILES] Exception reading file {file_path}: {type(e).__name__}: {e}") + content = f"\n--- ERROR READING FILE: {file_path} ---\nError: {str(e)}\n--- END FILE ---\n" + tokens = estimate_tokens(content) + logger.debug(f"[FILES] Returning error content for {file_path}: {tokens} tokens") + return content, tokens + + +def read_files( + file_paths: list[str], + code: Optional[str] = None, + max_tokens: Optional[int] = None, + reserve_tokens: int = 50_000, + *, + include_line_numbers: bool = False, +) -> str: + """ + Read multiple files and optional direct code with smart token management. + + This function implements intelligent token budgeting to maximize the amount + of relevant content that can be included in an AI prompt while staying + within token limits. It prioritizes direct code and reads files until + the token budget is exhausted. + + Args: + file_paths: List of file or directory paths (absolute paths required) + code: Optional direct code to include (prioritized over files) + max_tokens: Maximum tokens to use (defaults to DEFAULT_CONTEXT_WINDOW) + reserve_tokens: Tokens to reserve for prompt and response (default 50K) + include_line_numbers: Whether to add line numbers to file content + + Returns: + str: All file contents formatted for AI consumption + """ + if max_tokens is None: + max_tokens = DEFAULT_CONTEXT_WINDOW + + logger.debug(f"[FILES] read_files called with {len(file_paths)} paths") + logger.debug( + f"[FILES] Token budget: max={max_tokens:,}, reserve={reserve_tokens:,}, available={max_tokens - reserve_tokens:,}" + ) + + content_parts = [] + total_tokens = 0 + available_tokens = max_tokens - reserve_tokens + + files_skipped = [] + + # Priority 1: Handle direct code if provided + # Direct code is prioritized because it's explicitly provided by the user + if code: + formatted_code = f"\n--- BEGIN DIRECT CODE ---\n{code}\n--- END DIRECT CODE ---\n" + code_tokens = estimate_tokens(formatted_code) + + if code_tokens <= available_tokens: + content_parts.append(formatted_code) + total_tokens += code_tokens + available_tokens -= code_tokens + + # Priority 2: Process file paths + if file_paths: + # Expand directories to get all individual files + logger.debug(f"[FILES] Expanding {len(file_paths)} file paths") + all_files = expand_paths(file_paths) + logger.debug(f"[FILES] After expansion: {len(all_files)} individual files") + + if not all_files and file_paths: + # No files found but paths were provided + logger.debug("[FILES] No files found from provided paths") + content_parts.append(f"\n--- NO FILES FOUND ---\nProvided paths: {', '.join(file_paths)}\n--- END ---\n") + else: + # Read files sequentially until token limit is reached + logger.debug(f"[FILES] Reading {len(all_files)} files with token budget {available_tokens:,}") + for i, file_path in enumerate(all_files): + if total_tokens >= available_tokens: + logger.debug(f"[FILES] Token budget exhausted, skipping remaining {len(all_files) - i} files") + files_skipped.extend(all_files[i:]) + break + + file_content, file_tokens = read_file_content(file_path, include_line_numbers=include_line_numbers) + logger.debug(f"[FILES] File {file_path}: {file_tokens:,} tokens") + + # Check if adding this file would exceed limit + if total_tokens + file_tokens <= available_tokens: + content_parts.append(file_content) + total_tokens += file_tokens + logger.debug(f"[FILES] Added file {file_path}, total tokens: {total_tokens:,}") + else: + # File too large for remaining budget + logger.debug( + f"[FILES] File {file_path} too large for remaining budget ({file_tokens:,} tokens, {available_tokens - total_tokens:,} remaining)" + ) + files_skipped.append(file_path) + + # Add informative note about skipped files to help users understand + # what was omitted and why + if files_skipped: + logger.debug(f"[FILES] {len(files_skipped)} files skipped due to token limits") + skip_note = "\n\n--- SKIPPED FILES (TOKEN LIMIT) ---\n" + skip_note += f"Total skipped: {len(files_skipped)}\n" + # Show first 10 skipped files as examples + for _i, file_path in enumerate(files_skipped[:10]): + skip_note += f" - {file_path}\n" + if len(files_skipped) > 10: + skip_note += f" ... and {len(files_skipped) - 10} more\n" + skip_note += "--- END SKIPPED FILES ---\n" + content_parts.append(skip_note) + + result = "\n\n".join(content_parts) if content_parts else "" + logger.debug(f"[FILES] read_files complete: {len(result)} chars, {total_tokens:,} tokens used") + return result + + +def estimate_file_tokens(file_path: str) -> int: + """ + Estimate tokens for a file using file-type aware ratios. + + Args: + file_path: Path to the file + + Returns: + Estimated token count for the file + """ + try: + if not os.path.exists(file_path) or not os.path.isfile(file_path): + return 0 + + file_size = os.path.getsize(file_path) + + # Get the appropriate ratio for this file type + from .file_types import get_token_estimation_ratio + + ratio = get_token_estimation_ratio(file_path) + + return int(file_size / ratio) + except Exception: + return 0 + + +def check_files_size_limit(files: list[str], max_tokens: int, threshold_percent: float = 1.0) -> tuple[bool, int, int]: + """ + Check if a list of files would exceed token limits. + + Args: + files: List of file paths to check + max_tokens: Maximum allowed tokens + threshold_percent: Percentage of max_tokens to use as threshold (0.0-1.0) + + Returns: + Tuple of (within_limit, total_estimated_tokens, file_count) + """ + if not files: + return True, 0, 0 + + total_estimated_tokens = 0 + file_count = 0 + threshold = int(max_tokens * threshold_percent) + + for file_path in files: + try: + estimated_tokens = estimate_file_tokens(file_path) + total_estimated_tokens += estimated_tokens + if estimated_tokens > 0: # Only count accessible files + file_count += 1 + except Exception: + # Skip files that can't be accessed for size check + continue + + within_limit = total_estimated_tokens <= threshold + return within_limit, total_estimated_tokens, file_count + + +def read_json_file(file_path: str) -> Optional[dict]: + """ + Read and parse a JSON file with proper error handling. + + Args: + file_path: Path to the JSON file + + Returns: + Parsed JSON data as dict, or None if file doesn't exist or invalid + """ + try: + if not os.path.exists(file_path): + return None + + with open(file_path, encoding="utf-8") as f: + return json.load(f) + except (json.JSONDecodeError, OSError): + return None + + +def write_json_file(file_path: str, data: dict, indent: int = 2) -> bool: + """ + Write data to a JSON file with proper formatting. + + Args: + file_path: Path to write the JSON file + data: Dictionary data to serialize + indent: JSON indentation level + + Returns: + True if successful, False otherwise + """ + try: + os.makedirs(os.path.dirname(file_path), exist_ok=True) + + with open(file_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=indent, ensure_ascii=False) + return True + except (OSError, TypeError): + return False + + +def get_file_size(file_path: str) -> int: + """ + Get file size in bytes with proper error handling. + + Args: + file_path: Path to the file + + Returns: + File size in bytes, or 0 if file doesn't exist or error + """ + try: + if os.path.exists(file_path) and os.path.isfile(file_path): + return os.path.getsize(file_path) + return 0 + except OSError: + return 0 + + +def ensure_directory_exists(file_path: str) -> bool: + """ + Ensure the parent directory of a file path exists. + + Args: + file_path: Path to file (directory will be created for parent) + + Returns: + True if directory exists or was created, False on error + """ + try: + directory = os.path.dirname(file_path) + if directory: + os.makedirs(directory, exist_ok=True) + return True + except OSError: + return False + + +def is_text_file(file_path: str) -> bool: + """ + Check if a file is likely a text file based on extension and content. + + Args: + file_path: Path to the file + + Returns: + True if file appears to be text, False otherwise + """ + from .file_types import is_text_file as check_text_type + + return check_text_type(file_path) + + +def read_file_safely(file_path: str, max_size: int = 10 * 1024 * 1024) -> Optional[str]: + """ + Read a file with size limits and encoding handling. + + Args: + file_path: Path to the file + max_size: Maximum file size in bytes (default 10MB) + + Returns: + File content as string, or None if file too large or unreadable + """ + try: + if not os.path.exists(file_path) or not os.path.isfile(file_path): + return None + + file_size = os.path.getsize(file_path) + if file_size > max_size: + return None + + with open(file_path, encoding="utf-8", errors="ignore") as f: + return f.read() + except OSError: + return None + + +def check_total_file_size(files: list[str], model_name: str) -> Optional[dict]: + """ + Check if total file sizes would exceed token threshold before embedding. + + IMPORTANT: This performs STRICT REJECTION at MCP boundary. + No partial inclusion - either all files fit or request is rejected. + This forces Claude to make better file selection decisions. + + This function MUST be called with the effective model name (after resolution). + It should never receive 'auto' or None - model resolution happens earlier. + + Args: + files: List of file paths to check + model_name: The resolved model name for context-aware thresholds (required) + + Returns: + Dict with `code_too_large` response if too large, None if acceptable + """ + if not files: + return None + + # Validate we have a proper model name (not auto or None) + if not model_name or model_name.lower() == "auto": + raise ValueError( + f"check_total_file_size called with unresolved model: '{model_name}'. " + "Model must be resolved before file size checking." + ) + + logger.info(f"File size check: Using model '{model_name}' for token limit calculation") + + from utils.model_context import ModelContext + + model_context = ModelContext(model_name) + token_allocation = model_context.calculate_token_allocation() + + # Dynamic threshold based on model capacity + context_window = token_allocation.total_tokens + if context_window >= 1_000_000: # Gemini-class models + threshold_percent = 0.8 # Can be more generous + elif context_window >= 500_000: # Mid-range models + threshold_percent = 0.7 # Moderate + else: # OpenAI-class models (200K) + threshold_percent = 0.6 # Conservative + + max_file_tokens = int(token_allocation.file_tokens * threshold_percent) + + # Use centralized file size checking (threshold already applied to max_file_tokens) + within_limit, total_estimated_tokens, file_count = check_files_size_limit(files, max_file_tokens) + + if not within_limit: + return { + "status": "code_too_large", + "content": ( + f"The selected files are too large for analysis " + f"(estimated {total_estimated_tokens:,} tokens, limit {max_file_tokens:,}). " + f"Please select fewer, more specific files that are most relevant " + f"to your question, then invoke the tool again." + ), + "content_type": "text", + "metadata": { + "total_estimated_tokens": total_estimated_tokens, + "limit": max_file_tokens, + "file_count": file_count, + "threshold_percent": threshold_percent, + "model_context_window": context_window, + "model_name": model_name, + "instructions": "Reduce file selection and try again - all files must fit within budget. If this persists, please use a model with a larger context window where available.", + }, + } + + return None # Proceed with ALL files diff --git a/utils/image_utils.py b/utils/image_utils.py new file mode 100644 index 0000000..621ea9a --- /dev/null +++ b/utils/image_utils.py @@ -0,0 +1,94 @@ +"""Utility helpers for validating image inputs.""" + +import base64 +import binascii +import os +from collections.abc import Iterable + +from utils.file_types import IMAGES, get_image_mime_type + +DEFAULT_MAX_IMAGE_SIZE_MB = 20.0 + +__all__ = ["DEFAULT_MAX_IMAGE_SIZE_MB", "validate_image"] + + +def _valid_mime_types() -> Iterable[str]: + """Return the MIME types permitted by the IMAGES whitelist.""" + return (get_image_mime_type(ext) for ext in IMAGES) + + +def validate_image(image_path: str, max_size_mb: float = None) -> tuple[bytes, str]: + """Validate a user-supplied image path or data URL. + + Args: + image_path: Either a filesystem path or a data URL. + max_size_mb: Optional size limit (defaults to ``DEFAULT_MAX_IMAGE_SIZE_MB``). + + Returns: + A tuple ``(image_bytes, mime_type)`` ready for upstream providers. + + Raises: + ValueError: When the image is missing, malformed, or exceeds limits. + """ + if max_size_mb is None: + max_size_mb = DEFAULT_MAX_IMAGE_SIZE_MB + + if image_path.startswith("data:"): + return _validate_data_url(image_path, max_size_mb) + + return _validate_file_path(image_path, max_size_mb) + + +def _validate_data_url(image_data_url: str, max_size_mb: float) -> tuple[bytes, str]: + """Validate a data URL and return image bytes plus MIME type.""" + try: + header, data = image_data_url.split(",", 1) + mime_type = header.split(";")[0].split(":")[1] + except (ValueError, IndexError) as exc: + raise ValueError(f"Invalid data URL format: {exc}") + + valid_mime_types = list(_valid_mime_types()) + if mime_type not in valid_mime_types: + raise ValueError( + "Unsupported image type: {mime}. Supported types: {supported}".format( + mime=mime_type, supported=", ".join(valid_mime_types) + ) + ) + + try: + image_bytes = base64.b64decode(data) + except binascii.Error as exc: + raise ValueError(f"Invalid base64 data: {exc}") + + _validate_size(image_bytes, max_size_mb) + return image_bytes, mime_type + + +def _validate_file_path(file_path: str, max_size_mb: float) -> tuple[bytes, str]: + """Validate an image loaded from the filesystem.""" + try: + with open(file_path, "rb") as handle: + image_bytes = handle.read() + except FileNotFoundError: + raise ValueError(f"Image file not found: {file_path}") + except OSError as exc: + raise ValueError(f"Failed to read image file: {exc}") + + ext = os.path.splitext(file_path)[1].lower() + if ext not in IMAGES: + raise ValueError( + "Unsupported image format: {ext}. Supported formats: {supported}".format( + ext=ext, supported=", ".join(sorted(IMAGES)) + ) + ) + + mime_type = get_image_mime_type(ext) + _validate_size(image_bytes, max_size_mb) + return image_bytes, mime_type + + +def _validate_size(image_bytes: bytes, max_size_mb: float) -> None: + """Ensure the image does not exceed the configured size limit.""" + size_mb = len(image_bytes) / (1024 * 1024) + if size_mb > max_size_mb: + raise ValueError(f"Image too large: {size_mb:.1f}MB (max: {max_size_mb}MB)") diff --git a/utils/model_context.py b/utils/model_context.py new file mode 100644 index 0000000..c4015cc --- /dev/null +++ b/utils/model_context.py @@ -0,0 +1,180 @@ +""" +Model context management for dynamic token allocation. + +This module provides a clean abstraction for model-specific token management, +ensuring that token limits are properly calculated based on the current model +being used, not global constants. + +CONVERSATION MEMORY INTEGRATION: +This module works closely with the conversation memory system to provide +optimal token allocation for multi-turn conversations: + +1. DUAL PRIORITIZATION STRATEGY SUPPORT: + - Provides separate token budgets for conversation history vs. files + - Enables the conversation memory system to apply newest-first prioritization + - Ensures optimal balance between context preservation and new content + +2. MODEL-SPECIFIC ALLOCATION: + - Dynamic allocation based on model capabilities (context window size) + - Conservative allocation for smaller models (O3: 200K context) + - Generous allocation for larger models (Gemini: 1M+ context) + - Adapts token distribution ratios based on model capacity + +3. CROSS-TOOL CONSISTENCY: + - Provides consistent token budgets across different tools + - Enables seamless conversation continuation between tools + - Supports conversation reconstruction with proper budget management +""" + +import logging +from dataclasses import dataclass +from typing import Any, Optional + +from config import DEFAULT_MODEL +from providers import ModelCapabilities, ModelProviderRegistry + +logger = logging.getLogger(__name__) + + +@dataclass +class TokenAllocation: + """Token allocation strategy for a model.""" + + total_tokens: int + content_tokens: int + response_tokens: int + file_tokens: int + history_tokens: int + + @property + def available_for_prompt(self) -> int: + """Tokens available for the actual prompt after allocations.""" + return self.content_tokens - self.file_tokens - self.history_tokens + + +class ModelContext: + """ + Encapsulates model-specific information and token calculations. + + This class provides a single source of truth for all model-related + token calculations, ensuring consistency across the system. + """ + + def __init__(self, model_name: str, model_option: Optional[str] = None): + self.model_name = model_name + self.model_option = model_option # Store optional model option (e.g., "for", "against", etc.) + self._provider = None + self._capabilities = None + self._token_allocation = None + + @property + def provider(self): + """Get the model provider lazily.""" + if self._provider is None: + self._provider = ModelProviderRegistry.get_provider_for_model(self.model_name) + if not self._provider: + available_models = ModelProviderRegistry.get_available_model_names() + if available_models: + available_text = ", ".join(available_models) + else: + available_text = ( + "No models detected. Configure provider credentials or set DEFAULT_MODEL to a valid option." + ) + + raise ValueError( + f"Model '{self.model_name}' is not available with current API keys. Available models: {available_text}." + ) + return self._provider + + @property + def capabilities(self) -> ModelCapabilities: + """Get model capabilities lazily.""" + if self._capabilities is None: + self._capabilities = self.provider.get_capabilities(self.model_name) + return self._capabilities + + def calculate_token_allocation(self, reserved_for_response: Optional[int] = None) -> TokenAllocation: + """ + Calculate token allocation based on model capacity and conversation requirements. + + This method implements the core token budget calculation that supports the + dual prioritization strategy used in conversation memory and file processing: + + TOKEN ALLOCATION STRATEGY: + 1. CONTENT vs RESPONSE SPLIT: + - Smaller models (< 300K): 60% content, 40% response (conservative) + - Larger models (≥ 300K): 80% content, 20% response (generous) + + 2. CONTENT SUB-ALLOCATION: + - File tokens: 30-40% of content budget for newest file versions + - History tokens: 40-50% of content budget for conversation context + - Remaining: Available for tool-specific prompt content + + 3. CONVERSATION MEMORY INTEGRATION: + - History allocation enables conversation reconstruction in reconstruct_thread_context() + - File allocation supports newest-first file prioritization in tools + - Remaining budget passed to tools via _remaining_tokens parameter + + Args: + reserved_for_response: Override response token reservation + + Returns: + TokenAllocation with calculated budgets for dual prioritization strategy + """ + total_tokens = self.capabilities.context_window + + # Dynamic allocation based on model capacity + if total_tokens < 300_000: + # Smaller context models (O3): Conservative allocation + content_ratio = 0.6 # 60% for content + response_ratio = 0.4 # 40% for response + file_ratio = 0.3 # 30% of content for files + history_ratio = 0.5 # 50% of content for history + else: + # Larger context models (Gemini): More generous allocation + content_ratio = 0.8 # 80% for content + response_ratio = 0.2 # 20% for response + file_ratio = 0.4 # 40% of content for files + history_ratio = 0.4 # 40% of content for history + + # Calculate allocations + content_tokens = int(total_tokens * content_ratio) + response_tokens = reserved_for_response or int(total_tokens * response_ratio) + + # Sub-allocations within content budget + file_tokens = int(content_tokens * file_ratio) + history_tokens = int(content_tokens * history_ratio) + + allocation = TokenAllocation( + total_tokens=total_tokens, + content_tokens=content_tokens, + response_tokens=response_tokens, + file_tokens=file_tokens, + history_tokens=history_tokens, + ) + + logger.debug(f"Token allocation for {self.model_name}:") + logger.debug(f" Total: {allocation.total_tokens:,}") + logger.debug(f" Content: {allocation.content_tokens:,} ({content_ratio:.0%})") + logger.debug(f" Response: {allocation.response_tokens:,} ({response_ratio:.0%})") + logger.debug(f" Files: {allocation.file_tokens:,} ({file_ratio:.0%} of content)") + logger.debug(f" History: {allocation.history_tokens:,} ({history_ratio:.0%} of content)") + + return allocation + + def estimate_tokens(self, text: str) -> int: + """ + Estimate token count for text using model-specific tokenizer. + + For now, uses simple estimation. Can be enhanced with model-specific + tokenizers (tiktoken for OpenAI, etc.) in the future. + """ + # TODO: Integrate model-specific tokenizers + # For now, use conservative estimation + return len(text) // 3 # Conservative estimate + + @classmethod + def from_arguments(cls, arguments: dict[str, Any]) -> "ModelContext": + """Create ModelContext from tool arguments.""" + model_name = arguments.get("model") or DEFAULT_MODEL + return cls(model_name) diff --git a/utils/model_restrictions.py b/utils/model_restrictions.py new file mode 100644 index 0000000..8b0984e --- /dev/null +++ b/utils/model_restrictions.py @@ -0,0 +1,226 @@ +""" +Model Restriction Service + +This module provides centralized management of model usage restrictions +based on environment variables. It allows organizations to limit which +models can be used from each provider for cost control, compliance, or +standardization purposes. + +Environment Variables: +- OPENAI_ALLOWED_MODELS: Comma-separated list of allowed OpenAI models +- GOOGLE_ALLOWED_MODELS: Comma-separated list of allowed Gemini models +- XAI_ALLOWED_MODELS: Comma-separated list of allowed X.AI GROK models +- OPENROUTER_ALLOWED_MODELS: Comma-separated list of allowed OpenRouter models +- DIAL_ALLOWED_MODELS: Comma-separated list of allowed DIAL models + +Example: + OPENAI_ALLOWED_MODELS=o3-mini,o4-mini + GOOGLE_ALLOWED_MODELS=flash + XAI_ALLOWED_MODELS=grok-3,grok-3-fast + OPENROUTER_ALLOWED_MODELS=opus,sonnet,mistral +""" + +import logging +import os +from typing import Optional + +from providers.shared import ProviderType + +logger = logging.getLogger(__name__) + + +class ModelRestrictionService: + """Central authority for environment-driven model allowlists. + + Role + Interpret ``*_ALLOWED_MODELS`` environment variables, keep their + entries normalised (lowercase), and answer whether a provider/model + pairing is permitted. + + Responsibilities + * Parse, cache, and expose per-provider restriction sets + * Validate configuration by cross-checking each entry against the + provider’s alias-aware model list + * Offer helper methods such as ``is_allowed`` and ``filter_models`` to + enforce policy everywhere model names appear (tool selection, CLI + commands, etc.). + """ + + # Environment variable names + ENV_VARS = { + ProviderType.OPENAI: "OPENAI_ALLOWED_MODELS", + ProviderType.GOOGLE: "GOOGLE_ALLOWED_MODELS", + ProviderType.XAI: "XAI_ALLOWED_MODELS", + ProviderType.OPENROUTER: "OPENROUTER_ALLOWED_MODELS", + ProviderType.DIAL: "DIAL_ALLOWED_MODELS", + } + + def __init__(self): + """Initialize the restriction service by loading from environment.""" + self.restrictions: dict[ProviderType, set[str]] = {} + self._load_from_env() + + def _load_from_env(self) -> None: + """Load restrictions from environment variables.""" + for provider_type, env_var in self.ENV_VARS.items(): + env_value = os.getenv(env_var) + + if env_value is None or env_value == "": + # Not set or empty - no restrictions (allow all models) + logger.debug(f"{env_var} not set or empty - all {provider_type.value} models allowed") + continue + + # Parse comma-separated list + models = set() + for model in env_value.split(","): + cleaned = model.strip().lower() + if cleaned: + models.add(cleaned) + + if models: + self.restrictions[provider_type] = models + logger.info(f"{provider_type.value} allowed models: {sorted(models)}") + else: + # All entries were empty after cleaning - treat as no restrictions + logger.debug(f"{env_var} contains only whitespace - all {provider_type.value} models allowed") + + def validate_against_known_models(self, provider_instances: dict[ProviderType, any]) -> None: + """ + Validate restrictions against known models from providers. + + This should be called after providers are initialized to warn about + typos or invalid model names in the restriction lists. + + Args: + provider_instances: Dictionary of provider type to provider instance + """ + for provider_type, allowed_models in self.restrictions.items(): + provider = provider_instances.get(provider_type) + if not provider: + continue + + # Get all supported models using the clean polymorphic interface + try: + # Gather canonical models and aliases with consistent formatting + all_models = provider.list_models( + respect_restrictions=False, + include_aliases=True, + lowercase=True, + unique=True, + ) + supported_models = set(all_models) + except Exception as e: + logger.debug(f"Could not get model list from {provider_type.value} provider: {e}") + supported_models = set() + + # Check each allowed model + for allowed_model in allowed_models: + if allowed_model not in supported_models: + logger.warning( + f"Model '{allowed_model}' in {self.ENV_VARS[provider_type]} " + f"is not a recognized {provider_type.value} model. " + f"Please check for typos. Known models: {sorted(supported_models)}" + ) + + def is_allowed(self, provider_type: ProviderType, model_name: str, original_name: Optional[str] = None) -> bool: + """ + Check if a model is allowed for a specific provider. + + Args: + provider_type: The provider type (OPENAI, GOOGLE, etc.) + model_name: The canonical model name (after alias resolution) + original_name: The original model name before alias resolution (optional) + + Returns: + True if allowed (or no restrictions), False if restricted + """ + if provider_type not in self.restrictions: + # No restrictions for this provider + return True + + allowed_set = self.restrictions[provider_type] + + if len(allowed_set) == 0: + # Empty set - allowed + return True + + # Check both the resolved name and original name (if different) + names_to_check = {model_name.lower()} + if original_name and original_name.lower() != model_name.lower(): + names_to_check.add(original_name.lower()) + + # If any of the names is in the allowed set, it's allowed + return any(name in allowed_set for name in names_to_check) + + def get_allowed_models(self, provider_type: ProviderType) -> Optional[set[str]]: + """ + Get the set of allowed models for a provider. + + Args: + provider_type: The provider type + + Returns: + Set of allowed model names, or None if no restrictions + """ + return self.restrictions.get(provider_type) + + def has_restrictions(self, provider_type: ProviderType) -> bool: + """ + Check if a provider has any restrictions. + + Args: + provider_type: The provider type + + Returns: + True if restrictions exist, False otherwise + """ + return provider_type in self.restrictions + + def filter_models(self, provider_type: ProviderType, models: list[str]) -> list[str]: + """ + Filter a list of models based on restrictions. + + Args: + provider_type: The provider type + models: List of model names to filter + + Returns: + Filtered list containing only allowed models + """ + if not self.has_restrictions(provider_type): + return models + + return [m for m in models if self.is_allowed(provider_type, m)] + + def get_restriction_summary(self) -> dict[str, any]: + """ + Get a summary of all restrictions for logging/debugging. + + Returns: + Dictionary with provider names and their restrictions + """ + summary = {} + for provider_type, allowed_set in self.restrictions.items(): + if allowed_set: + summary[provider_type.value] = sorted(allowed_set) + else: + summary[provider_type.value] = "none (provider disabled)" + + return summary + + +# Global instance (singleton pattern) +_restriction_service: Optional[ModelRestrictionService] = None + + +def get_restriction_service() -> ModelRestrictionService: + """ + Get the global restriction service instance. + + Returns: + The singleton ModelRestrictionService instance + """ + global _restriction_service + if _restriction_service is None: + _restriction_service = ModelRestrictionService() + return _restriction_service diff --git a/utils/security_config.py b/utils/security_config.py new file mode 100644 index 0000000..ce8fb29 --- /dev/null +++ b/utils/security_config.py @@ -0,0 +1,104 @@ +""" +Security configuration and path validation constants + +This module contains security-related constants and configurations +for file access control. +""" + +from pathlib import Path + +# Dangerous paths that should never be scanned +# These would give overly broad access and pose security risks +DANGEROUS_PATHS = { + "/", + "/etc", + "/usr", + "/bin", + "/var", + "/root", + "/home", + "C:\\", + "C:\\Windows", + "C:\\Program Files", + "C:\\Users", +} + +# Directories to exclude from recursive file search +# These typically contain generated code, dependencies, or build artifacts +EXCLUDED_DIRS = { + # Python + "__pycache__", + ".venv", + "venv", + "env", + ".env", + "*.egg-info", + ".eggs", + "wheels", + ".Python", + ".mypy_cache", + ".pytest_cache", + ".tox", + "htmlcov", + ".coverage", + "coverage", + # Node.js / JavaScript + "node_modules", + ".next", + ".nuxt", + "bower_components", + ".sass-cache", + # Version Control + ".git", + ".svn", + ".hg", + # Build Output + "build", + "dist", + "target", + "out", + # IDEs + ".idea", + ".vscode", + ".sublime", + ".atom", + ".brackets", + # Temporary / Cache + ".cache", + ".temp", + ".tmp", + "*.swp", + "*.swo", + "*~", + # OS-specific + ".DS_Store", + "Thumbs.db", + # Java / JVM + ".gradle", + ".m2", + # Documentation build + "_build", + "site", + # Mobile development + ".expo", + ".flutter", + # Package managers + "vendor", +} + + +def is_dangerous_path(path: Path) -> bool: + """ + Check if a path is in the dangerous paths list. + + Args: + path: Path to check + + Returns: + True if the path is dangerous and should not be accessed + """ + try: + resolved = path.resolve() + return str(resolved) in DANGEROUS_PATHS or resolved.parent == resolved + except Exception: + return True # If we can't resolve, consider it dangerous diff --git a/utils/storage_backend.py b/utils/storage_backend.py new file mode 100644 index 0000000..0951aab --- /dev/null +++ b/utils/storage_backend.py @@ -0,0 +1,113 @@ +""" +In-memory storage backend for conversation threads + +This module provides a thread-safe, in-memory alternative to Redis for storing +conversation contexts. It's designed for ephemeral MCP server sessions where +conversations only need to persist during a single Claude session. + +⚠️ PROCESS-SPECIFIC STORAGE: This storage is confined to a single Python process. + Data stored in one process is NOT accessible from other processes or subprocesses. + This is why simulator tests that run server.py as separate subprocesses cannot + share conversation state between tool calls. + +Key Features: +- Thread-safe operations using locks +- TTL support with automatic expiration +- Background cleanup thread for memory management +- Singleton pattern for consistent state within a single process +- Drop-in replacement for Redis storage (for single-process scenarios) +""" + +import logging +import os +import threading +import time +from typing import Optional + +logger = logging.getLogger(__name__) + + +class InMemoryStorage: + """Thread-safe in-memory storage for conversation threads""" + + def __init__(self): + self._store: dict[str, tuple[str, float]] = {} + self._lock = threading.Lock() + # Match Redis behavior: cleanup interval based on conversation timeout + # Run cleanup at 1/10th of timeout interval (e.g., 18 mins for 3 hour timeout) + timeout_hours = int(os.getenv("CONVERSATION_TIMEOUT_HOURS", "3")) + self._cleanup_interval = (timeout_hours * 3600) // 10 + self._cleanup_interval = max(300, self._cleanup_interval) # Minimum 5 minutes + self._shutdown = False + + # Start background cleanup thread + self._cleanup_thread = threading.Thread(target=self._cleanup_worker, daemon=True) + self._cleanup_thread.start() + + logger.info( + f"In-memory storage initialized with {timeout_hours}h timeout, cleanup every {self._cleanup_interval//60}m" + ) + + def set_with_ttl(self, key: str, ttl_seconds: int, value: str) -> None: + """Store value with expiration time""" + with self._lock: + expires_at = time.time() + ttl_seconds + self._store[key] = (value, expires_at) + logger.debug(f"Stored key {key} with TTL {ttl_seconds}s") + + def get(self, key: str) -> Optional[str]: + """Retrieve value if not expired""" + with self._lock: + if key in self._store: + value, expires_at = self._store[key] + if time.time() < expires_at: + logger.debug(f"Retrieved key {key}") + return value + else: + # Clean up expired entry + del self._store[key] + logger.debug(f"Key {key} expired and removed") + return None + + def setex(self, key: str, ttl_seconds: int, value: str) -> None: + """Redis-compatible setex method""" + self.set_with_ttl(key, ttl_seconds, value) + + def _cleanup_worker(self): + """Background thread that periodically cleans up expired entries""" + while not self._shutdown: + time.sleep(self._cleanup_interval) + self._cleanup_expired() + + def _cleanup_expired(self): + """Remove all expired entries""" + with self._lock: + current_time = time.time() + expired_keys = [k for k, (_, exp) in self._store.items() if exp < current_time] + for key in expired_keys: + del self._store[key] + + if expired_keys: + logger.debug(f"Cleaned up {len(expired_keys)} expired conversation threads") + + def shutdown(self): + """Graceful shutdown of background thread""" + self._shutdown = True + if self._cleanup_thread.is_alive(): + self._cleanup_thread.join(timeout=1) + + +# Global singleton instance +_storage_instance = None +_storage_lock = threading.Lock() + + +def get_storage_backend() -> InMemoryStorage: + """Get the global storage instance (singleton pattern)""" + global _storage_instance + if _storage_instance is None: + with _storage_lock: + if _storage_instance is None: + _storage_instance = InMemoryStorage() + logger.info("Initialized in-memory conversation storage") + return _storage_instance diff --git a/utils/token_utils.py b/utils/token_utils.py new file mode 100644 index 0000000..393669e --- /dev/null +++ b/utils/token_utils.py @@ -0,0 +1,54 @@ +""" +Token counting utilities for managing API context limits + +This module provides functions for estimating token counts to ensure +requests stay within the Gemini API's context window limits. + +Note: The estimation uses a simple character-to-token ratio which is +approximate. For production systems requiring precise token counts, +consider using the actual tokenizer for the specific model. +""" + +# Default fallback for token limit (conservative estimate) +DEFAULT_CONTEXT_WINDOW = 200_000 # Conservative fallback for unknown models + + +def estimate_tokens(text: str) -> int: + """ + Estimate token count using a character-based approximation. + + This uses a rough heuristic where 1 token ≈ 4 characters, which is + a reasonable approximation for English text. The actual token count + may vary based on: + - Language (non-English text may have different ratios) + - Code vs prose (code often has more tokens per character) + - Special characters and formatting + + Args: + text: The text to estimate tokens for + + Returns: + int: Estimated number of tokens + """ + return len(text) // 4 + + +def check_token_limit(text: str, context_window: int = DEFAULT_CONTEXT_WINDOW) -> tuple[bool, int]: + """ + Check if text exceeds the specified token limit. + + This function is used to validate that prepared prompts will fit + within the model's context window, preventing API errors and ensuring + reliable operation. + + Args: + text: The text to check + context_window: The model's context window size (defaults to conservative fallback) + + Returns: + Tuple[bool, int]: (is_within_limit, estimated_tokens) + - is_within_limit: True if the text fits within context_window + - estimated_tokens: The estimated token count + """ + estimated = estimate_tokens(text) + return estimated <= context_window, estimated