Guest User

Untitled

a guest
Jan 26th, 2024
98
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.05 KB | None | 0 0
  1. import sys
  2. import os
  3. from tqdm import tqdm
  4. sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  5.  
  6. from exllamav2 import (
  7. ExLlamaV2,
  8. ExLlamaV2Config,
  9. ExLlamaV2Cache,
  10. ExLlamaV2Tokenizer,
  11. )
  12.  
  13. from exllamav2.generator import (
  14. ExLlamaV2BaseGenerator,
  15. ExLlamaV2Sampler,
  16. )
  17.  
  18. import time
  19. import json
  20.  
  21. # Initialize model and cache
  22.  
  23. model_directory = "yi-200kquanted/4.8bpw/"
  24.  
  25. config = ExLlamaV2Config()
  26. config.model_dir = model_directory
  27. config.prepare()
  28.  
  29. model = ExLlamaV2(config)
  30. print("Loading model: " + model_directory)
  31.  
  32. # allocate 18 GB to CUDA:0 and 24 GB to CUDA:1.
  33. # (Call `model.load()` if using a single GPU.)
  34.  
  35. tokenizer = ExLlamaV2Tokenizer(config)
  36.  
  37.  
  38. model.load([18.5, 22])
  39. cache = ExLlamaV2Cache(model, batch_size=1)
  40. # Initialize generator
  41.  
  42. # Generate some text
  43.  
  44. settings = ExLlamaV2Sampler.Settings()
  45. settings.temperature = 0.9
  46. settings.top_k = 0
  47. settings.top_p = 0.7
  48. settings.token_repetition_penalty = 1.1
  49. settings.disallow_tokens(tokenizer, [tokenizer.eos_token_id])
  50.  
  51. max_new_tokens = 20000
  52. generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)
  53.  
  54. generator.warmup()
  55.  
  56. import json
  57. def parse_jsonl(file_path):
  58. data_list = []
  59. with open(file_path, 'r') as file:
  60. for line in file:
  61. try:
  62. # Parse each line as a JSON object
  63. json_object = json.loads(line.strip())
  64. data_list.append(json_object)
  65. except json.JSONDecodeError as e:
  66. print(f"Error decoding JSON: {e}. Skipping line.")
  67.  
  68. return data_list
  69.  
  70.  
  71.  
  72.  
  73.  
  74. tmpCollector = {}
  75. # Open JSON file from /dev/shm if that fails, open from ./sample4Sum.json
  76. try:
  77. file_path = '/dev/shm/ppr.jsonl'
  78. inputCollector = parse_jsonl(file_path)
  79.  
  80. except:
  81. file_path = '/dev/shm/pprBkup.jsonl'
  82. inputCollector = parse_jsonl(file_path)
  83. os.system('cp /dev/shm/pprBkup.json /dev/shm/ppr.json')
  84.  
  85.  
  86.  
  87.  
  88. for index, row in tqdm(enumerate(inputCollector), total=len(inputCollector), desc="Processing"):
  89. output = ''
  90. if row['rejected'] != '':
  91. continue
  92.  
  93. prompt = f"{row['system']}\nUser:{row['user']}\nAssistant:"
  94. tmpCollector[index] = prompt
  95. if len(tmpCollector) < 1:
  96. continue
  97. output = generator.generate_simple(list(tmpCollector.values()), settings, max_new_tokens)
  98. for singleComp in output:
  99. for singlePrompt in tmpCollector:
  100. initialPrompt = tmpCollector[singlePrompt]
  101. if not initialPrompt in singleComp:
  102. continue
  103. newGen = singleComp.split('Assistant:')[1]
  104.  
  105. print('>>>>>>>>>>>>'+newGen)
  106. print()
  107. inputCollector[singlePrompt]['rejected']=newGen
  108. break
  109. tmpCollector = {}
  110. with open('/dev/shm/pprBkup.jsonl', 'w') as file:
  111. file.seek(0)
  112. json.dump(inputCollector, file)
  113. file.truncate()
  114. # now the save is done, make a copy of it to /dev/shm/sample4SumBackUP.json
  115. os.system('cp /dev/shm/pprBkup.jsonl /dev/shm/ppr.jsonl')
Advertisement
Add Comment
Please, Sign In to add comment