Guest User

Untitled

a guest
Jun 5th, 2025
21
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 13.21 KB | None | 0 0
  1. from google.adk.agents import Agent
  2. from config import GEMINI_MODEL_NAME
  3. from google.adk.planners import BuiltInPlanner
  4. from google.genai import types
  5. import os
  6. import asyncio
  7. from typing import AsyncGenerator
  8. import uuid
  9. from pydantic import BaseModel
  10. from google.adk.agents.callback_context import CallbackContext
  11. from google.adk.models import LlmResponse, LlmRequest
  12. from typing import Dict, List, Optional, Any
  13. import uvicorn
  14.  
  15.  
  16. class PartialResultFunctionCall(BaseModel):
  17.   partial_result: Any = ""
  18.   partial_result_function_call_id: str
  19.   tool_response_complete: bool = False
  20.  
  21.  
  22. class PartialResultFunctionCallQueue:
  23.   """
  24.    contains the results from the function calls
  25.    """
  26.   _instance = None
  27.  
  28.   @staticmethod
  29.   def instance():
  30.     if PartialResultFunctionCallQueue._instance is None:
  31.       PartialResultFunctionCallQueue._instance = PartialResultFunctionCallQueue()
  32.     return PartialResultFunctionCallQueue._instance
  33.  
  34.   def __init__(self):
  35.     if PartialResultFunctionCallQueue._instance is not None:
  36.       raise Exception("This class is a singleton!")
  37.     else:
  38.       PartialResultFunctionCallQueue._instance = self
  39.       self.queue = {}
  40.  
  41.   def push_result(self, partial_result_function_call_id: str,
  42.                   result: PartialResultFunctionCall) -> None:
  43.     if not self.queue.get(partial_result_function_call_id):
  44.       self.queue[partial_result_function_call_id] = []
  45.     self.queue[partial_result_function_call_id].append(result)
  46.  
  47.   def pop_result(self, partial_result_function_call_id: str) -> Optional[PartialResultFunctionCall]:
  48.     if partial_result_function_call_id in self.queue and self.queue[partial_result_function_call_id]:
  49.       return self.queue[partial_result_function_call_id].pop(0)
  50.     return None
  51.  
  52.  
  53. function_call_partial_result_queue = PartialResultFunctionCallQueue.instance()
  54.  
  55.  
  56. async def monitor_stock_price(stock_symbol: str) -> AsyncGenerator[str, None]:
  57.   """This function will monitor the price for the given stock_symbol in a continuous, streaming and asynchronously way."""
  58.   print(f"Start monitor stock price for {stock_symbol}!")
  59.  
  60.   # Let's mock stock price change.
  61.   await asyncio.sleep(4)
  62.   price_alert1 = f"the price for {stock_symbol} is 300"
  63.   yield price_alert1
  64.   print(price_alert1)
  65.  
  66.   await asyncio.sleep(4)
  67.   price_alert1 = f"the price for {stock_symbol} is 400"
  68.   yield price_alert1
  69.   print(price_alert1)
  70.  
  71.   await asyncio.sleep(4)
  72.   price_alert1 = f"the price for {stock_symbol} is 900"
  73.   yield price_alert1
  74.   print(price_alert1)
  75.  
  76.   await asyncio.sleep(4)
  77.   price_alert1 = f"the price for {stock_symbol} is 500"
  78.   yield price_alert1
  79.   print(price_alert1)
  80.  
  81.  
  82. async def core_monitor_gold_price(partial_result_function_call_id: str = "") -> None:
  83.   """This function will monitor the price for gold in a continuous, streaming and asynchronously way.
  84.  
  85.  Args:
  86.    partial_result_function_call_id: Optional, required only function is called subsequent times. For the first call, it can be empty.
  87.  """
  88.   print("Start monitor gold price!")
  89.  
  90.   # Let's mock gold price change.
  91.   await asyncio.sleep(4)
  92.   price_alert1 = "the price for gold is 300"
  93.   function_call_partial_result_queue.push_result(
  94.       partial_result_function_call_id,
  95.       PartialResultFunctionCall(partial_result=price_alert1,
  96.                                 partial_result_function_call_id=partial_result_function_call_id,
  97.                                 tool_response_complete=False))
  98.   print(price_alert1)
  99.  
  100.   await asyncio.sleep(4)
  101.   price_alert1 = "the price for gold is 400"
  102.   function_call_partial_result_queue.push_result(
  103.       partial_result_function_call_id,
  104.       PartialResultFunctionCall(partial_result=price_alert1,
  105.                                 partial_result_function_call_id=partial_result_function_call_id,
  106.                                 tool_response_complete=False))
  107.   print(price_alert1)
  108.  
  109.   await asyncio.sleep(4)
  110.   price_alert1 = "the price for gold is 900"
  111.   function_call_partial_result_queue.push_result(
  112.       partial_result_function_call_id,
  113.       PartialResultFunctionCall(partial_result=price_alert1,
  114.                                 partial_result_function_call_id=partial_result_function_call_id,
  115.                                 tool_response_complete=False))
  116.   print(price_alert1)
  117.  
  118.   await asyncio.sleep(4)
  119.   price_alert1 = "the price for gold is 500"
  120.   function_call_partial_result_queue.push_result(
  121.       partial_result_function_call_id,
  122.       PartialResultFunctionCall(partial_result=price_alert1,
  123.                                 partial_result_function_call_id=partial_result_function_call_id,
  124.                                 tool_response_complete=True))
  125.   print(price_alert1)
  126.  
  127.  
  128. async def listen_to_results_until_tool_result_complete(
  129.     partial_result_function_call_id: str) -> PartialResultFunctionCall:
  130.   """
  131.    """
  132.   while not function_call_partial_result_queue.queue[partial_result_function_call_id]:
  133.     await asyncio.sleep(0.1)
  134.   return function_call_partial_result_queue.pop_result(partial_result_function_call_id)
  135.  
  136.  
  137. async def streaming_tool_caller(function_name: str,
  138.                                 partial_result_function_call_id: str = None) -> Optional[Any]:
  139.   """Call the streaming function.
  140.  
  141.  Args:
  142.    function_name: The name of the streaming function to call.
  143.  """
  144.   if function_name == "monitor_gold_price":
  145.     # before the first call generate a unique partial_result_function_call_id
  146.     if not partial_result_function_call_id:
  147.       partial_result_function_call_id = str(uuid.uuid4())
  148.       function_call_partial_result_queue.queue[partial_result_function_call_id] = []
  149.     # return each result along with metadata (whether to call the function again or not)
  150.     # the results from the function should be kept in a result queue and popped from it to return to the agent
  151.     # if there are more than one results in the queue, return all of them, leaving the queue empty each time
  152.     # if more than one results are used, then the tool_response_complete of the last result should be set to used to figure out whether to call the function again
  153.  
  154.     # call core_monitor_gold_price in background thread
  155.     asyncio.create_task(core_monitor_gold_price(partial_result_function_call_id))
  156.     # wait until the queue is non empty, and pop the result from queue as soon as the result is available
  157.     return await listen_to_results_until_tool_result_complete(partial_result_function_call_id)
  158.  
  159.  
  160. async def monitor_gold_price() -> None:
  161.   """
  162.    This function will monitor the price for gold.
  163.    """
  164.   return await streaming_tool_caller("monitor_gold_price")
  165.  
  166.  
  167. # --- Define the Callback Function ---
  168. def simple_before_model_modifier(callback_context: CallbackContext,
  169.                                  llm_request: LlmRequest) -> Optional[LlmResponse]:
  170.   """Inspects/modifies the LLM request or skips the call."""
  171.   agent_name = callback_context.agent_name
  172.   print(f"[Callback] Before model call for agent: {agent_name}")
  173.  
  174.   # print llm request which is a pydantic model to string
  175.   print(f"[Callback] LLM Request: {str(llm_request)}")
  176.  
  177.   # Inspect the last user message in the request contents
  178.   print(f"[Callback] Last part: {llm_request.contents[-1].parts[-1]}")
  179.   print(f"[Callback] Last part role: {llm_request.contents[-1].role}")
  180.  
  181.   # if last part is a function response, and it's a partial response, skip the model call
  182.   if llm_request.contents[-1].parts[-1].function_response:
  183.     print(
  184.         f"[Callback] Last part function response: {llm_request.contents[-1].parts[-1].function_response}"
  185.     )
  186.     print(
  187.         f"[Callback] Last part function response name: {llm_request.contents[-1].parts[-1].function_response.name}"
  188.     )
  189.     function_response = llm_request.contents[-1].parts[-1].function_response
  190.     if function_response:
  191.       function_response_value = function_response.response
  192.       function_response_name = function_response.name
  193.       print(f"[Callback] Last part function response value: {function_response_value}")
  194.       print(f"[Callback] Last part function response name: {function_response_name}")
  195.       if function_response_value and \
  196.       'result' in function_response_value and \
  197.       isinstance(function_response_value['result'], PartialResultFunctionCall):
  198.         function_response_value = function_response_value['result']
  199.         print(f"[Callback] Last part function response partial result: {function_response_value}")
  200.         if not function_response_value.tool_response_complete:
  201.           # simulate LLMResult and return new tool call using function_response_name and partial_result_function_call_id
  202.           return LlmResponse(content=types.Content(
  203.               role="model",
  204.               parts=[
  205.                   types.Part(thought=True,
  206.                              text="Partial result update from tool (" + function_response_name +
  207.                              "): " + str(function_response_value.partial_result)),
  208.                   types.Part(function_call=types.FunctionCall(
  209.                       name="listen_to_results_until_tool_result_complete",
  210.                       args={
  211.                           "partial_result_function_call_id":
  212.                               function_response_value.partial_result_function_call_id
  213.                       }))
  214.               ]))
  215.   return None
  216.  
  217.   # if function_response and function_response.response:
  218.   #     if not function_response.response.tool_response_complete:
  219.   # result = function_response.response['result']
  220.   # for res in result:
  221.   #     print(f"[Callback] Last part function response response result: {res}")
  222.   #     print(f"[Callback] Last part function response response result type: {type(res)}")
  223.   #     print(f"[Callback] Last part function response response result partial_result: {res.partial_result}")
  224.   #     print(f"[Callback] Last part function response response result partial_result_function_call_id: {res.partial_result_function_call_id}")
  225.   #     print(f"[Callback] Last part function response response result tool_response_complete: {res.tool_response_complete}")
  226.   #     if not res.tool_response_complete:
  227.   #         print("[Callback] Last part function response response result is not complete")
  228.   #     else:
  229.   #         print("[Callback] Last part function response response result is complete")
  230.   # last_model_message = ""
  231.   # if llm_request.contents and llm_request.contents[-1].role == 'model':
  232.   #     if llm_request.contents[-1].parts and \
  233.   #         llm_request.contents[-1].parts[-1].function_response:
  234.   #         last_model_message = llm_request.contents[-1].parts[-1].function_response
  235.   #         print(f"[Callback] Function response: '{last_model_message}'")
  236.   #     else:
  237.   #         last_model_message = llm_request.contents[-1].parts[0].text
  238.   #         print(f"[Callback] Inspecting last model message: '{last_model_message}'")
  239.  
  240.   # --- Modification Example ---
  241.   # Add a prefix to the system instruction
  242.   # original_instruction = llm_request.config.system_instruction or types.Content(role="system", parts=[])
  243.   # prefix = "[Modified by Callback] "
  244.   # # Ensure system_instruction is Content and parts list exists
  245.   # if not isinstance(original_instruction, types.Content):
  246.   #      # Handle case where it might be a string (though config expects Content)
  247.   #      original_instruction = types.Content(role="system", parts=[types.Part(text=str(original_instruction))])
  248.   # if not original_instruction.parts:
  249.   #     original_instruction.parts.append(types.Part(text="")) # Add an empty part if none exist
  250.  
  251.   # Modify the text of the first part
  252.   # modified_text = prefix + (original_instruction.parts[0].text or "")
  253.   # original_instruction.parts[0].text = modified_text
  254.   # llm_request.config.system_instruction = original_instruction
  255.   # print(f"[Callback] Modified system instruction to: '{modified_text}'")
  256.  
  257.   # return None
  258.   # --- Skip Example ---
  259.   # Check if the last user message contains "BLOCK"
  260.   # if "BLOCK" in last_user_message.upper():
  261.   #     print("[Callback] 'BLOCK' keyword found. Skipping LLM call.")
  262.   #     # Return an LlmResponse to skip the actual LLM call
  263.   #     return LlmResponse(
  264.   #         content=types.Content(
  265.   #             role="model",
  266.   #             parts=[types.Part(text="LLM call was blocked by before_model_callback.")],
  267.   #         )
  268.   #     )
  269.   # else:
  270.   #     print("[Callback] Proceeding with LLM call.")
  271.   #     # Return None to allow the (modified) request to go to the LLM
  272.   #     return None
  273.  
  274.  
  275. tools = [monitor_stock_price, monitor_gold_price, listen_to_results_until_tool_result_complete]
  276.  
  277. planner = None
  278.  
  279. thinking_enabled = os.getenv("THINKING_ENABLED", "True").lower() == "true"
  280. agent_instruction = """
  281. You are an agent that can monitor the stock price for a given stock symbol, and gold price.
  282. """
  283.  
  284. if thinking_enabled:
  285.   planner = BuiltInPlanner(
  286.       thinking_config=types.ThinkingConfig(include_thoughts=True, thinking_budget=500))
  287.  
  288. root_agent = Agent(
  289.     name="PriceMonitorAgent",
  290.     model=GEMINI_MODEL_NAME,
  291.     description=(
  292.         "A dummy agent that can monitor the stock price for a given stock symbol, and gold price."),
  293.     instruction=(agent_instruction),
  294.     tools=tools,
  295.     before_model_callback=simple_before_model_modifier,
  296.     planner=planner)
  297.  
  298. if __name__ == "__main__":
  299.   uvicorn.run(app, host="0.0.0.0", port=8000)
  300.  
Advertisement
Add Comment
Please, Sign In to add comment