Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from google.adk.agents import Agent
- from config import GEMINI_MODEL_NAME
- from google.adk.planners import BuiltInPlanner
- from google.genai import types
- import os
- import asyncio
- from typing import AsyncGenerator
- import uuid
- from pydantic import BaseModel
- from google.adk.agents.callback_context import CallbackContext
- from google.adk.models import LlmResponse, LlmRequest
- from typing import Dict, List, Optional, Any
- import uvicorn
- class PartialResultFunctionCall(BaseModel):
- partial_result: Any = ""
- partial_result_function_call_id: str
- tool_response_complete: bool = False
- class PartialResultFunctionCallQueue:
- """
- contains the results from the function calls
- """
- _instance = None
- @staticmethod
- def instance():
- if PartialResultFunctionCallQueue._instance is None:
- PartialResultFunctionCallQueue._instance = PartialResultFunctionCallQueue()
- return PartialResultFunctionCallQueue._instance
- def __init__(self):
- if PartialResultFunctionCallQueue._instance is not None:
- raise Exception("This class is a singleton!")
- else:
- PartialResultFunctionCallQueue._instance = self
- self.queue = {}
- def push_result(self, partial_result_function_call_id: str,
- result: PartialResultFunctionCall) -> None:
- if not self.queue.get(partial_result_function_call_id):
- self.queue[partial_result_function_call_id] = []
- self.queue[partial_result_function_call_id].append(result)
- def pop_result(self, partial_result_function_call_id: str) -> Optional[PartialResultFunctionCall]:
- if partial_result_function_call_id in self.queue and self.queue[partial_result_function_call_id]:
- return self.queue[partial_result_function_call_id].pop(0)
- return None
- function_call_partial_result_queue = PartialResultFunctionCallQueue.instance()
- async def monitor_stock_price(stock_symbol: str) -> AsyncGenerator[str, None]:
- """This function will monitor the price for the given stock_symbol in a continuous, streaming and asynchronously way."""
- print(f"Start monitor stock price for {stock_symbol}!")
- # Let's mock stock price change.
- await asyncio.sleep(4)
- price_alert1 = f"the price for {stock_symbol} is 300"
- yield price_alert1
- print(price_alert1)
- await asyncio.sleep(4)
- price_alert1 = f"the price for {stock_symbol} is 400"
- yield price_alert1
- print(price_alert1)
- await asyncio.sleep(4)
- price_alert1 = f"the price for {stock_symbol} is 900"
- yield price_alert1
- print(price_alert1)
- await asyncio.sleep(4)
- price_alert1 = f"the price for {stock_symbol} is 500"
- yield price_alert1
- print(price_alert1)
- async def core_monitor_gold_price(partial_result_function_call_id: str = "") -> None:
- """This function will monitor the price for gold in a continuous, streaming and asynchronously way.
- Args:
- partial_result_function_call_id: Optional, required only function is called subsequent times. For the first call, it can be empty.
- """
- print("Start monitor gold price!")
- # Let's mock gold price change.
- await asyncio.sleep(4)
- price_alert1 = "the price for gold is 300"
- function_call_partial_result_queue.push_result(
- partial_result_function_call_id,
- PartialResultFunctionCall(partial_result=price_alert1,
- partial_result_function_call_id=partial_result_function_call_id,
- tool_response_complete=False))
- print(price_alert1)
- await asyncio.sleep(4)
- price_alert1 = "the price for gold is 400"
- function_call_partial_result_queue.push_result(
- partial_result_function_call_id,
- PartialResultFunctionCall(partial_result=price_alert1,
- partial_result_function_call_id=partial_result_function_call_id,
- tool_response_complete=False))
- print(price_alert1)
- await asyncio.sleep(4)
- price_alert1 = "the price for gold is 900"
- function_call_partial_result_queue.push_result(
- partial_result_function_call_id,
- PartialResultFunctionCall(partial_result=price_alert1,
- partial_result_function_call_id=partial_result_function_call_id,
- tool_response_complete=False))
- print(price_alert1)
- await asyncio.sleep(4)
- price_alert1 = "the price for gold is 500"
- function_call_partial_result_queue.push_result(
- partial_result_function_call_id,
- PartialResultFunctionCall(partial_result=price_alert1,
- partial_result_function_call_id=partial_result_function_call_id,
- tool_response_complete=True))
- print(price_alert1)
- async def listen_to_results_until_tool_result_complete(
- partial_result_function_call_id: str) -> PartialResultFunctionCall:
- """
- """
- while not function_call_partial_result_queue.queue[partial_result_function_call_id]:
- await asyncio.sleep(0.1)
- return function_call_partial_result_queue.pop_result(partial_result_function_call_id)
- async def streaming_tool_caller(function_name: str,
- partial_result_function_call_id: str = None) -> Optional[Any]:
- """Call the streaming function.
- Args:
- function_name: The name of the streaming function to call.
- """
- if function_name == "monitor_gold_price":
- # before the first call generate a unique partial_result_function_call_id
- if not partial_result_function_call_id:
- partial_result_function_call_id = str(uuid.uuid4())
- function_call_partial_result_queue.queue[partial_result_function_call_id] = []
- # return each result along with metadata (whether to call the function again or not)
- # the results from the function should be kept in a result queue and popped from it to return to the agent
- # if there are more than one results in the queue, return all of them, leaving the queue empty each time
- # if more than one results are used, then the tool_response_complete of the last result should be set to used to figure out whether to call the function again
- # call core_monitor_gold_price in background thread
- asyncio.create_task(core_monitor_gold_price(partial_result_function_call_id))
- # wait until the queue is non empty, and pop the result from queue as soon as the result is available
- return await listen_to_results_until_tool_result_complete(partial_result_function_call_id)
- async def monitor_gold_price() -> None:
- """
- This function will monitor the price for gold.
- """
- return await streaming_tool_caller("monitor_gold_price")
- # --- Define the Callback Function ---
- def simple_before_model_modifier(callback_context: CallbackContext,
- llm_request: LlmRequest) -> Optional[LlmResponse]:
- """Inspects/modifies the LLM request or skips the call."""
- agent_name = callback_context.agent_name
- print(f"[Callback] Before model call for agent: {agent_name}")
- # print llm request which is a pydantic model to string
- print(f"[Callback] LLM Request: {str(llm_request)}")
- # Inspect the last user message in the request contents
- print(f"[Callback] Last part: {llm_request.contents[-1].parts[-1]}")
- print(f"[Callback] Last part role: {llm_request.contents[-1].role}")
- # if last part is a function response, and it's a partial response, skip the model call
- if llm_request.contents[-1].parts[-1].function_response:
- print(
- f"[Callback] Last part function response: {llm_request.contents[-1].parts[-1].function_response}"
- )
- print(
- f"[Callback] Last part function response name: {llm_request.contents[-1].parts[-1].function_response.name}"
- )
- function_response = llm_request.contents[-1].parts[-1].function_response
- if function_response:
- function_response_value = function_response.response
- function_response_name = function_response.name
- print(f"[Callback] Last part function response value: {function_response_value}")
- print(f"[Callback] Last part function response name: {function_response_name}")
- if function_response_value and \
- 'result' in function_response_value and \
- isinstance(function_response_value['result'], PartialResultFunctionCall):
- function_response_value = function_response_value['result']
- print(f"[Callback] Last part function response partial result: {function_response_value}")
- if not function_response_value.tool_response_complete:
- # simulate LLMResult and return new tool call using function_response_name and partial_result_function_call_id
- return LlmResponse(content=types.Content(
- role="model",
- parts=[
- types.Part(thought=True,
- text="Partial result update from tool (" + function_response_name +
- "): " + str(function_response_value.partial_result)),
- types.Part(function_call=types.FunctionCall(
- name="listen_to_results_until_tool_result_complete",
- args={
- "partial_result_function_call_id":
- function_response_value.partial_result_function_call_id
- }))
- ]))
- return None
- # if function_response and function_response.response:
- # if not function_response.response.tool_response_complete:
- # result = function_response.response['result']
- # for res in result:
- # print(f"[Callback] Last part function response response result: {res}")
- # print(f"[Callback] Last part function response response result type: {type(res)}")
- # print(f"[Callback] Last part function response response result partial_result: {res.partial_result}")
- # print(f"[Callback] Last part function response response result partial_result_function_call_id: {res.partial_result_function_call_id}")
- # print(f"[Callback] Last part function response response result tool_response_complete: {res.tool_response_complete}")
- # if not res.tool_response_complete:
- # print("[Callback] Last part function response response result is not complete")
- # else:
- # print("[Callback] Last part function response response result is complete")
- # last_model_message = ""
- # if llm_request.contents and llm_request.contents[-1].role == 'model':
- # if llm_request.contents[-1].parts and \
- # llm_request.contents[-1].parts[-1].function_response:
- # last_model_message = llm_request.contents[-1].parts[-1].function_response
- # print(f"[Callback] Function response: '{last_model_message}'")
- # else:
- # last_model_message = llm_request.contents[-1].parts[0].text
- # print(f"[Callback] Inspecting last model message: '{last_model_message}'")
- # --- Modification Example ---
- # Add a prefix to the system instruction
- # original_instruction = llm_request.config.system_instruction or types.Content(role="system", parts=[])
- # prefix = "[Modified by Callback] "
- # # Ensure system_instruction is Content and parts list exists
- # if not isinstance(original_instruction, types.Content):
- # # Handle case where it might be a string (though config expects Content)
- # original_instruction = types.Content(role="system", parts=[types.Part(text=str(original_instruction))])
- # if not original_instruction.parts:
- # original_instruction.parts.append(types.Part(text="")) # Add an empty part if none exist
- # Modify the text of the first part
- # modified_text = prefix + (original_instruction.parts[0].text or "")
- # original_instruction.parts[0].text = modified_text
- # llm_request.config.system_instruction = original_instruction
- # print(f"[Callback] Modified system instruction to: '{modified_text}'")
- # return None
- # --- Skip Example ---
- # Check if the last user message contains "BLOCK"
- # if "BLOCK" in last_user_message.upper():
- # print("[Callback] 'BLOCK' keyword found. Skipping LLM call.")
- # # Return an LlmResponse to skip the actual LLM call
- # return LlmResponse(
- # content=types.Content(
- # role="model",
- # parts=[types.Part(text="LLM call was blocked by before_model_callback.")],
- # )
- # )
- # else:
- # print("[Callback] Proceeding with LLM call.")
- # # Return None to allow the (modified) request to go to the LLM
- # return None
- tools = [monitor_stock_price, monitor_gold_price, listen_to_results_until_tool_result_complete]
- planner = None
- thinking_enabled = os.getenv("THINKING_ENABLED", "True").lower() == "true"
- agent_instruction = """
- You are an agent that can monitor the stock price for a given stock symbol, and gold price.
- """
- if thinking_enabled:
- planner = BuiltInPlanner(
- thinking_config=types.ThinkingConfig(include_thoughts=True, thinking_budget=500))
- root_agent = Agent(
- name="PriceMonitorAgent",
- model=GEMINI_MODEL_NAME,
- description=(
- "A dummy agent that can monitor the stock price for a given stock symbol, and gold price."),
- instruction=(agent_instruction),
- tools=tools,
- before_model_callback=simple_before_model_modifier,
- planner=planner)
- if __name__ == "__main__":
- uvicorn.run(app, host="0.0.0.0", port=8000)
Advertisement
Add Comment
Please, Sign In to add comment