Advertisement
Guest User

Untitled

a guest
Sep 18th, 2019
132
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.54 KB | None | 0 0
  1. import pandas as pd
  2.  
  3. TEXT_COL, LABEL_COL = 'text', 'truth'
  4.  
  5. def read_sst5(data_dir, colnames=[LABEL_COL, TEXT_COL]):
  6. datasets = {}
  7. for t in ["train", "dev", "test"]:
  8. df = pd.read_csv(os.path.join(data_dir, f"sst_{t}.txt"), sep='\t', header=None, names=colnames)
  9. df[LABEL_COL] = df[LABEL_COL].str.replace('__label__', '')
  10. df[LABEL_COL] = df[LABEL_COL].astype(int) # Categorical data type for truth labels
  11. df[LABEL_COL] = df[LABEL_COL] - 1 # Zero-index labels for PyTorch
  12. datasets[t] = df
  13. return datasets
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement