Guest User

Alpaca_format_score_PairRM-hf_QA.py

a guest
May 31st, 2024
44
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.87 KB | Source Code | 0 0
  1. # B.2 Smart dedup by score - slow : Modify from https://pastebin.com/qaGG7NSM
  2. import json, ijson
  3. import torch
  4. from transformers import AutoTokenizer
  5. from tqdm import tqdm
  6. from llm_blender.pair_ranker.pairrm import DebertaV2PairRM # pip install git+https://github.com/yuchenlin/LLM-Blender.git
  7.  
  8. reward_name = "llm-blender/PairRM-hf"
  9. rank_model = DebertaV2PairRM.from_pretrained(reward_name, device_map="cuda:0").eval()
  10. tokenizer = AutoTokenizer.from_pretrained(reward_name)
  11. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  12. rank_model.to(device)
  13.  
  14. DEBUG_SCORE = False  # Include the score in the output
  15. SCORE_THRESHOLD = 0.01  # Usefulness minimum score to keep
  16.  
  17. def score_object(obj):
  18.     question = "" + obj['instruction'] + " " + obj['input']
  19.     answer = "" + obj['output'] + ""
  20.     inputs = tokenizer(question + answer, return_tensors='pt').to(device)
  21.  
  22.     with torch.no_grad():
  23.         outputs = rank_model(**inputs)
  24.         score = outputs.logits.item()  # Get the single score
  25.  
  26.     if DEBUG_SCORE:
  27.         score_str = f"{score:.6f}"
  28.         return {**obj, 'score': score_str}, score
  29.     else:
  30.         return obj, score
  31.  
  32. def score_and_sort_data(input_file, output_file):
  33.     with open(input_file, 'r', encoding='utf-8-sig') as f:
  34.         objects = ijson.items(f, 'item')
  35.  
  36.         scored_objects = [score_object(obj) for obj in tqdm(objects, desc="Scoring data")] # Score objects in a list
  37.  
  38.         filtered_objects = []
  39.         for obj, score in scored_objects:
  40.             if score > SCORE_THRESHOLD:
  41.                 filtered_objects.append(obj)
  42.  
  43.         filtered_objects.sort(key=lambda x: x.get('score', 0) if DEBUG_SCORE else score, reverse=True) # Sort by score
  44.  
  45.     with open(output_file, 'w', encoding='utf-8') as f:
  46.         json.dump(filtered_objects, f, ensure_ascii=False, indent=2)
  47.  
  48. score_and_sort_data('en_output.json', 'en_output_scoreQA.json')
Advertisement
Add Comment
Please, Sign In to add comment