Advertisement
Maroxtn

Untitled

Jun 17th, 2021
850
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.75 KB | None | 0 0
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3.  
  4. # In[ ]:
  5.  
  6.  
  7. import pandas as pd
  8. from sklearn import preprocessing
  9.  
  10. import io
  11. import os
  12. import torch
  13. from tqdm.notebook import tqdm
  14. from torch.utils.data import Dataset, DataLoader
  15. from sklearn.metrics import classification_report, accuracy_score
  16. from transformers import (set_seed,
  17.                           TrainingArguments,
  18.                           Trainer,
  19.                           AutoConfig,
  20.                           AutoTokenizer,
  21.                           AdamW,
  22.                           get_linear_schedule_with_warmup,
  23.                           AutoModelForSequenceClassification)
  24.  
  25. from sklearn.metrics import precision_recall_fscore_support as score
  26.  
  27.  
  28. # In[ ]:
  29.  
  30.  
  31. use_wandb = False
  32. epochs = 8
  33.  
  34.  
  35. # In[3]:
  36.  
  37.  
  38. import random
  39. import numpy as np
  40.  
  41. def set_seed(seed=12):
  42.     """Set seed for reproducibility.
  43.    """
  44.     random.seed(seed)
  45.     np.random.seed(seed)
  46.     torch.manual_seed(seed)
  47.     torch.cuda.manual_seed_all(seed)
  48.    
  49.     os.environ['PYTHONHASHSEED'] = str(seed)
  50.    
  51. set_seed()
  52.  
  53.  
  54. # In[4]:
  55.  
  56.  
  57. if use_wandb:
  58.    
  59.     import wandb
  60.  
  61.     wandb.login(key="69968957548c81fa530d32661ab316213ff08545")
  62.     wandb.init(project="multi-class-text-classifier")
  63.  
  64.  
  65.     wandb.run.name = "two-gpts"
  66.     wandb.run.save()
  67.  
  68.  
  69.     config = wandb.config
  70.     config.description = "Establishing baseline from GPT2 model"
  71.     config.model_used = "GPT2-Large"
  72.  
  73.  
  74. # In[5]:
  75.  
  76.  
  77. def get_model_tokenizer():
  78.    
  79.     tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)
  80.  
  81.  
  82.     model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path=model_name, num_labels=len(labels_ids))
  83.     model = model.to(device)
  84.    
  85.     return model, tokenizer
  86.  
  87.  
  88. # In[6]:
  89.  
  90.  
  91. class Gpt2ClassificationCollator(object):
  92.  
  93.     def __init__(self, use_tokenizer, labels_encoder, max_sequence_len=None):
  94.  
  95.         self.use_tokenizer = use_tokenizer
  96.         self.max_sequence_len = use_tokenizer.model_max_length if max_sequence_len is None else max_sequence_len
  97.         self.labels_encoder = labels_encoder
  98.         return
  99.  
  100.     def __call__(self, sequences):
  101.  
  102.         texts = [sequence['text'] for sequence in sequences]
  103.         labels = [sequence['label'] for sequence in sequences]
  104.         inputs = self.use_tokenizer(text=texts, return_tensors="pt", padding=True, truncation=True,  max_length=self.max_sequence_len)
  105.         inputs.update({'labels':torch.tensor(labels)})
  106.  
  107.         return inputs
  108.  
  109.  
  110. # In[7]:
  111.  
  112.  
  113. class ClassificationDataset(Dataset):
  114.  
  115.  
  116.     def __init__(self, df):
  117.  
  118.         self.texts = df.text.values.tolist()
  119.         self.labels = df.label.values.tolist()
  120.  
  121.     def __len__(self):
  122.  
  123.         return len(self.texts)
  124.  
  125.     def __getitem__(self, item):
  126.  
  127.         return {'text':self.texts[item],
  128.             'label':self.labels[item]}
  129.  
  130.  
  131. # In[8]:
  132.  
  133.  
  134.  
  135.  
  136. def train_epoch(dataloader, optimizer_, scheduler_, device_, model):
  137.  
  138.     predictions_labels = []
  139.     true_labels = []
  140.     total_loss = 0
  141.  
  142.     model.train()
  143.  
  144.     for batch in tqdm(dataloader, total=len(dataloader)):
  145.  
  146.         true_labels += batch['labels'].numpy().flatten().tolist()
  147.         batch = {k:v.type(torch.long).to(device_) for k,v in batch.items()}
  148.  
  149.         model.zero_grad()
  150.  
  151.         outputs = model(**batch)
  152.  
  153.         loss, logits = outputs[:2]
  154.  
  155.         total_loss += loss.item()
  156.  
  157.         loss.backward()
  158.  
  159.         torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
  160.  
  161.         optimizer.step()
  162.  
  163.         scheduler.step()
  164.  
  165.         logits = logits.detach().cpu().numpy()
  166.  
  167.         predictions_labels += logits.argmax(axis=-1).flatten().tolist()
  168.  
  169.         avg_epoch_loss = total_loss / len(dataloader)
  170.  
  171.     return true_labels, predictions_labels, avg_epoch_loss
  172.  
  173.  
  174.  
  175. def validation(dataloader, device_, model):
  176.  
  177.     predictions_labels = []
  178.     true_labels = []
  179.     total_loss = 0
  180.  
  181.     model.eval()
  182.  
  183.     for batch in tqdm(dataloader, total=len(dataloader)):
  184.  
  185.         true_labels += batch['labels'].numpy().flatten().tolist()
  186.         batch = {k:v.type(torch.long).to(device_) for k,v in batch.items()}
  187.  
  188.         with torch.no_grad():        
  189.  
  190.             outputs = model(**batch)
  191.  
  192.             loss, logits = outputs[:2]
  193.             logits = logits.detach().cpu().numpy()
  194.             total_loss += loss.item()
  195.  
  196.             predict_content = logits.argmax(axis=-1).flatten().tolist()
  197.  
  198.             predictions_labels += predict_content
  199.  
  200.     avg_epoch_loss = total_loss / len(dataloader)
  201.    
  202.     return true_labels, predictions_labels, avg_epoch_loss
  203.  
  204.  
  205. # In[9]:
  206.  
  207. import argparse
  208. def train_model(model, epochs, train_dataloader, valid_dataloader, device, scheduler,  prefix=""):
  209.  
  210.     best_val_acc = 0
  211.     best_preds = None
  212.     best_labels = None
  213.  
  214.     if use_wandb:  wandb.watch(model)
  215.  
  216.  
  217.     print('Epoch')
  218.     for epoch in tqdm(range(epochs)):
  219.  
  220.         print()
  221.         print('Training on batches...')
  222.  
  223.         train_labels, train_predict, train_loss = train_epoch(train_dataloader, optimizer, scheduler, device, model)
  224.         train_acc = accuracy_score(train_labels, train_predict)
  225.  
  226.         print('Validation on batches...')
  227.         valid_labels, valid_predict, val_loss = validation(valid_dataloader, device, model)
  228.         val_acc = accuracy_score(valid_labels, valid_predict)
  229.  
  230.         if use_wandb:
  231.             #Log
  232.             wandb.log({"train_loss": train_loss})
  233.             wandb.log({"train_acc": train_acc})
  234.  
  235.             wandb.log({"test_loss": val_loss})
  236.             wandb.log({"test_acc": val_acc})
  237.  
  238.             precision, recall, fscore, support = score(valid_labels, valid_predict)
  239.  
  240.             for k, n in enumerate(labels_ids):
  241.                 wandb.log({f"precision-{n}": precision[k]})
  242.                 wandb.log({f"recall-{n}": recall[k]})
  243.                 wandb.log({f"fscore-{n}": fscore[k]})
  244.  
  245.  
  246.         if val_acc > best_val_acc:
  247.             torch.save(model, prefix+"best_model.pt")
  248.             best_val_acc = val_acc
  249.  
  250.             best_preds = valid_predict
  251.             best_labels = valid_labels
  252.  
  253.  
  254.         # Print loss and accuracy values to see how training evolves.
  255.         print("  train_loss: %.5f - val_loss: %.5f - train_acc: %.5f - valid_acc: %.5f"%(train_loss, val_loss, train_acc, val_acc))
  256.         print()
  257.  
  258.     return best_preds, best_labels, best_val_acc
  259.  
  260.  
  261. def main():
  262.     # In[10]:
  263.    
  264.     parser = argparse.ArgumentParser()
  265.     parser.add_argument("--lr", type=float)
  266.     parser.add_argument("--batch_size", type=int)
  267.  
  268.     args = parser.parse_args()
  269.  
  270.    
  271.     lr, batch_size = args.lr, args.batch_size
  272.     print("============")
  273.     print(str(lr) + " " + str(batch_size))
  274.  
  275.     train = pd.read_csv("../input/upwork02/train.csv").sample(frac=1)
  276.     test = pd.read_csv("../input/upwork02/test.csv").sample(frac=1)
  277.  
  278.  
  279.     train = train[(train.label != 'Expertise - "Teach Me"') & (train.label != "Other") & (train.label != "Business Opportunity ")]
  280.  
  281.     le = preprocessing.LabelEncoder()
  282.     le.fit(train.label)
  283.  
  284.  
  285.     train.label = le.transform(train.label)
  286.     test.label = le.transform(test.label)
  287.  
  288.     train["label_names"] = le.inverse_transform(train.label)
  289.     test["label_names"] = le.inverse_transform(test.label)
  290.  
  291.     labels_ids = dict(zip(le.classes_, le.transform(le.classes_)))
  292.  
  293.     #######################
  294.  
  295.     train_df = train
  296.     test_df = test
  297.  
  298.     model_name = "roberta-large"
  299.     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  300.     model, tokenizer = get_model_tokenizer()
  301.  
  302.  
  303.     gpt2_classificaiton_collator = Gpt2ClassificationCollator(use_tokenizer=tokenizer,
  304.                                                             labels_encoder=labels_ids,
  305.                                                             max_sequence_len=96)
  306.  
  307.  
  308.     train_dataset = ClassificationDataset(train_df)
  309.     train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=gpt2_classificaiton_collator)
  310.  
  311.     valid_dataset = ClassificationDataset(test_df)
  312.     valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=gpt2_classificaiton_collator)
  313.  
  314.  
  315.     optimizer = AdamW(model.parameters(),
  316.                     lr = lr, # default is 5e-5, our notebook had 2e-5
  317.                     eps = 1e-8 # default is 1e-8.
  318.                     )
  319.  
  320.  
  321.     total_steps = len(train_dataloader) * epochs
  322.  
  323.     scheduler = get_linear_schedule_with_warmup(optimizer,
  324.                                                 num_warmup_steps = 0, # Default value in run_glue.py
  325.                                                 num_training_steps = total_steps)
  326.  
  327.     best_preds, best_labels, best_val_acc = train_model(model, epochs, train_dataloader, valid_dataloader, device, scheduler,  prefix="binary_")
  328.  
  329.  
  330.     print(best_val_acc)
  331.  
  332.  
  333. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement