genericPaster

Kobold_summarizer

Jun 12th, 2024
325
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.92 KB | None | 0 0
  1. import argparse
  2. import os
  3. import spacy
  4. import requests
  5. import json
  6. import random
  7. import time
  8. import threading
  9. import os
  10.  
  11. seq_start = """<|START_OF_TURN_TOKEN|><|USER_TOKEN|>##Instructions\n\nSummarize this chunk of text. Do not use lists.\n\n##Output\nOutput the summary in the following JSON format:
  12. {
  13.  "short_summary": "<include a short summary of the text here>",
  14.  "most_important_points": [
  15.    "<one important point>",
  16.    "<another important point>",
  17.    "<another important point>"
  18.  ]
  19. }\n\n##Text\n"""
  20. seq_end = """<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"""
  21. global generated
  22. nlp = spacy.load('en_core_web_sm')
  23.  
  24. def clearConsole():
  25.     command = 'clear'
  26.     if os.name in ('nt', 'dos'):
  27.         command = 'cls'
  28.     os.system(command)
  29.  
  30. def poll_generation_status(api_url, genkey):
  31.     global generated
  32.     generated = False
  33.     headers = {
  34.         'Content-Type': 'application/json',
  35.         'Authorization': f'Bearer {password}'
  36.     }
  37.     payload = {
  38.         'genkey': genkey
  39.     }
  40.     while generated is False:
  41.         generating = False
  42.         try:
  43.             response = requests.post(f"{api_url}/api/extra/generate/check", json=payload, headers=headers)
  44.             if response.status_code == 200:
  45.                 result = response.json().get('results')[0].get('text')
  46.                 if result:
  47.                     if result == '' and generating is False:
  48.                         generated = False
  49.                         continue
  50.                     if result == '' and generating is True:
  51.                         generated = True
  52.                         continue
  53.                     else:
  54.                         generating = True
  55.                         clearConsole()
  56.                         print(f"\r{result} ", end="", flush=True)
  57.         except Exception as e:
  58.             if generating is False:
  59.                 break
  60.         time.sleep(2)
  61.        
  62. def generate_genkey():
  63.     digits = "".join(str(random.randint(0, 9)) for _ in range(4))
  64.     genkey = f"KCP{digits}"
  65.     return genkey
  66.  
  67. def read_file(filename):
  68.     try:
  69.         with open(filename, 'r', encoding='utf-8') as file:
  70.             return file.read()
  71.     except Exception as e:
  72.         print(f"Error while reading file '{filename}': {e}")
  73.         return None
  74.  
  75. def write_file(filename, data):
  76.     try:
  77.         with open(filename, 'w', encoding='utf-8') as file:
  78.             print(f"Write succes: {filename}")
  79.             file.write(data)
  80.     except Exception as e:
  81.         print(f"Error while writing to file '{filename}': {e}")
  82.  
  83. def chunkify(sentences, chunk_size):
  84.     chunk = ''
  85.     chunks = []
  86.     for sentence in sentences:
  87.         sentence_tokens = nlp(sentence)
  88.         if len(sentence_tokens) > chunk_size:
  89.             print(f"Warning: Sentence is longer than {chunk_size} tokens: {sentence}")
  90.             break
  91.         if len(chunk.split()) + len(sentence_tokens) <= chunk_size:
  92.             chunk += ' ' + sentence
  93.         else:
  94.             chunks.append(chunk.strip())
  95.             chunk = sentence
  96.     chunks.append(chunk.strip())
  97.     return chunks
  98.  
  99. def generation_from_api(text, api_url, chunk_size):
  100.     genkey = generate_genkey()
  101.     #api = {api_url} + "/api/v1/generate/"
  102.     headers = {
  103.         'Content-Type': 'application/json',
  104.         'Authorization': f'Bearer {password}'
  105.     }
  106.     payload = {
  107.         'prompt': seq_start + text + seq_end,
  108.         'max_length': int(chunk_size / 2),
  109.         'max_context_length': chunk_size * 2,
  110.         'rep_pen': 1,
  111.         'temperature': 0,
  112.         'top_p': 1,
  113.         'top_k': 0,
  114.         'top_a': 0,
  115.         'min_p': .05,
  116.         'tfs': 1,
  117.         'typical': 1,
  118.         'n': 1,
  119.         'genkey': genkey,
  120.         'quiet': 'quiet'
  121.     }
  122.     global generated
  123.     generated = False
  124.     poll_thread = threading.Thread(target=poll_generation_status, args=(api_url, genkey))
  125.     poll_thread.start()
  126.     try:
  127.         response = requests.post(api_url, json=payload, headers=headers)
  128.         if response.status_code == 200:
  129.             generated = True
  130.             poll_thread.join()
  131.             return response.json().get('results')[0].get('text')
  132.         elif response.status_code == 503:
  133.             print("Server is busy; please try again later.")
  134.             return text
  135.         else:
  136.             print(f"Kobold API responded with status code {response.status_code}: {response.text}")
  137.             return None
  138.     except Exception as e:
  139.         print(f"Error communicating with Kobold API: {e}")
  140.         return None
  141.  
  142.        
  143. def main():
  144.     clearConsole()
  145.     generations = []
  146.     edited_text = []
  147.     parser = argparse.ArgumentParser(description='Chunk a UTF-8 text file and send to LLM to editing.')
  148.     parser.add_argument('filename', help='text file')
  149.     parser.add_argument('--api-url', default='http://172.16.0.219:5001/api/v1/generate/',
  150.                         help='the URL of the Kobold API')
  151.     parser.add_argument('--chunksize', default=512, help='max tokens per chunk')
  152.     parser.add_argument('--password', default='', help='server password')
  153.  
  154.     args = parser.parse_args()
  155.     global password
  156.     password = args.password
  157.     chunk_size = int(args.chunksize)
  158.     content = read_file(args.filename)
  159.  
  160.     if content is None:
  161.         return
  162.     doc = nlp(content)
  163.     text = ''.join([token.text_with_ws for token in doc if token.lang_ == 'en'])
  164.     if text is None:
  165.         return
  166.     sentences = [sent.text for sent in doc.sents]
  167.     chunks = chunkify(sentences, chunk_size)
  168.  
  169.     for chunk in chunks:
  170.         if (response := generation_from_api(chunk, args.api_url, chunk_size)) is not None:
  171.             generations.append(response)
  172.         else:
  173.             break
  174.     new_file_content = ' '.join(generations)
  175.     new_filename = os.path.join(os.path.dirname(args.filename), os.path.basename(args.filename) + ".summarized.json")
  176.     write_file(new_filename, new_file_content)
  177.  
  178. if __name__ == "__main__":
  179.     main()
Advertisement
Add Comment
Please, Sign In to add comment