Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from flask import Flask, request, jsonify
- import boto3
- import time
- import uuid
- import base64
- from langchain.memory import ConversationBufferMemory # LangChain for context
- app = Flask(__name__)
- # Initialize AWS Bedrock client
- bedrock_runtime = boto3.client('bedrock-agent-runtime', region_name="yourregionhere")
- # Mapping of agents with their aliases. repeat the line each time you need to define a new agent
- AGENTS = {
- "agent1": {"id": "AgentID", "alias": "AgentAlias"},
- "agent2": {"id": "AgentID", "alias": "AgentAlias"}
- }
- # API key for request authentication
- API_KEY = "mysecretkey123"
- # Dictionary to store conversation memories
- # Key format: "user_id:conversation_id"
- conversation_memories = {}
- def get_conversation_memory(user_id, conversation_id):
- """
- Retrieves or initializes a conversation memory for a given user and conversation.
- """
- key = f"{user_id}:{conversation_id}"
- if key not in conversation_memories:
- conversation_memories[key] = ConversationBufferMemory(return_messages=True)
- return conversation_memories[key]
- def call_bedrock_agent(agent_id, agent_alias, user_id, conversation_id, user_message, history):
- """
- Sends a user message to the Bedrock agent while maintaining context using LangChain.
- """
- memory = get_conversation_memory(user_id, conversation_id)
- # Add user message to memory
- memory.chat_memory.add_user_message(user_message)
- # Build the full prompt with context
- full_prompt = "\n".join(history) + f"\nUser: {user_message}"
- session_id = str(uuid.uuid4()) # Generate a unique session ID for this request
- try:
- # Invoke Bedrock agent with the contextualized prompt
- response_stream = bedrock_runtime.invoke_agent(
- agentId=agent_id,
- agentAliasId=agent_alias,
- sessionId=session_id,
- inputText=full_prompt
- )["completion"]
- response_text = ""
- # Process streaming response
- for event in response_stream:
- if "chunk" in event:
- chunk_data = event["chunk"].get("bytes", b"")
- if isinstance(chunk_data, bytes):
- try:
- decoded_chunk = chunk_data.decode("utf-8")
- except UnicodeDecodeError:
- decoded_chunk = base64.b64encode(chunk_data).decode("utf-8")
- else:
- decoded_chunk = str(chunk_data)
- response_text += decoded_chunk
- # Add agent response to memory
- memory.chat_memory.add_ai_message(response_text)
- # Return the response along with user_id and conversation_id for tracking
- return jsonify({
- "id": f"chatcmpl-{agent_id[:8]}",
- "object": "chat.completion",
- "created": int(time.time()),
- "model": f"bedrock-agent-{agent_id}",
- "choices": [{
- "index": 0,
- "message": {
- "role": "assistant",
- "content": response_text
- },
- "finish_reason": "stop"
- }],
- "user_id": user_id,
- "conversation_id": conversation_id
- })
- except Exception as e:
- return jsonify({"error": str(e)}), 500
- @app.before_request
- def check_api_key():
- """
- Validates the API key in the request headers.
- """
- api_key = request.headers.get("Authorization")
- if not api_key or not api_key.startswith("Bearer "):
- return jsonify({"error": "Unauthorized"}), 401
- key_value = api_key.split("Bearer ")[1]
- if key_value != API_KEY:
- return jsonify({"error": "Forbidden"}), 403
- @app.route("/v1/chat/completions", methods=["POST"])
- def handle_chat_completions():
- """
- Handles chat completions requests.
- """
- data = request.json
- if not data or "model" not in data:
- return jsonify({"error": "Missing 'model' field in request"}), 400
- model_id = data["model"]
- agent_id = model_id.replace("bedrock-agent-", "")
- agent_info = next((info for info in AGENTS.values() if info["id"] == agent_id), None)
- if not agent_info:
- return jsonify({"error": "Invalid model specified"}), 400
- user_message = data["messages"][-1]["content"]
- # Extract user_id and conversation_id from request
- user_id = data.get("user_id") or request.headers.get("X-User-ID")
- conversation_id = data.get("conversation_id") or request.headers.get("X-Conversation-ID")
- if not user_id:
- user_id = str(uuid.uuid4())
- if not conversation_id:
- conversation_id = str(uuid.uuid4())
- # Retrieve the conversation history
- history = [f"{msg['role']}: {msg['content']}" for msg in data["messages"][:-1]]
- return call_bedrock_agent(agent_info["id"], agent_info["alias"], user_id, conversation_id, user_message, history)
- @app.route("/v1/models", methods=["GET"])
- def list_models():
- """
- Returns a list of available Bedrock agents.
- """
- models = [{"id": f"bedrock-agent-{agent_info['id']}", "object": "model"} for agent_info in AGENTS.values()]
- return jsonify({"object": "list", "data": models})
- # Dynamically create routes for each agent
- def create_agent_view(agent_id, agent_alias):
- def agent_view():
- data = request.json
- if not data or "messages" not in data or not data["messages"]:
- return jsonify({"error": "Invalid request format"}), 400
- user_message = data["messages"][-1]["content"]
- user_id = data.get("user_id") or request.headers.get("X-User-ID")
- conversation_id = data.get("conversation_id") or request.headers.get("X-Conversation-ID")
- if not user_id:
- user_id = str(uuid.uuid4())
- if not conversation_id:
- conversation_id = str(uuid.uuid4())
- # Retrieve the conversation history
- history = [f"{msg['role']}: {msg['content']}" for msg in data["messages"][:-1]]
- return call_bedrock_agent(agent_id, agent_alias, user_id, conversation_id, user_message, history)
- return agent_view
- for agent_name, agent_info in AGENTS.items():
- app.add_url_rule(
- f"/v1/chat/completions/{agent_name}",
- endpoint=f"agent_{agent_name}",
- view_func=create_agent_view(agent_info["id"], agent_info["alias"]),
- methods=["POST"]
- )
- if __name__ == "__main__":
- app.run(host="0.0.0.0", port=5001)
Add Comment
Please, Sign In to add comment