Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import json
- import time
- import tomllib
- import logging
- import asyncio
- import uuid
- import httpx
- from fastapi import FastAPI, Request, HTTPException
- from fastapi.responses import StreamingResponse, JSONResponse
- from fastapi.middleware.cors import CORSMiddleware
- from typing import List, Dict, Any, Optional
- # Setup logging
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
- logger = logging.getLogger(__name__)
- # Load config
- try:
- with open("config.toml", "rb") as f:
- config = tomllib.load(f)
- except FileNotFoundError:
- logger.warning("config.toml not found, using defaults")
- config = {}
- general_config = config.get("general", {})
- PROVIDERS = config.get("providers", [])
- MODEL_PROVIDER_MAP = {}
- ALL_MODELS = []
- for provider in PROVIDERS:
- p_name = provider.get("name", "unknown")
- p_models = provider.get("models", [])
- p_api_base = provider.get("api_base", "").rstrip("/")
- # Resolve API Key
- p_api_key_env = provider.get("api_key_env")
- p_api_key = os.getenv(p_api_key_env) if p_api_key_env else provider.get("api_key")
- if not p_api_base:
- logger.warning(f"Provider {p_name} missing api_base, skipping.")
- continue
- provider_settings = {
- "api_base": p_api_base,
- "api_key": p_api_key,
- "name": p_name
- }
- for m in p_models:
- MODEL_PROVIDER_MAP[m] = provider_settings
- ALL_MODELS.append(m)
- DEFAULT_MODEL = ALL_MODELS[0] if ALL_MODELS else "GLM-4.7"
- app = FastAPI(title="OpenAI-to-Anthropic Raw Proxy")
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
- def convert_messages(openai_messages: List[Dict[str, Any]]) -> tuple[Optional[str], List[Dict[str, Any]]]:
- """
- Extracts system prompt and converts messages to Anthropic format.
- """
- system_prompt = None
- anthropic_messages = []
- for msg in openai_messages:
- role = msg.get("role")
- content = msg.get("content")
- if role == "system":
- if system_prompt:
- system_prompt += "\n" + content
- else:
- system_prompt = content
- elif role in ["user", "assistant"]:
- anthropic_messages.append({"role": role, "content": content})
- else:
- # Anthropic API only allows system, user, and assistant in messages
- logger.warning(f"Skipping unsupported role: {role}")
- return system_prompt, anthropic_messages
- @app.get("/v1/models")
- async def list_models():
- return {
- "object": "list",
- "data": [
- {
- "id": model_id,
- "object": "model",
- "created": int(time.time()),
- "owned_by": MODEL_PROVIDER_MAP.get(model_id, {}).get("name", "anthropic-proxy")
- }
- for model_id in ALL_MODELS
- ]
- }
- @app.post("/v1/chat/completions")
- async def chat_completions(request: Request):
- try:
- body = await request.json()
- except Exception:
- raise HTTPException(status_code=400, detail="Invalid JSON")
- # Extract parameters
- model = body.get("model")
- if not model:
- model = DEFAULT_MODEL
- # Find provider
- provider = MODEL_PROVIDER_MAP.get(model)
- if not provider:
- # Strict mapping: only allow explicitly configured models
- # Unknown models are rejected since we can't determine which API key to use
- raise HTTPException(status_code=404, detail=f"Model '{model}' not configured in proxy.")
- else:
- raise HTTPException(status_code=500, detail="No providers configured.")
- target_api_base = provider["api_base"]
- target_api_key = provider["api_key"]
- messages = body.get("messages", [])
- stream = body.get("stream", False)
- temperature = body.get("temperature")
- top_p = body.get("top_p")
- stop = body.get("stop")
- max_tokens = body.get("max_tokens", 8192)
- # Convert messages
- system_prompt, anthropic_messages = convert_messages(messages)
- # Prepare Anthropic payload
- payload = {
- "model": model,
- "messages": anthropic_messages,
- "max_tokens": max_tokens,
- "stream": stream,
- "thinking": {"type": "enabled", "budget_tokens": 1024}
- }
- if system_prompt:
- payload["system"] = system_prompt
- if temperature is not None:
- payload["temperature"] = temperature
- if top_p is not None:
- payload["top_p"] = top_p
- if stop is not None:
- if isinstance(stop, str):
- payload["stop_sequences"] = [stop]
- else:
- payload["stop_sequences"] = stop
- # Headers
- headers = {
- "x-api-key": target_api_key,
- "anthropic-version": "2023-06-01",
- "content-type": "application/json",
- "accept": "application/json"
- }
- url = f"{target_api_base}/v1/messages"
- logger.info(f"Forwarding request to {url} for model {model} (Provider: {provider['name']})")
- client = httpx.AsyncClient(timeout=60.0)
- try:
- if stream:
- req = client.build_request("POST", url, headers=headers, json=payload)
- r = await client.send(req, stream=True)
- if r.status_code != 200:
- error_content = await r.aread()
- logger.error(f"Target API Error: {r.status_code} - {error_content.decode()}")
- await client.aclose()
- return JSONResponse(
- status_code=r.status_code,
- content={"error": {"message": "Upstream error", "details": error_content.decode()}}
- )
- async def sse_generator():
- chat_id = f"chatcmpl-{uuid.uuid4()}"
- created = int(time.time())
- try:
- async for line in r.aiter_lines():
- if not line or not line.startswith("data: "):
- continue
- data_str = line[6:].strip()
- if data_str == "[DONE]":
- yield "data: [DONE]\n\n"
- break
- try:
- event = json.loads(data_str)
- event_type = event.get("type")
- if event_type == "ping":
- continue
- chunk = {
- "id": chat_id,
- "object": "chat.completion.chunk",
- "created": created,
- "model": model,
- "choices": [
- {
- "index": 0,
- "delta": {},
- "finish_reason": None
- }
- ]
- }
- should_yield = False
- if event_type == "content_block_delta":
- delta = event.get("delta", {})
- delta_type = delta.get("type")
- if delta_type == "text_delta":
- chunk["choices"][0]["delta"]["content"] = delta.get("text", "")
- should_yield = True
- elif delta_type == "thinking_delta":
- chunk["choices"][0]["delta"]["reasoning_content"] = delta.get("thinking", "")
- should_yield = True
- elif event_type == "message_start":
- # Initial role
- chunk["choices"][0]["delta"]["role"] = "assistant"
- should_yield = True
- elif event_type == "message_delta":
- stop_reason = event.get("delta", {}).get("stop_reason")
- if stop_reason:
- chunk["choices"][0]["delta"] = {} # Empty delta
- chunk["choices"][0]["finish_reason"] = stop_reason
- should_yield = True
- elif event_type == "message_stop":
- chunk["choices"][0]["delta"] = {}
- chunk["choices"][0]["finish_reason"] = "stop"
- should_yield = True
- if should_yield:
- yield f"data: {json.dumps(chunk)}\n\n"
- except json.JSONDecodeError:
- logger.warning(f"Failed to parse JSON: {data_str}")
- continue
- except Exception as inner_e:
- logger.error(f"Error processing event {data_str}: {inner_e}", exc_info=True)
- raise inner_e
- except Exception as e:
- logger.error(f"Stream error: {e}")
- yield f"data: {{'error': '{str(e)}'}}\n\n"
- finally:
- await r.aclose()
- await client.aclose()
- return StreamingResponse(sse_generator(), media_type="text/event-stream")
- else:
- response = await client.post(url, headers=headers, json=payload)
- await client.aclose()
- if response.status_code != 200:
- return JSONResponse(
- status_code=response.status_code,
- content={"error": {"message": "Upstream error", "details": response.text}}
- )
- anthropic_resp = response.json()
- # Map response to OpenAI format
- content_blocks = anthropic_resp.get("content", [])
- text_content = ""
- reasoning_content = ""
- for block in content_blocks:
- if block.get("type") == "text":
- text_content += block.get("text", "")
- elif block.get("type") == "thinking":
- reasoning_content += block.get("thinking", "")
- openai_resp = {
- "id": anthropic_resp.get("id"),
- "object": "chat.completion",
- "created": int(time.time()),
- "model": model,
- "choices": [
- {
- "index": 0,
- "message": {
- "role": "assistant",
- "content": text_content,
- },
- "finish_reason": anthropic_resp.get("stop_reason")
- }
- ],
- "usage": {
- "prompt_tokens": anthropic_resp.get("usage", {}).get("input_tokens", 0),
- "completion_tokens": anthropic_resp.get("usage", {}).get("output_tokens", 0),
- "total_tokens": anthropic_resp.get("usage", {}).get("input_tokens", 0) + anthropic_resp.get("usage", {}).get("output_tokens", 0)
- }
- }
- if reasoning_content:
- openai_resp["choices"][0]["message"]["reasoning_content"] = reasoning_content
- openai_resp["usage"]["total_tokens"] = openai_resp["usage"]["prompt_tokens"] + openai_resp["usage"]["completion_tokens"]
- return openai_resp
- except Exception as e:
- await client.aclose()
- logger.error(f"Proxy error: {e}")
- raise HTTPException(status_code=500, detail=str(e))
- if __name__ == "__main__":
- import uvicorn
- port = int(general_config.get("port", 5050))
- print(f"Starting raw proxy on port {port}...")
- uvicorn.run(app, host="0.0.0.0", port=port)
Advertisement
Add Comment
Please, Sign In to add comment