Advertisement
kasem1910

Untitled

May 8th, 2021
40
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.66 KB | None | 0 0
  1. import json
  2. from pathlib import Path
  3.  
  4. def read_squad(path):
  5. path = Path(path)
  6. with open(path, 'rb') as f:
  7. squad_dict = json.load(f)
  8.  
  9. contexts = []
  10. questions = []
  11. answers = []
  12. for group in squad_dict['data']:
  13. for passage in group['paragraphs']:
  14. context = passage['context']
  15. for qa in passage['qas']:
  16. question = qa['question']
  17. for answer in qa['answers']:
  18. contexts.append(context)
  19. questions.append(question)
  20. answers.append(answer)
  21.  
  22. return contexts, questions, answers
  23.  
  24. train_contexts, train_questions, train_answers = read_squad('train-v2.0.json')
  25. val_contexts, val_questions, val_answers = read_squad('dev-v2.0.json')
  26.  
  27. def add_end_idx(answers, contexts):
  28. for answer, context in zip(answers, contexts):
  29. gold_text = answer['text']
  30. start_idx = answer['answer_start']
  31. end_idx = start_idx + len(gold_text)
  32.  
  33. # sometimes squad answers are off by a character or two – fix this
  34. if context[start_idx:end_idx] == gold_text:
  35. answer['answer_end'] = end_idx
  36. elif context[start_idx-1:end_idx-1] == gold_text:
  37. answer['answer_start'] = start_idx - 1
  38. answer['answer_end'] = end_idx - 1 # When the gold label is off by one character
  39. elif context[start_idx-2:end_idx-2] == gold_text:
  40. answer['answer_start'] = start_idx - 2
  41. answer['answer_end'] = end_idx - 2 # When the gold label is off by two characters
  42.  
  43. add_end_idx(train_answers, train_contexts)
  44. add_end_idx(val_answers, val_contexts)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement