Advertisement
Guest User

Untitled

a guest
Jan 29th, 2020
142
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.50 KB | None | 0 0
  1. def train_one_epoch(features, clf, critic, db, epoch):
  2.     n_train_iter = db.x_train.shape[0] // batch_size
  3.  
  4.     losses_aux = []
  5.     losses_ce = []
  6.     ces = []
  7.     num = 0
  8.     correct = 0
  9.  
  10.     for batch_idx in range(n_train_iter):
  11.         x_s, y_s, x_q, y_q = db.next("train")
  12.         n_tasks = x_s.size(0)
  13.  
  14.         f_clf = higher.patch.monkeypatch(clf, copy_initial_weights=False)
  15.  
  16.         inner_opt = torch.optim.Adam(clf.parameters(), lr=args.lr_clf)
  17.         diffopt = higher.optim.get_diff_optim(inner_opt, clf.parameters(),
  18.                                               f_clf)
  19.  
  20.         diff_features = higher.patch.monkeypatch(features, copy_initial_weights=False)
  21.         inner_opt_features = torch.optim.Adam(features.parameters(), lr=args.lr_features)
  22.         diffopt_features = higher.optim.get_diff_optim(inner_opt_features,
  23.                                                     features.parameters(),
  24.                                                     diff_features)
  25.  
  26.         diff_clf = higher.patch.monkeypatch(clf, copy_initial_weights=False)
  27.         inner_opt_clf = torch.optim.Adam(clf.parameters(), lr=args.lr_clf)
  28.         diffopt_clf = higher.optim.get_diff_optim(inner_opt_clf,
  29.                                                    clf.parameters(), diff_clf)
  30.  
  31.         # features_old.load_state_dict(features.state_dict())
  32.         clf_opt.zero_grad()
  33.         for i in range(n_tasks):
  34.             features_old.load_state_dict(diff_features.state_dict())
  35.             for j in range(inner_steps):
  36.                 z_new = diff_features(x_s[i])
  37.                 z_old = features_old(x_s[i])
  38.  
  39.                 y_new = diff_clf(z_new)
  40.                 y_old = diff_clf(z_old)
  41.  
  42.                 ce_new = criterion(y_new, y_s[i])
  43.                 ce_old = criterion(y_old, y_s[i])
  44.  
  45.                 aux_loss = torch.mean(critic(z_new))
  46.  
  47.                 features_old_opt.zero_grad()
  48.                 ce_old.backward(retain_graph=True)
  49.                 features_old_opt.step()
  50.  
  51.                 coef = epoch / args.epochs
  52.                 diffopt_clf.step(ce_new + coef * aux_loss)
  53.                 diffopt_features.step(ce_new + coef * aux_loss)
  54.  
  55.                 losses_ce.append(ce_new.item())
  56.                 losses_aux.append(aux_loss.item())
  57.  
  58.             critic_opt.zero_grad()
  59.  
  60.             grads = {}
  61.             for p in clf.parameters():
  62.                 grads.update({p: p.grad.clone()})
  63.  
  64.             z_new = diff_features(x_q[i])
  65.             z_old = features_old(x_q[i])
  66.  
  67.             y_new = diff_clf(z_new)
  68.             y_old = diff_clf(z_old)
  69.  
  70.             ce_new = criterion(y_new, y_q[i])
  71.             ce_old = criterion(y_old, y_q[i]).detach()
  72.  
  73.             reward = ce_old - ce_new
  74.             critic_loss = -torch.tanh(reward)
  75.             critic_loss.backward(retain_graph=True)
  76.  
  77.             for p in clf.parameters():
  78.                 p.grad.data = grads[p]
  79.             # del grads
  80.  
  81.             critic_opt.step()
  82.             ce_new.backward(retain_graph=i < n_tasks - 1)
  83.             features.load_state_dict(diff_features.state_dict())
  84.             n_items = x_q.size(1)
  85.             ces.append(F.cross_entropy(y_new, y_q[i]).item())
  86.  
  87.             pred = y_new.data.max(1, keepdim=True)[1]
  88.             correct += pred.eq(y_q[i].data.view_as(pred)).sum().item()
  89.             num += n_items
  90.  
  91.         clf_opt.step()
  92.  
  93.         return {'s_aux': np.mean(losses_aux),
  94.                 's_ce': np.mean(losses_ce),
  95.                 'q_acc': correct / num,
  96.                 'q_ce': np.mean(ces)}
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement