Advertisement
Guest User

Untitled

a guest
Jan 29th, 2020
83
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.43 KB | None | 0 0
  1.  
  2. def train_one_epoch(features, clf, critic, db):
  3.     n_train_iter = db.x_train.shape[0] // batch_size
  4.  
  5.     losses_aux = []
  6.     losses_ce = []
  7.     ces = []
  8.     num = 0
  9.     correct = 0
  10.  
  11.     for batch_idx in range(n_train_iter):
  12.         x_s, y_s, x_q, y_q = db.next("train")
  13.         n_tasks = x_s.size(0)
  14.  
  15.         f_clf = higher.patch.monkeypatch(clf, copy_initial_weights=False)
  16.  
  17.         inner_opt = torch.optim.Adam(clf.parameters(), lr=args.lr_clf)
  18.         diffopt = higher.optim.get_diff_optim(inner_opt, clf.parameters(),
  19.                                               f_clf)
  20.  
  21.         diff_features = higher.patch.monkeypatch(features, copy_initial_weights=False)
  22.         inner_opt_features = torch.optim.Adam(features.parameters(), lr=args.lr_features)
  23.         diffopt_features = higher.optim.get_diff_optim(inner_opt_features,
  24.                                                     features.parameters(),
  25.                                                     diff_features)
  26.  
  27.         diff_clf = higher.patch.monkeypatch(clf, copy_initial_weights=False)
  28.         inner_opt_clf = torch.optim.Adam(clf.parameters(), lr=args.lr_clf)
  29.         diffopt_clf = higher.optim.get_diff_optim(inner_opt_clf,
  30.                                                    clf.parameters(), diff_clf)
  31.  
  32.         # TODO: too long for one epoch
  33.         # TODO: cross entropy is not optimised?
  34.         features_old.load_state_dict(features.state_dict())
  35.         clf_opt.zero_grad()
  36.         for i in range(n_tasks):
  37.             for j in range(inner_steps):
  38.                 z_new = diff_features(x_s[i])
  39.                 z_old = features_old(x_s[i])
  40.  
  41.                 y_new = diff_clf(z_new)
  42.                 y_old = diff_clf(z_old)
  43.  
  44.                 ce_new = criterion(y_new, y_s[i])
  45.                 ce_old = criterion(y_old, y_s[i])
  46.  
  47.                 aux_loss = torch.mean(critic(z_new))
  48.  
  49.                 features_old_opt.zero_grad()
  50.                 ce_old.backward(retain_graph=True)
  51.                 features_old_opt.step()
  52.  
  53.                 diffopt_clf.step(ce_new + aux_loss)
  54.                 diffopt_features.step(ce_new + aux_loss)
  55.  
  56.                 losses_ce.append(ce_new.item())
  57.                 losses_aux.append(aux_loss.item())
  58.  
  59.             critic_opt.zero_grad()
  60.  
  61.             grads = {}
  62.             for p in clf.parameters():
  63.                 grads.update({p: torch.tensor(p.grad)})
  64.  
  65.             z_new = diff_features(x_q[i])
  66.             z_old = features_old(x_q[i])
  67.  
  68.             y_new = diff_clf(z_new)
  69.             y_old = diff_clf(z_old)
  70.  
  71.             ce_new = criterion(y_new, y_q[i])
  72.             ce_old = criterion(y_old, y_q[i]).detach()
  73.  
  74.             reward = ce_old - ce_new
  75.             critic_loss = -torch.tanh(reward)
  76.             critic_loss.backward(retain_graph=True)
  77.  
  78.             for p in clf.parameters():
  79.                 p.grad.data.fill_(grads[p])
  80.  
  81.             critic_opt.step()
  82.             ce_new.backward(retain_graph=i<n_tasks - 1)
  83.  
  84.             features.load_state_dict(diff_features.state_dict())
  85.             n_items = x_q.size(1)
  86.             ces.append(F.cross_entropy(y_new, y_q[i]).item())
  87.  
  88.             pred = y_new.data.max(1, keepdim=True)[1]
  89.             correct += pred.eq(y_q[i].data.view_as(pred)).sum().item()
  90.             num += n_items
  91.  
  92.         clf_opt.step()
  93.  
  94.         return {'s_aux': np.mean(losses_aux),
  95.                 's_ce': np.mean(losses_ce),
  96.                 'q_acc': correct / num,
  97.                 'q_ce': np.mean(ces)}
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement