Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import pandas as pd
- TEXT_COL, LABEL_COL = 'text', 'truth'
- def read_sst5(data_dir, colnames=[LABEL_COL, TEXT_COL]):
- datasets = {}
- for t in ["train", "dev", "test"]:
- df = pd.read_csv(os.path.join(data_dir, f"sst_{t}.txt"), sep='\t', header=None, names=colnames)
- df[LABEL_COL] = df[LABEL_COL].str.replace('__label__', '')
- df[LABEL_COL] = df[LABEL_COL].astype(int) # Categorical data type for truth labels
- df[LABEL_COL] = df[LABEL_COL] - 1 # Zero-index labels for PyTorch
- datasets[t] = df
- return datasets
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement