Guest User

llamaindex QA generate

a guest
Nov 10th, 2023
1,015
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.67 KB | None | 0 0
  1. import os
  2. import re
  3. from pathlib import Path
  4. import openai
  5. from llama_index import ServiceContext
  6. from llama_index.llms import OpenAI
  7. from llama_index import VectorStoreIndex
  8. from pathlib import Path
  9. from llama_hub.file.pdf.base import PDFReader
  10. from llama_hub.file.unstructured.base import UnstructuredReader
  11. from llama_hub.file.pymu_pdf.base import PyMuPDFReader
  12. from llama_index import Document
  13. from llama_index.callbacks import CallbackManager
  14. from llama_index.evaluation import DatasetGenerator
  15. from llama_index.indices.list import SummaryIndex
  16. from llama_index.node_parser import SimpleNodeParser
  17. # try evaluation modules
  18. from llama_index.evaluation import RelevancyEvaluator, FaithfulnessEvaluator
  19. from llama_index import PromptTemplate
  20. import json
  21. from copy import deepcopy
  22. import random
  23. from llama_index import ServiceContext
  24. from llama_index.llms import ChatMessage
  25. from langchain.text_splitter import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
  26. from langchain.document_loaders import UnstructuredMarkdownLoader, DirectoryLoader
  27. import tiktoken
  28. from llama_index.text_splitter import SentenceSplitter
  29.  
  30. os.environ["OPENAI_API_KEY"] = "sk-111111111111111111111111111111111111111111111111"
  31. os.environ["OPENAI_API_BASE"] = "http://localhost:5001/v1"
  32. openai.api_key = os.environ["OPENAI_API_KEY"]
  33. openai.api_base = os.environ["OPENAI_API_BASE"]
  34.  
  35.  
  36. docs = []
  37. cleaned_text = []
  38. doc_text_cleaned = []
  39. docs_raw = []
  40.  
  41. def cleanup_text(input_text):
  42.     # Remove all double brackets and characters within
  43.     input_text = re.sub(r'\[.*?\]', '', input_text)
  44.     # Remove all lines starting with #
  45.     input_text = re.sub(r'^#.*', '', input_text, flags=re.MULTILINE)
  46.     # Remove all asterisks
  47.     input_text = re.sub(r'\*', '', input_text)
  48.     # Remove all empty lines
  49.     input_text = re.sub(r'^\s*$', '', input_text, flags=re.MULTILINE)
  50.     # Remove all commas
  51.     input_text = input_text.replace(',', '')
  52.     # Remove all double spaces
  53.     input_text = re.sub(r' +', ' ', input_text)
  54.     # Replace newline characters with spaces
  55.     input_text = input_text.replace('\n', ' ')
  56.     return input_text
  57.  
  58.  
  59. with open('output/Test.txt', 'r') as docs0:
  60.     doc_text = docs0.read()
  61.  
  62. metadata = {"file_name": "Test.pdf"}
  63. docs = [Document(text=doc_text, metadata=metadata)]
  64.  
  65. #print(docs[0].page_content())
  66.  
  67. callback_manager = CallbackManager([])
  68.  
  69. model = OpenAI(model="gpt-3.5-turbo", temperature=0.9),
  70.  
  71. service_context = ServiceContext.from_defaults(
  72.     llm=OpenAI(model="gpt-3.5-turbo", temperature=0.9),
  73.     callback_manager=callback_manager,
  74. )
  75.  
  76. service_context_large = ServiceContext.from_defaults(
  77.     llm=OpenAI(model="gpt-3.5-turbo", temperature=0.9),
  78.     callback_manager=callback_manager,
  79. )
  80.  
  81. text_splitter = SentenceSplitter(
  82.   separator="\n\n",
  83.   chunk_size=512,
  84.   chunk_overlap=0,
  85.   paragraph_separator="\n\n\n",
  86.   secondary_chunking_regex="[^,.;。]+[,.;。]?",
  87.   tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode
  88. )
  89.  
  90. node_parser = SimpleNodeParser.from_defaults(text_splitter=text_splitter)
  91. nodes = node_parser.get_nodes_from_documents(docs)
  92.  
  93. num_questions_per_chunk = 20
  94. question_gen_query = (
  95.     "You are a Teacher/ Professor. Your task is to setup "
  96.     "a quiz/examination. Using the provided context, "
  97.     f"formulate {num_questions_per_chunk} that captures an important fact from the "
  98.     "context. \n"
  99.     "You MUST obey the following criteria:\n"
  100.     "- Restrict the question to the context information provided.\n"
  101.     "- Do NOT create a question that cannot be answered from the context.\n"
  102.     "- Phrase the question so that it does NOT refer to specific context. "
  103.     'For instance, do NOT put phrases like "given provided context" or "in this work" in the question, '
  104.     "because if the question is asked elsewhere it wouldn't be provided specific context. Replace these terms "
  105.     "with specific details.\n"
  106.     "BAD questions:\n"
  107.     "What did the author do in his childhood\n"
  108.     "What were the main findings in this report\n\n"
  109.     "GOOD questions:\n"
  110.     "What did Barack Obama do in his childhood\n"
  111.     "What were the main findings in the original Transformers paper by Vaswani et al.\n\n"
  112.     "Generate the questions below:\n"
  113. )
  114.  
  115. fp = open("data/qa_pairs.jsonl", "w")
  116. for idx, node in enumerate(nodes):
  117.     print (node.text)
  118.  
  119.     dataset_generator = DatasetGenerator(
  120.         [node],
  121.         question_gen_query=question_gen_query,
  122.         service_context=service_context,
  123.         metadata_mode="all",
  124.     )
  125.  
  126.     node_questions_0 = dataset_generator.generate_questions_from_nodes(num=20)
  127.     print(f"[Node {idx}] Generated questions:\n {node_questions_0}")
  128.     # for each question, get a response
  129.     for question in node_questions_0:
  130.         index = SummaryIndex([node], service_context=service_context)
  131.         query_engine = index.as_query_engine()
  132.         response = query_engine.query(question)
  133.         out_dict = {"query": question, "response": str(response)}
  134.         print(f"[Node {idx}] Outputs: {out_dict}")
  135.         fp.write(json.dumps(out_dict) + "\n")
  136.  
  137. fp.close()
  138.  
  139. query_eval_tmpl = PromptTemplate(
  140.     "Your task is to evaluate the following: If the response for the query isn't able to answer the question provided.\n"
  141.     "If query isn't able to answer the question, answer NO.\n"
  142.     "Otherwise answer YES.\n"
  143.     "To elaborate, you might get an answer like the following: 'The context does not contain the answer to this question.'"
  144.     "Please return NO in that case. "
  145.     "You be given the query and response. Return YES or NO as the answer.\n"
  146.     "Query: \n {query_str}\n"
  147.     "Response: \n {response_str}\n"
  148.     "Answer: "
  149. )
  150.  
  151. eval_llm = OpenAI(model="gpt-3.5-turbo")
  152.  
  153. def filter_data(path: str, out_path: str):
  154.     fp = open(path, "r")
  155.     out_fp = open(out_path, "w")
  156.     new_lines = []
  157.     for idx, line in enumerate(fp):
  158.         qa_pair = json.loads(line)
  159.         eval = eval_llm.complete(
  160.             query_eval_tmpl.format(
  161.                 query_str=qa_pair["query"], response_str=qa_pair["response"]
  162.             )
  163.         )
  164.  
  165.         print(f"[{idx}] QA Pair: {qa_pair} \n Eval: {eval}")
  166.         if "NO" in str(eval):
  167.             continue
  168.         else:
  169.             # new_lines.append(line)
  170.             out_fp.write(line)
  171.            
  172. filter_data("data/qa_pairs.jsonl", "data/qa_pairs_2.jsonl")
  173.  
  174. def split_train_val(path: str, out_train_path: str, out_val_path: str, train_split=0.7):
  175.     with open(path, "r") as fp:
  176.         lines = fp.readlines()
  177.  
  178.         # shuffle the lines to make sure that the "train questions" cover most fo the context
  179.         shuffled_lines = deepcopy(lines)
  180.         random.shuffle(shuffled_lines)
  181.  
  182.         split_idx = int(train_split * len(shuffled_lines))
  183.         train_lines = shuffled_lines[:split_idx]
  184.         val_lines = shuffled_lines[split_idx:]
  185.         with open(out_train_path, "w") as out_fp:
  186.             out_fp.write("".join(train_lines))
  187.  
  188.         with open(out_val_path, "w") as out_fp:
  189.             out_fp.write("".join(val_lines))
  190.            
  191. split_train_val(
  192.     "data/qa_pairs_2.jsonl", "data/qa_pairs_train.jsonl", "data/qa_pairs_val.jsonl"
  193. )
  194.  
  195. vp = open("data/qa_pairs_val.jsonl", "r")
  196. fp = open("data/qa_pairs_train.jsonl", "r")
  197. out_fp = open("data/qa_pairs_mistral.jsonl", "w")
  198. out_vp = open("data/qa_pairs_mistral_val.jsonl", "w")
  199.  
  200. for line in fp:
  201.     qa_pair = json.loads(line)
  202.     out_dict = {
  203.         "input": qa_pair["query"],
  204.         "output": qa_pair["response"]
  205.     }
  206.     out_fp.write(json.dumps(out_dict) + "\n")
  207.  
  208. for line in vp:
  209.     vp_pair = json.loads(line)
  210.     out_dict_vp = {
  211.         "input": vp_pair["query"],
  212.         "output": vp_pair["response"]
  213.     }
  214.     out_vp.write(json.dumps(out_dict_vp) + "\n")
Add Comment
Please, Sign In to add comment