Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import argparse
- import os
- import spacy
- import requests
- import json
- import random
- import time
- import threading
- import os
- seq_start = """<|START_OF_TURN_TOKEN|><|USER_TOKEN|>##Instructions\n\nSummarize this chunk of text. Do not use lists.\n\n##Output\nOutput the summary in the following JSON format:
- {
- "short_summary": "<include a short summary of the text here>",
- "most_important_points": [
- "<one important point>",
- "<another important point>",
- "<another important point>"
- ]
- }\n\n##Text\n"""
- seq_end = """<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"""
- global generated
- nlp = spacy.load('en_core_web_sm')
- def clearConsole():
- command = 'clear'
- if os.name in ('nt', 'dos'):
- command = 'cls'
- os.system(command)
- def poll_generation_status(api_url, genkey):
- global generated
- generated = False
- headers = {
- 'Content-Type': 'application/json',
- 'Authorization': f'Bearer {password}'
- }
- payload = {
- 'genkey': genkey
- }
- while generated is False:
- generating = False
- try:
- response = requests.post(f"{api_url}/api/extra/generate/check", json=payload, headers=headers)
- if response.status_code == 200:
- result = response.json().get('results')[0].get('text')
- if result:
- if result == '' and generating is False:
- generated = False
- continue
- if result == '' and generating is True:
- generated = True
- continue
- else:
- generating = True
- clearConsole()
- print(f"\r{result} ", end="", flush=True)
- except Exception as e:
- if generating is False:
- break
- time.sleep(2)
- def generate_genkey():
- digits = "".join(str(random.randint(0, 9)) for _ in range(4))
- genkey = f"KCP{digits}"
- return genkey
- def read_file(filename):
- try:
- with open(filename, 'r', encoding='utf-8') as file:
- return file.read()
- except Exception as e:
- print(f"Error while reading file '{filename}': {e}")
- return None
- def write_file(filename, data):
- try:
- with open(filename, 'w', encoding='utf-8') as file:
- print(f"Write succes: {filename}")
- file.write(data)
- except Exception as e:
- print(f"Error while writing to file '{filename}': {e}")
- def chunkify(sentences, chunk_size):
- chunk = ''
- chunks = []
- for sentence in sentences:
- sentence_tokens = nlp(sentence)
- if len(sentence_tokens) > chunk_size:
- print(f"Warning: Sentence is longer than {chunk_size} tokens: {sentence}")
- break
- if len(chunk.split()) + len(sentence_tokens) <= chunk_size:
- chunk += ' ' + sentence
- else:
- chunks.append(chunk.strip())
- chunk = sentence
- chunks.append(chunk.strip())
- return chunks
- def generation_from_api(text, api_url, chunk_size):
- genkey = generate_genkey()
- #api = {api_url} + "/api/v1/generate/"
- headers = {
- 'Content-Type': 'application/json',
- 'Authorization': f'Bearer {password}'
- }
- payload = {
- 'prompt': seq_start + text + seq_end,
- 'max_length': int(chunk_size / 2),
- 'max_context_length': chunk_size * 2,
- 'rep_pen': 1,
- 'temperature': 0,
- 'top_p': 1,
- 'top_k': 0,
- 'top_a': 0,
- 'min_p': .05,
- 'tfs': 1,
- 'typical': 1,
- 'n': 1,
- 'genkey': genkey,
- 'quiet': 'quiet'
- }
- global generated
- generated = False
- poll_thread = threading.Thread(target=poll_generation_status, args=(api_url, genkey))
- poll_thread.start()
- try:
- response = requests.post(api_url, json=payload, headers=headers)
- if response.status_code == 200:
- generated = True
- poll_thread.join()
- return response.json().get('results')[0].get('text')
- elif response.status_code == 503:
- print("Server is busy; please try again later.")
- return text
- else:
- print(f"Kobold API responded with status code {response.status_code}: {response.text}")
- return None
- except Exception as e:
- print(f"Error communicating with Kobold API: {e}")
- return None
- def main():
- clearConsole()
- generations = []
- edited_text = []
- parser = argparse.ArgumentParser(description='Chunk a UTF-8 text file and send to LLM to editing.')
- parser.add_argument('filename', help='text file')
- parser.add_argument('--api-url', default='http://172.16.0.219:5001/api/v1/generate/',
- help='the URL of the Kobold API')
- parser.add_argument('--chunksize', default=512, help='max tokens per chunk')
- parser.add_argument('--password', default='', help='server password')
- args = parser.parse_args()
- global password
- password = args.password
- chunk_size = int(args.chunksize)
- content = read_file(args.filename)
- if content is None:
- return
- doc = nlp(content)
- text = ''.join([token.text_with_ws for token in doc if token.lang_ == 'en'])
- if text is None:
- return
- sentences = [sent.text for sent in doc.sents]
- chunks = chunkify(sentences, chunk_size)
- for chunk in chunks:
- if (response := generation_from_api(chunk, args.api_url, chunk_size)) is not None:
- generations.append(response)
- else:
- break
- new_file_content = ' '.join(generations)
- new_filename = os.path.join(os.path.dirname(args.filename), os.path.basename(args.filename) + ".summarized.json")
- write_file(new_filename, new_file_content)
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment