Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- MAX_LEN = 100 # max is 512 for BERT
- class text_dataset(Dataset):
- def __init__(self, X, y):
- self.X = X
- self.y = y
- def __getitem__(self,index):
- tokenized = tokenizer.tokenize(self.X[index])
- if len(tokenized) > MAX_LEN : tokenized = tokenized[:MAX_LEN]
- ids = tokenizer.convert_tokens_to_ids(tokenized)
- ids = torch.tensor(ids + [0] * (MAX_LEN - len(ids)))
- labels = [torch.from_numpy(np.array(self.y[index]))]
- return ids, labels[0]
- def __len__(self):
- return len(self.X)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement