Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import os
- import json
- from transformers import AutoTokenizer, AutoModelForCausalLM
- import torch
- model_path = "/home/xie.zhongwei/workspace/model/Llama-3.1-8B-Instruct"
- json_file_paths = [
- # "/home/xie.zhongwei/extracted_data.json",
- "/home/xie.zhongwei/modified_extracted_data.json"
- ]
- output_files = [
- # "original_model_results.json",
- "modified_model_results.json"
- ]
- # 加载模型和分词器
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
- model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device, trust_remote_code=True).eval()
- # 系统提示 (这个system prompt不太好用,可以考虑用用其他的,我之前是将其置为空)
- system_prompt = "Please wrap the final answers with specific marks ([Answer] and [/Answer])."
- def process_json_file(json_file_path, output_file):
- # 读取并处理 JSON 数据
- try:
- with open(json_file_path, 'r', encoding='utf-8') as f:
- data_list = json.load(f)
- except Exception as e:
- print(f"读取 JSON {json_file_path} 失败: {e}")
- return
- results = []
- for item in data_list:
- question_structure = item.get('question_structure', {})
- answer = item.get('answer', [])
- context = question_structure.get('context', '')
- sub_question_answers = []
- sub_question_keys = [key for key in sorted(question_structure.keys()) if key.startswith('sub_question')]
- for i, sub_question_key in enumerate(sub_question_keys):
- sub_question = question_structure[sub_question_key]
- full_question = f"{system_prompt} {context} {sub_question}".strip()
- try:
- inputs = tokenizer(full_question, return_tensors="pt").to(device)
- with torch.no_grad():
- outputs = model.generate(**inputs, max_new_tokens=96, do_sample=True, top_p=0.85, temperature=0.35)
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
- sub_question_answers.append(generated_text)
- except Exception as e:
- print(f"处理子问题 {sub_question} 时出错: {e}")
- sub_question_answers.append(None)
- results.append({
- 'question_structure': question_structure,
- 'answer': answer,
- 'model_output': sub_question_answers
- })
- # 保存结果
- with open(output_file, 'w', encoding='utf-8') as f:
- json.dump(results, f, ensure_ascii=False, indent=4)
- print(f"结果已保存到 {output_file},包含 {len(results)} 条有效数据。")
- for json_file_path, output_file in zip(json_file_paths, output_files):
- process_json_file(json_file_path, output_file)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement