Advertisement
Guest User

Untitled

a guest
Sep 25th, 2024
271
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 16.24 KB | None | 0 0
  1. """
  2. title: mcts
  3. author: av
  4. author_url: https://github.com/av
  5. description: mcts - Monte Carlo Tree Search
  6. version: 0.0.5
  7. """
  8.  
  9. import logging
  10. import random
  11. import math
  12. import asyncio
  13. import json
  14. import re
  15.  
  16. from typing import (
  17.   List,
  18.   Optional,
  19.   AsyncGenerator,
  20.   Callable,
  21.   Awaitable,
  22.   Generator,
  23.   Iterator,
  24. )
  25. from open_webui.constants import TASKS
  26. from open_webui.apps.openai import main as ollama
  27.  
  28. # ==============================================================================
  29.  
  30. name = "mcts"
  31. default_max_children = 2
  32. default_exploration_weight = 1.414
  33. default_max_iterations = 2
  34. default_max_simulations = 2
  35. default_thoughts = 2
  36.  
  37. # ==============================================================================
  38.  
  39. thoughts_prompt = """
  40. <instruction>
  41. Give a suggestion on how this answer can be improved.
  42. WRITE ONLY AN IMPROVEMENT SUGGESTION AND NOTHING ELSE.
  43. YOUR REPLY SHOULD BE A SINGLE SENTENCE.
  44. </instruction>
  45.  
  46. <question>
  47. {question}
  48. </question>
  49.  
  50. <draft>
  51. {answer}
  52. </draft>
  53. """.strip()
  54.  
  55. eval_answer_prompt = """
  56. Given the following text:
  57. "{answer}"
  58.  
  59. How well does it answers this question:
  60. "{question}"
  61.  
  62. Rate the answer from 1 to 10, where 1 is completely wrong or irrelevant and 10 is a perfect answer.
  63. Reply with a single number between 1 and 10 only. Do not write anything else, it will be discarded.
  64. THINK CAREFULLY AND USE BEST PRACTICES.
  65. """.strip()
  66.  
  67. analyze_prompt = """
  68. Iteration Analysis:
  69.  
  70. Original question: {question}
  71. Best answer found: {best_answer}
  72. Best score achieved: {best_score}
  73.  
  74. Analyze this iteration of the thought process. Consider the following:
  75. 1. What aspects of the best answer made it successful?
  76. 2. What patterns or approaches led to higher-scoring thoughts?
  77. 3. Were there any common pitfalls or irrelevant tangents in lower-scoring thoughts?
  78. 4. How can the thought generation process be improved for the next iteration?
  79.  
  80. Provide a concise analysis and suggest one specific improvement strategy for the next iteration.
  81. """.strip()
  82.  
  83. update_prompt = """
  84. <instruction>
  85. Your task is to read the question and the answer below, then analyse the given critique.
  86. When you are done - think about how the answer can be improved based on the critique.
  87. WRITE A REVISED ANSWER THAT ADDRESSES THE CRITIQUE. DO NOT WRITE ANYTHING ELSE.
  88. </instruction>
  89. <question>
  90. {question}
  91. </question>
  92. <draft>
  93. {answer}
  94. </draft>
  95. <critique>
  96. {improvements}
  97. </critique>
  98. """.strip()
  99.  
  100. initial_prompt = """
  101. <instruction>
  102. Answer the question below. Do not pay attention to, unexpected casing, punctuation or accent marks.
  103. </instruction>
  104.  
  105. <question>
  106. {question}
  107. </question>
  108. """
  109.  
  110. # ==============================================================================
  111.  
  112.  
  113. def setup_logger():
  114.   logger = logging.getLogger(__name__)
  115.   if not logger.handlers:
  116.     logger.setLevel(logging.DEBUG)
  117.     handler = logging.StreamHandler()
  118.     handler.set_name(name)
  119.     formatter = logging.Formatter(
  120.       "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
  121.     )
  122.     handler.setFormatter(formatter)
  123.     logger.addHandler(handler)
  124.     logger.propagate = False
  125.   return logger
  126.  
  127.  
  128. logger = setup_logger()
  129.  
  130. # ==============================================================================
  131.  
  132. mods = [
  133.   "capitalize",
  134.   "diacritic",
  135.   "leetspeak",
  136.   "remove_vowel",
  137. ]
  138.  
  139.  
  140. def modify_text(text, percentage):
  141.   if not text:
  142.     return "", {}    # Return empty string and empty mapping if input is empty
  143.  
  144.   if not 0 <= percentage <= 100:
  145.     raise ValueError("Percentage must be between 0 and 100")
  146.  
  147.   words = text.split()
  148.   chars = list(text)
  149.   num_chars_to_modify = max(1, int(len(chars) * (percentage / 100)))
  150.   indices_to_modify = random.sample(range(len(chars)), num_chars_to_modify)
  151.   word_mapping = {}
  152.  
  153.   for idx in indices_to_modify:
  154.     modification = random.choice(mods)
  155.  
  156.     # Find the word that contains the current character
  157.     current_length = 0
  158.     for word_idx, word in enumerate(words):
  159.       if current_length <= idx < current_length + len(word):
  160.         original_word = word
  161.         word_start_idx = current_length
  162.         break
  163.       current_length += len(word) + 1    # +1 for the space
  164.     else:
  165.       # If we're here, we're likely dealing with a space or the last character
  166.       continue
  167.  
  168.     if modification == "capitalize":
  169.       chars[idx] = chars[idx].swapcase()
  170.     elif modification == "diacritic":
  171.       if chars[idx].isalpha():
  172.         diacritics = ["̀", "́", "̂", "̃", "̈", "̄", "̆", "̇", "̊", "̋"]
  173.         chars[idx] = chars[idx] + random.choice(diacritics)
  174.     elif modification == "leetspeak":
  175.       leetspeak_map = {
  176.         "a": "4",
  177.         "e": "3",
  178.         "i": "1",
  179.         "o": "0",
  180.         "s": "5",
  181.         "t": "7",
  182.         "b": "8",
  183.         "g": "9",
  184.         "l": "1",
  185.       }
  186.       chars[idx] = leetspeak_map.get(chars[idx].lower(), chars[idx])
  187.     elif modification == "remove_vowel":
  188.       if chars[idx].lower() in "aeiou":
  189.         chars[idx] = ""
  190.  
  191.     modified_word = "".join(
  192.       chars[word_start_idx:word_start_idx + len(original_word)]
  193.     )
  194.  
  195.     if modified_word != original_word:
  196.       # Clean up both the modified word and the original word
  197.       cleaned_modified_word = modified_word.rstrip(".,!?")
  198.       cleaned_original_word = original_word.rstrip(".,!?")
  199.       word_mapping[cleaned_modified_word] = cleaned_original_word
  200.  
  201.   modified_text = "".join(chars)
  202.   return modified_text, word_mapping
  203.  
  204.  
  205. def replace_with_mapping(text, mapping):
  206.   for key, value in mapping.items():
  207.     text = text.replace(key, value)
  208.   return text
  209.  
  210.  
  211. # ==============================================================================
  212.  
  213.  
  214. def escape_mermaid(text):
  215.   return text.replace('"', "&quot;").replace("(", "&#40;").replace(")", "&#41;")
  216.  
  217.  
  218. class Node:
  219.   id: str
  220.   content: str
  221.   parent: Optional["Node"]
  222.   max_children: int
  223.   children: List["Node"]
  224.   visits: int
  225.   value: float
  226.  
  227.   def __init__(self, **kwargs):
  228.     self.id = "".join(random.choices("abcdefghijklmnopqrstuvwxyz", k=4))
  229.     self.content = kwargs.get("content")
  230.     self.parent = kwargs.get("parent")
  231.     self.exploration_weight = kwargs.get(
  232.       "exploration_weight", default_exploration_weight
  233.     )
  234.     self.max_children = kwargs.get("max_children", default_max_children)
  235.     self.children = []
  236.     self.visits = 0
  237.     self.value = 0
  238.  
  239.   def add_child(self, child: "Node"):
  240.     child.parent = self
  241.     self.children.append(child)
  242.     return child
  243.  
  244.   def fully_expanded(self):
  245.     return len(self.children) >= self.max_children
  246.  
  247.   def uct_value(self):
  248.     epsilon = 1e-6
  249.  
  250.     return self.value / (self.visits +
  251.                          epsilon) + self.exploration_weight * math.sqrt(
  252.                            math.log(self.parent.visits) /
  253.                            (self.visits + epsilon)
  254.                          )
  255.  
  256.   def mermaid(self, offset=0, selected=None):
  257.     padding = " " * offset
  258.     msg = f"{padding}{self.id}({self.id}:{self.visits} - {escape_mermaid(self.content[:25])})\n"
  259.  
  260.     if selected == self.id:
  261.       msg += f"{padding}style {self.id} stroke:#0ff\n"
  262.  
  263.     for child in self.children:
  264.       msg += child.mermaid(offset + 4, selected)
  265.       msg += f"{padding}{self.id} --> {child.id}\n"
  266.  
  267.     return msg
  268.  
  269.   def best_child(self):
  270.     if not self.children:
  271.       return self
  272.  
  273.     return max(self.children, key=lambda child: child.visits).best_child()
  274.  
  275.  
  276. class MCTS:
  277.   question: str
  278.   root: Node
  279.   llm: "Pipe"
  280.   selected: Optional[Node]
  281.   exploration_weight: float
  282.  
  283.   def __init__(self, **kwargs):
  284.     self.question = kwargs.get("question")
  285.     self.root = kwargs.get("root")
  286.     self.llm = kwargs.get("llm")
  287.     self.selected = None
  288.     self.exploration_weight = kwargs.get(
  289.       "exploration_weight", default_exploration_weight
  290.     )
  291.  
  292.   async def select(self):
  293.     logger.debug("Selecting node...")
  294.     node = self.root
  295.     while node.children:
  296.       node = self.uct_select(node)
  297.     return node
  298.  
  299.   async def expand(self, node):
  300.     logger.debug(f"Expanding node {node.id}...")
  301.     await self.llm.progress(f"Thinking about {node.id}...")
  302.  
  303.     for _ in range(random.randint(default_thoughts, default_thoughts + 1)):
  304.       await self.llm.emit_replace(self.mermaid(node))
  305.       await self.llm.emit_message(f"Thought: ")
  306.       thought = await self.llm.generate_thought(node.content)
  307.       await self.llm.emit_message(f"\n\n---\n\nSolution:\n")
  308.  
  309.       new_content = await self.llm.update_approach(node.content, thought)
  310.       child = Node(content=new_content, parent=node)
  311.       node.add_child(child)
  312.  
  313.     return random.choice(node.children)
  314.  
  315.   async def simulate(self, node):
  316.     logger.debug(f"Simulating node {node.id}...")
  317.     await self.llm.progress(f"Thinking about {node.id}...")
  318.     await self.llm.emit_replace(self.mermaid())
  319.  
  320.     return await self.llm.evaluate_answer(node.content)
  321.  
  322.   def backpropagate(self, node, score):
  323.     logger.debug(f"Backpropagating from {node.id}...")
  324.     while node:
  325.       node.visits += 1
  326.       node.value += score
  327.       node = node.parent
  328.  
  329.   def uct_select(self, node):
  330.     logger.debug(f"Selecting uct {node.id}...")
  331.     return max(node.children, key=lambda child: child.uct_value())
  332.  
  333.   def best_child(self):
  334.     return self.root.best_child()
  335.  
  336.   async def search(self, num_simulations):
  337.     logger.debug("Starting search...")
  338.  
  339.     for _ in range(num_simulations):
  340.       leaf = await self.select()
  341.       self.selected = leaf
  342.       if not leaf.fully_expanded():
  343.         leaf = await self.expand(leaf)
  344.       score = await self.simulate(leaf)
  345.       self.backpropagate(leaf, score)
  346.  
  347.     return self.selected
  348.  
  349.   def mermaid(self, selected=None):
  350.     return f"""
  351. ```mermaid
  352. graph LR
  353. {self.root.mermaid(0, selected.id if selected else self.selected.id)}
  354. ```
  355. """
  356.  
  357.  
  358. # ==============================================================================
  359.  
  360. EventEmitter = Callable[[dict], Awaitable[None]]
  361.  
  362.  
  363. class Pipe:
  364.   __current_event_emitter__: EventEmitter
  365.   __current_node__: Node
  366.   __question__: str
  367.   __model__: str
  368.  
  369.   def __init__(self):
  370.     self.type = "manifold"
  371.  
  372.   def pipes(self) -> list[dict[str, str]]:
  373.     ollama.get_all_models()
  374.     models = ollama.app.state.MODELS
  375.  
  376.     out = [
  377.       {
  378.         "id": f"{name}-{key}",
  379.         "name": f"{name} {models[key]['name']}"
  380.       } for key in models
  381.     ]
  382.     logger.debug(f"Available models: {out}")
  383.  
  384.     return out
  385.  
  386.   def resolve_model(self, body: dict) -> str:
  387.     model_id = body.get("model")
  388.     without_pipe = ".".join(model_id.split(".")[1:])
  389.     return without_pipe.replace(f"{name}-", "")
  390.  
  391.   def resolve_question(self, body: dict) -> str:
  392.     return body.get("messages")[-1].get("content").strip()
  393.  
  394.   async def pipe(
  395.     self,
  396.     body: dict,
  397.     __user__: dict,
  398.     __event_emitter__=None,
  399.     __task__=None,
  400.     __model__=None,
  401.   ) -> str | Generator | Iterator:
  402.     model = self.resolve_model(body)
  403.     base_question = self.resolve_question(body)
  404.  
  405.     if __task__ == TASKS.TITLE_GENERATION:
  406.       content = await self.get_completion(model, body.get("messages"))
  407.       return f"{name}: {content}"
  408.  
  409.     logger.debug(f"Pipe {name} received: {body}")
  410.     question, mapping = modify_text(base_question, 0)
  411.     logger.debug(f"Question: {question}")
  412.  
  413.     # TODO: concurrency
  414.     self.__model__ = model
  415.     self.__question__ = base_question
  416.     self.__current_event_emitter__ = __event_emitter__
  417.  
  418.     best_answer = None
  419.     best_score = -float("inf")
  420.  
  421.     await self.progress("Preparing initial thoughts...")
  422.     initial_reply = await self.stream_prompt_completion(
  423.       initial_prompt, question=question
  424.     )
  425.  
  426.     root = Node(content=initial_reply)
  427.     mcts = MCTS(root=root, llm=self)
  428.  
  429.     logger.debug("Starting MCTS...")
  430.     for i in range(default_max_iterations):
  431.       logger.debug(f"Iteration {i + 1}/{default_max_iterations}...")
  432.  
  433.       await mcts.search(default_max_simulations)
  434.       logger.debug(mcts.mermaid())
  435.  
  436.       best_child = mcts.best_child()
  437.       score = await self.evaluate_answer(best_child.content)
  438.  
  439.       if score > best_score:
  440.         best_score = score
  441.         best_answer = best_child.content
  442.  
  443.     await self.emit_replace(mcts.mermaid(best_child))
  444.     await self.emit_message(f"{best_answer}")
  445.     await asyncio.sleep(0.2)
  446.     await self.done()
  447.  
  448.     return ""
  449.  
  450.   async def progress(
  451.     self,
  452.     message: str,
  453.   ):
  454.     logger.debug(f"Progress: {message}")
  455.     await self.emit_status("info", message, False)
  456.  
  457.   async def done(self,):
  458.     await self.emit_status("info", "Fin.", True)
  459.  
  460.   async def emit_message(self, message: str):
  461.     await self.__current_event_emitter__(
  462.       {
  463.         "type": "message",
  464.         "data": {
  465.           "content": message
  466.         }
  467.       }
  468.     )
  469.  
  470.   async def emit_replace(self, message: str):
  471.     await self.__current_event_emitter__(
  472.       {
  473.         "type": "replace",
  474.         "data": {
  475.           "content": message
  476.         }
  477.       }
  478.     )
  479.  
  480.   async def emit_status(self, level: str, message: str, done: bool):
  481.     await self.__current_event_emitter__(
  482.       {
  483.         "type": "status",
  484.         "data":
  485.           {
  486.             "status": "complete" if done else "in_progress",
  487.             "level": level,
  488.             "description": message,
  489.             "done": done,
  490.           },
  491.       }
  492.     )
  493.  
  494.   async def get_streaming_completion(
  495.     self,
  496.     model: str,
  497.     messages,
  498.   ) -> AsyncGenerator[str, None]:
  499.     response = await ollama.generate_chat_completion(
  500.       {
  501.         "model": model,
  502.         "messages": messages,
  503.         "stream": True
  504.       }
  505.     )
  506.  
  507.     async for chunk in response.body_iterator:
  508.       for part in self.get_chunk_content(chunk):
  509.         yield part
  510.  
  511.   async def get_message_completion(self, model: str, content):
  512.     async for chunk in self.get_streaming_completion(
  513.       model, [{
  514.         "role": "user",
  515.         "content": content
  516.       }]
  517.     ):
  518.       yield chunk
  519.  
  520.   async def get_completion(self, model: str, messages):
  521.     response = await ollama.generate_chat_completion(
  522.       {
  523.         "model": model,
  524.         "messages": messages,
  525.         "stream": False
  526.       }
  527.     )
  528.  
  529.     return self.get_response_content(response)
  530.  
  531.   async def stream_prompt_completion(self, prompt, **format_args):
  532.     complete = ""
  533.     async for chunk in self.get_message_completion(
  534.       self.__model__,
  535.       prompt.format(**format_args),
  536.     ):
  537.       complete += chunk
  538.       await self.emit_message(chunk)
  539.     return complete
  540.  
  541.   async def generate_thought(self, answer):
  542.     return await self.stream_prompt_completion(
  543.       thoughts_prompt, answer=answer, question=self.__question__
  544.     )
  545.  
  546.   async def analyze_iteration(self, best_answer, best_score):
  547.     return await self.stream_prompt_completion(
  548.       analyze_prompt,
  549.       question=self.__question__,
  550.       best_answer=best_answer,
  551.       best_score=best_score
  552.     )
  553.  
  554.   async def update_approach(self, answer, improvements):
  555.     return await self.stream_prompt_completion(
  556.       update_prompt,
  557.       question=self.__question__,
  558.       answer=answer,
  559.       improvements=improvements
  560.     )
  561.  
  562.   async def evaluate_answer(self, answer):
  563.     result = await self.stream_prompt_completion(
  564.       eval_answer_prompt,
  565.       answer=answer,
  566.       question=self.__question__,
  567.     )
  568.     try:
  569.       score = re.search(r"\d+", result).group()
  570.       return int(score)
  571.     except AttributeError:
  572.       logger.error(f"AnswerEval: unable to parse \"{result[:100]}\"")
  573.       return 0
  574.  
  575.   def get_response_content(self, response):
  576.     try:
  577.       return response["choices"][0]["message"]["content"]
  578.     except (KeyError, IndexError):
  579.       logger.error(
  580.         f"ResponseError: unable to extract content from \"{response[:100]}\""
  581.       )
  582.       return ""
  583.  
  584.   def get_chunk_content(self, chunk):
  585.     chunk_str = chunk.decode("utf-8")
  586.     if chunk_str.startswith("data: "):
  587.       chunk_str = chunk_str[6:]
  588.  
  589.     chunk_str = chunk_str.strip()
  590.  
  591.     if chunk_str == "[DONE]" or not chunk_str:
  592.       return
  593.  
  594.     try:
  595.       chunk_data = json.loads(chunk_str)
  596.       if "choices" in chunk_data and len(chunk_data["choices"]) > 0:
  597.         delta = chunk_data["choices"][0].get("delta", {})
  598.         if "content" in delta:
  599.           yield delta["content"]
  600.     except json.JSONDecodeError:
  601.       logger.error(f"ChunkDecodeError: unable to parse \"{chunk_str[:100]}\"")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement