Advertisement
witw78

test_infer_physreason.py

Apr 20th, 2025
31
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.82 KB | None | 0 0
  1. import os
  2. import json
  3. from transformers import AutoTokenizer, AutoModelForCausalLM
  4. import torch
  5.  
  6. model_path = "/home/xie.zhongwei/workspace/model/Llama-3.1-8B-Instruct"
  7. json_file_paths = [
  8.     # "/home/xie.zhongwei/extracted_data.json",
  9.     "/home/xie.zhongwei/modified_extracted_data.json"
  10. ]
  11. output_files = [
  12.     # "original_model_results.json",
  13.     "modified_model_results.json"
  14. ]
  15.  
  16. # 加载模型和分词器
  17. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  18. tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
  19. model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, trust_remote_code=True).eval()
  20.  
  21. # 系统提示 (这个system prompt不太好用,可以考虑用用其他的,我之前是将其置为空)
  22. system_prompt = "Please wrap the final answers with specific marks ([Answer] and [/Answer])."
  23.  
  24. def process_json_file(json_file_path, output_file):
  25.     # 读取并处理 JSON 数据
  26.     try:
  27.         with open(json_file_path, 'r', encoding='utf-8') as f:
  28.             data_list = json.load(f)
  29.     except Exception as e:
  30.         print(f"读取 JSON {json_file_path} 失败: {e}")
  31.         return
  32.  
  33.     results = []
  34.     for item in data_list:
  35.         question_structure = item.get('question_structure', {})
  36.         answer = item.get('answer', [])
  37.         context = question_structure.get('context', '')
  38.  
  39.         sub_question_answers = []
  40.         sub_question_keys = [key for key in sorted(question_structure.keys()) if key.startswith('sub_question')]
  41.         for i, sub_question_key in enumerate(sub_question_keys):
  42.             sub_question = question_structure[sub_question_key]
  43.             full_question = f"{system_prompt} {context} {sub_question}".strip()
  44.             try:
  45.                 inputs = tokenizer(full_question, return_tensors="pt").to(device)
  46.                 with torch.no_grad():
  47.                     outputs = model.generate(**inputs, max_new_tokens=96, do_sample=True, top_p=0.85, temperature=0.35)
  48.                 generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
  49.                 sub_question_answers.append(generated_text)
  50.             except Exception as e:
  51.                 print(f"处理子问题 {sub_question} 时出错: {e}")
  52.                 sub_question_answers.append(None)
  53.  
  54.         results.append({
  55.             'question_structure': question_structure,
  56.             'answer': answer,
  57.             'model_output': sub_question_answers
  58.         })
  59.  
  60.     # 保存结果
  61.     with open(output_file, 'w', encoding='utf-8') as f:
  62.         json.dump(results, f, ensure_ascii=False, indent=4)
  63.  
  64.     print(f"结果已保存到 {output_file},包含 {len(results)} 条有效数据。")
  65.  
  66. for json_file_path, output_file in zip(json_file_paths, output_files):
  67.     process_json_file(json_file_path, output_file)
  68.    
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement