- Core architecture from zen-mcp-server - OpenRouter and Gemini provider configuration - Content variant generator tool (first marketing tool) - Chat tool for marketing strategy - Version and model listing tools - Configuration system with .env support - Logging infrastructure - Ready for Claude Desktop integration
352 lines
11 KiB
Python
352 lines
11 KiB
Python
"""
|
|
Zen-Marketing MCP Server - Main server implementation
|
|
|
|
AI-powered marketing tools for Claude Desktop, specialized for technical B2B content
|
|
creation, multi-platform campaigns, and content variation testing.
|
|
|
|
Based on Zen MCP Server architecture by Fahad Gilani.
|
|
"""
|
|
|
|
import asyncio
|
|
import atexit
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
from logging.handlers import RotatingFileHandler
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
# Load environment variables from .env file
|
|
try:
|
|
from dotenv import load_dotenv
|
|
|
|
script_dir = Path(__file__).parent
|
|
env_file = script_dir / ".env"
|
|
load_dotenv(dotenv_path=env_file, override=False)
|
|
except ImportError:
|
|
pass
|
|
|
|
from mcp.server import Server
|
|
from mcp.server.models import InitializationOptions
|
|
from mcp.server.stdio import stdio_server
|
|
from mcp.types import (
|
|
GetPromptResult,
|
|
Prompt,
|
|
PromptMessage,
|
|
PromptsCapability,
|
|
ServerCapabilities,
|
|
TextContent,
|
|
Tool,
|
|
ToolsCapability,
|
|
)
|
|
|
|
from config import DEFAULT_MODEL, __version__
|
|
from tools.chat import ChatTool
|
|
from tools.contentvariant import ContentVariantTool
|
|
from tools.listmodels import ListModelsTool
|
|
from tools.models import ToolOutput
|
|
from tools.version import VersionTool
|
|
|
|
# Configure logging
|
|
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
|
|
|
|
|
class LocalTimeFormatter(logging.Formatter):
|
|
def formatTime(self, record, datefmt=None):
|
|
ct = self.converter(record.created)
|
|
if datefmt:
|
|
s = time.strftime(datefmt, ct)
|
|
else:
|
|
t = time.strftime("%Y-%m-%d %H:%M:%S", ct)
|
|
s = f"{t},{record.msecs:03.0f}"
|
|
return s
|
|
|
|
|
|
# Configure logging
|
|
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
|
|
root_logger = logging.getLogger()
|
|
root_logger.handlers.clear()
|
|
|
|
stderr_handler = logging.StreamHandler(sys.stderr)
|
|
stderr_handler.setLevel(getattr(logging, log_level, logging.INFO))
|
|
stderr_handler.setFormatter(LocalTimeFormatter(log_format))
|
|
root_logger.addHandler(stderr_handler)
|
|
|
|
root_logger.setLevel(getattr(logging, log_level, logging.INFO))
|
|
|
|
# Add rotating file handler
|
|
try:
|
|
log_dir = Path(__file__).parent / "logs"
|
|
log_dir.mkdir(exist_ok=True)
|
|
|
|
file_handler = RotatingFileHandler(
|
|
log_dir / "mcp_server.log",
|
|
maxBytes=20 * 1024 * 1024,
|
|
backupCount=5,
|
|
encoding="utf-8",
|
|
)
|
|
file_handler.setLevel(getattr(logging, log_level, logging.INFO))
|
|
file_handler.setFormatter(LocalTimeFormatter(log_format))
|
|
logging.getLogger().addHandler(file_handler)
|
|
|
|
mcp_logger = logging.getLogger("mcp_activity")
|
|
mcp_file_handler = RotatingFileHandler(
|
|
log_dir / "mcp_activity.log",
|
|
maxBytes=10 * 1024 * 1024,
|
|
backupCount=2,
|
|
encoding="utf-8",
|
|
)
|
|
mcp_file_handler.setLevel(logging.INFO)
|
|
mcp_file_handler.setFormatter(LocalTimeFormatter("%(asctime)s - %(message)s"))
|
|
mcp_logger.addHandler(mcp_file_handler)
|
|
mcp_logger.setLevel(logging.INFO)
|
|
mcp_logger.propagate = True
|
|
|
|
logging.info(f"Logging to: {log_dir / 'mcp_server.log'}")
|
|
logging.info(f"Process PID: {os.getpid()}")
|
|
except Exception as e:
|
|
print(f"Warning: Could not set up file logging: {e}", file=sys.stderr)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Create MCP server instance
|
|
server: Server = Server("zen-marketing")
|
|
|
|
# Essential tools that cannot be disabled
|
|
ESSENTIAL_TOOLS = {"version", "listmodels"}
|
|
|
|
|
|
def parse_disabled_tools_env() -> set[str]:
|
|
"""Parse DISABLED_TOOLS environment variable"""
|
|
disabled_tools_env = os.getenv("DISABLED_TOOLS", "").strip()
|
|
if not disabled_tools_env:
|
|
return set()
|
|
return {t.strip().lower() for t in disabled_tools_env.split(",") if t.strip()}
|
|
|
|
|
|
def filter_disabled_tools(all_tools: dict[str, Any]) -> dict[str, Any]:
|
|
"""Filter tools based on DISABLED_TOOLS environment variable"""
|
|
disabled_tools = parse_disabled_tools_env()
|
|
if not disabled_tools:
|
|
logger.info("All tools enabled")
|
|
return all_tools
|
|
|
|
enabled_tools = {}
|
|
for tool_name, tool_instance in all_tools.items():
|
|
if tool_name in ESSENTIAL_TOOLS or tool_name not in disabled_tools:
|
|
enabled_tools[tool_name] = tool_instance
|
|
else:
|
|
logger.info(f"Tool '{tool_name}' disabled via DISABLED_TOOLS")
|
|
|
|
logger.info(f"Active tools: {sorted(enabled_tools.keys())}")
|
|
return enabled_tools
|
|
|
|
|
|
# Initialize tool registry
|
|
TOOLS = {
|
|
"chat": ChatTool(),
|
|
"contentvariant": ContentVariantTool(),
|
|
"listmodels": ListModelsTool(),
|
|
"version": VersionTool(),
|
|
}
|
|
TOOLS = filter_disabled_tools(TOOLS)
|
|
|
|
# Prompt templates
|
|
PROMPT_TEMPLATES = {
|
|
"chat": {
|
|
"name": "chat",
|
|
"description": "Chat and brainstorm marketing ideas",
|
|
"template": "Chat about marketing strategy with {model}",
|
|
},
|
|
"contentvariant": {
|
|
"name": "variations",
|
|
"description": "Generate content variations for A/B testing",
|
|
"template": "Generate content variations for testing with {model}",
|
|
},
|
|
"listmodels": {
|
|
"name": "listmodels",
|
|
"description": "List available AI models",
|
|
"template": "List all available models",
|
|
},
|
|
"version": {
|
|
"name": "version",
|
|
"description": "Show server version",
|
|
"template": "Show Zen-Marketing server version",
|
|
},
|
|
}
|
|
|
|
|
|
def configure_providers():
|
|
"""Configure and validate AI providers"""
|
|
logger.debug("Checking environment variables for API keys...")
|
|
|
|
from providers import ModelProviderRegistry
|
|
from providers.gemini import GeminiModelProvider
|
|
from providers.openrouter import OpenRouterProvider
|
|
from providers.shared import ProviderType
|
|
|
|
valid_providers = []
|
|
|
|
# Check for Gemini API key
|
|
gemini_key = os.getenv("GEMINI_API_KEY")
|
|
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
|
valid_providers.append("Gemini")
|
|
logger.info("Gemini API key found - Gemini models available")
|
|
|
|
# Check for OpenRouter API key
|
|
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
|
if openrouter_key and openrouter_key != "your_openrouter_api_key_here":
|
|
valid_providers.append("OpenRouter")
|
|
logger.info("OpenRouter API key found - Multiple models available")
|
|
|
|
# Register providers
|
|
if gemini_key and gemini_key != "your_gemini_api_key_here":
|
|
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
|
|
|
|
if openrouter_key and openrouter_key != "your_openrouter_api_key_here":
|
|
ModelProviderRegistry.register_provider(ProviderType.OPENROUTER, OpenRouterProvider)
|
|
|
|
if not valid_providers:
|
|
error_msg = (
|
|
"No valid API keys found. Please configure at least one:\n"
|
|
" - GEMINI_API_KEY for Google Gemini\n"
|
|
" - OPENROUTER_API_KEY for OpenRouter (minimax, etc.)\n"
|
|
"Create a .env file based on .env.example"
|
|
)
|
|
logger.error(error_msg)
|
|
raise ValueError(error_msg)
|
|
|
|
logger.info(f"Configured providers: {', '.join(valid_providers)}")
|
|
logger.info(f"Default model: {DEFAULT_MODEL}")
|
|
|
|
|
|
# Tool call handlers
|
|
@server.list_tools()
|
|
async def handle_list_tools() -> list[Tool]:
|
|
"""List available tools"""
|
|
tools = []
|
|
for tool_name, tool_instance in TOOLS.items():
|
|
tool_schema = tool_instance.get_input_schema()
|
|
tools.append(
|
|
Tool(
|
|
name=tool_name,
|
|
description=tool_instance.get_description(),
|
|
inputSchema=tool_schema,
|
|
)
|
|
)
|
|
return tools
|
|
|
|
|
|
@server.call_tool()
|
|
async def handle_call_tool(name: str, arguments: dict) -> list[TextContent]:
|
|
"""Handle tool execution requests"""
|
|
logger.info(f"Tool call: {name}")
|
|
|
|
if name not in TOOLS:
|
|
error_msg = f"Unknown tool: {name}"
|
|
logger.error(error_msg)
|
|
return [TextContent(type="text", text=error_msg)]
|
|
|
|
try:
|
|
tool = TOOLS[name]
|
|
result: ToolOutput = await tool.execute(arguments)
|
|
|
|
if result.is_error:
|
|
logger.error(f"Tool {name} failed: {result.text}")
|
|
else:
|
|
logger.info(f"Tool {name} completed successfully")
|
|
|
|
return [TextContent(type="text", text=result.text)]
|
|
|
|
except Exception as e:
|
|
error_msg = f"Tool execution failed: {str(e)}"
|
|
logger.error(f"{error_msg}\n{type(e).__name__}")
|
|
return [TextContent(type="text", text=error_msg)]
|
|
|
|
|
|
@server.list_prompts()
|
|
async def handle_list_prompts() -> list[Prompt]:
|
|
"""List available prompt templates"""
|
|
prompts = []
|
|
for tool_name, template_info in PROMPT_TEMPLATES.items():
|
|
if tool_name in TOOLS:
|
|
prompts.append(
|
|
Prompt(
|
|
name=template_info["name"],
|
|
description=template_info["description"],
|
|
arguments=[
|
|
{"name": "model", "description": "AI model to use", "required": False}
|
|
],
|
|
)
|
|
)
|
|
return prompts
|
|
|
|
|
|
@server.get_prompt()
|
|
async def handle_get_prompt(name: str, arguments: dict | None = None) -> GetPromptResult:
|
|
"""Get a specific prompt template"""
|
|
model = arguments.get("model", DEFAULT_MODEL) if arguments else DEFAULT_MODEL
|
|
|
|
for template_info in PROMPT_TEMPLATES.values():
|
|
if template_info["name"] == name:
|
|
prompt_text = template_info["template"].format(model=model)
|
|
return GetPromptResult(
|
|
description=template_info["description"],
|
|
messages=[PromptMessage(role="user", content=TextContent(type="text", text=prompt_text))],
|
|
)
|
|
|
|
raise ValueError(f"Unknown prompt: {name}")
|
|
|
|
|
|
def cleanup():
|
|
"""Cleanup function called on shutdown"""
|
|
logger.info("Zen-Marketing MCP Server shutting down")
|
|
logging.shutdown()
|
|
|
|
|
|
async def main():
|
|
"""Main entry point"""
|
|
logger.info("=" * 80)
|
|
logger.info(f"Zen-Marketing MCP Server v{__version__}")
|
|
logger.info(f"Python: {sys.version}")
|
|
logger.info(f"Working directory: {os.getcwd()}")
|
|
logger.info("=" * 80)
|
|
|
|
# Register cleanup
|
|
atexit.register(cleanup)
|
|
|
|
# Configure providers
|
|
try:
|
|
configure_providers()
|
|
except ValueError as e:
|
|
logger.error(f"Configuration error: {e}")
|
|
sys.exit(1)
|
|
|
|
# List enabled tools
|
|
logger.info(f"Enabled tools: {sorted(TOOLS.keys())}")
|
|
|
|
# Run server
|
|
capabilities = ServerCapabilities(
|
|
tools=ToolsCapability(),
|
|
prompts=PromptsCapability(),
|
|
)
|
|
|
|
options = InitializationOptions(
|
|
server_name="zen-marketing",
|
|
server_version=__version__,
|
|
capabilities=capabilities,
|
|
)
|
|
|
|
async with stdio_server() as (read_stream, write_stream):
|
|
logger.info("Server started, awaiting requests...")
|
|
await server.run(
|
|
read_stream,
|
|
write_stream,
|
|
options,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|