Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from fastapi import FastAPI, Request
- from fastapi.responses import JSONResponse, StreamingResponse
- import uvicorn
- import httpx # Changed from requests to httpx for async
- import numpy as np
- import uuid
- import time
- import json
- import asyncio
- import datetime # Import the datetime module
- # ----------------------
- # CONFIG
- # ----------------------
- SEARXNG_URL = "http://localhost:8888/search"
- EMBED_URL = "http://localhost:8081/v1/embeddings"
- LLM_URL = "http://localhost:8080/v1/chat/completions"
- EMBED_MODEL = "granite-embedding-125m-english-Q8_0"
- LLM_MODEL = "Qwen3-1.7B-Q4_K_M"
- # ----------------------
- # HELPERS
- # ----------------------
- def clean_query(query: str) -> str:
- query = query.strip().rstrip("?")
- stopwords = {
- "'s", "a", "about", "above", "after", "against", "again", "an",
- "and", "are", "as", "at", "be", "before", "below", "between", "description",
- "but", "by", "detail", "details", "did", "during", "do", "does", "explain",
- "for", "from", "further", "give", "has", "have", "he", "her", "help", "his", "him",
- "if", "in", "into", "is", "it", "me", "meant", "no", "noob", "not", "of", "on", "please", # Added 'please' here
- "or", "she", "simple", "such", "tell", "terms", "that", "the", "their",
- "then", "there", "these", "they", "this", "to", "through", "was",
- "what", "what's", "where", "when", "will", "why", "with", "write",
- "who", "who's", "you're", "understand"
- }
- return " ".join([w for w in query.split() if w.lower() not in stopwords])
- # Now an async function using httpx
- async def get_embedding(text: str) -> np.ndarray:
- async with httpx.AsyncClient() as client:
- r = await client.post(EMBED_URL, json={"model": EMBED_MODEL, "input": text})
- r.raise_for_status()
- return np.array(r.json()["data"][0]["embedding"])
- def cosine_similarity(a, b):
- return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
- # pipeline is now an async generator for real-time streaming
- async def pipeline(user_query: str):
- clean_q = clean_query(user_query)
- async with httpx.AsyncClient() as client:
- # 1. Search results
- r = await client.get(SEARXNG_URL, params={"q": clean_q, "format": "json"})
- r.raise_for_status()
- results = r.json()["results"][:10]
- docs = [f"{res.get('title','')}. {res.get('content','')}" for res in results]
- # 2. Rank docs - fetching embeddings concurrently
- q_emb = await get_embedding(clean_q)
- # Create a list of tasks for getting document embeddings
- doc_embedding_tasks = [get_embedding(d) for d in docs]
- # Run them concurrently
- doc_embeddings = await asyncio.gather(*doc_embedding_tasks)
- scored_docs = []
- for i, d in enumerate(docs):
- d_emb = doc_embeddings[i]
- scored_docs.append((cosine_similarity(q_emb, d_emb), d))
- top_docs = [d for _, d in sorted(scored_docs, key=lambda x: x[0], reverse=True)[:4]]
- # **Get the current date and format it**
- current_date = datetime.date.today().strftime("%Y-%m-%d")
- # **3. LLM call**
- context = "\n\n".join(top_docs)
- # **Construct the prompt including the current date**
- prompt = f"The current date is {current_date}. Answer the question based on the following web information. Answer as if you already possess the information:\n\n{context}\n\nQuestion: {user_query}"
- # Make the LLM call with stream=True and iterate over its chunks
- # Use client.stream() for an asynchronous stream
- async with client.stream("POST", LLM_URL, json={
- "model": LLM_MODEL,
- "messages": [{"role": "user", "content": prompt}],
- "max_tokens": 512,
- "temperature": 0.7,
- "stream": True # Request streaming from the LLM API
- }, timeout=None) as r: # Set timeout=None for potentially long streams
- r.raise_for_status()
- # Iterate over raw bytes from the stream
- async for chunk in r.aiter_bytes():
- try:
- # Decode chunk and split by lines (SSE messages)
- lines = chunk.decode('utf-8').split('\n')
- for line in lines:
- if line.startswith('data: '):
- json_data = line[len('data: '):]
- if json_data == '[DONE]':
- yield '[DONE]' # Propagate the DONE signal
- return # Exit the generator
- data = json.loads(json_data)
- # Extract the token. Adjust this path based on your 8080 LLM API's exact structure
- if 'choices' in data and len(data['choices']) > 0 and 'delta' in data['choices'][0] and 'content' in data['choices'][0]['delta']:
- token = data['choices'][0]['delta']['content']
- if token: # Only yield if there's actual content
- yield token
- except json.JSONDecodeError:
- # Handle cases where a chunk might not be a complete JSON line
- # Or where partial JSON is received. You might buffer partial lines.
- # For simplicity, we skip malformed chunks for now.
- pass
- except Exception as e:
- print(f"Error processing LLM stream chunk: {e}")
- # Optionally yield an error message or raise an exception
- pass
- yield '[DONE]' # Ensure DONE is always yielded if the above loop finishes without explicit [DONE]
- # ----------------------
- # OPENAI-COMPATIBLE API
- # ----------------------
- app = FastAPI()
- @app.post("/v1/chat/completions")
- async def chat_completions(request: Request):
- body = await request.json()
- messages = body.get("messages", [])
- user_query = messages[-1]["content"] if messages else ""
- stream = body.get("stream", False)
- if not stream:
- # For non-streaming, we need to collect the full answer from the pipeline generator
- full_answer_tokens = []
- async for token in pipeline(user_query):
- if token == '[DONE]':
- break
- full_answer_tokens.append(token)
- answer = "".join(full_answer_tokens)
- # Non-streaming (normal JSON response)
- response = {
- "id": f"chatcmpl-{uuid.uuid4()}",
- "object": "chat.completion",
- "created": int(time.time()),
- "model": LLM_MODEL,
- "choices": [
- {
- "index": 0,
- "message": {"role": "assistant", "content": answer},
- "finish_reason": "stop"
- }
- ],
- "usage": {
- "prompt_tokens": 0, # You'd need to calculate these based on your prompt and context
- "completion_tokens": len(answer.split()), # Approximate
- "total_tokens": len(answer.split()) # Approximate
- }
- }
- return JSONResponse(content=response)
- # Streaming response (Server-Sent Events style, like OpenAI)
- async def event_generator():
- completion_id = f"chatcmpl-{uuid.uuid4()}"
- created = int(time.time())
- # Iterate directly over the pipeline's yielded tokens
- async for token in pipeline(user_query):
- if token == '[DONE]':
- # Final done message
- done_chunk = {
- "id": completion_id,
- "object": "chat.completion.chunk",
- "created": created,
- "model": LLM_MODEL,
- "choices": [
- {
- "index": 0,
- "delta": {},
- "finish_reason": "stop"
- }
- ]
- }
- yield f"data: {json.dumps(done_chunk)}\n\n"
- break # Exit the loop after sending DONE
- # Yield each token as it arrives from the LLM
- chunk = {
- "id": completion_id,
- "object": "chat.completion.chunk",
- "created": created,
- "model": LLM_MODEL,
- "choices": [
- {
- "index": 0,
- "delta": {"content": token}, # Send token as is, client handles concatenation
- "finish_reason": None
- }
- ]
- }
- yield f"data: {json.dumps(chunk)}\n\n"
- # Ensure a final [DONE] if the pipeline somehow exited without yielding it
- # (e.g., due to an unhandled exception)
- yield "data: [DONE]\n\n"
- return StreamingResponse(event_generator(), media_type="text/event-stream")
- # ----------------------
- # RUN SERVER
- # ----------------------
- if __name__ == "__main__":
- # Ensure uvicorn is running with an appropriate worker count
- # For CPU-bound tasks, you might use num_cores; for IO-bound, more workers can help.
- # In an async app, fewer workers might be fine if I/O is truly non-blocking.
- uvicorn.run(app, host="0.0.0.0", port=8000)
Advertisement
Add Comment
Please, Sign In to add comment