Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def train_one_epoch(features, clf, critic, db, epoch):
- n_train_iter = db.x_train.shape[0] // batch_size
- losses_aux = []
- losses_ce = []
- ces = []
- num = 0
- correct = 0
- for batch_idx in range(n_train_iter):
- x_s, y_s, x_q, y_q = db.next("train")
- n_tasks = x_s.size(0)
- f_clf = higher.patch.monkeypatch(clf, copy_initial_weights=False)
- inner_opt = torch.optim.Adam(clf.parameters(), lr=args.lr_clf)
- diffopt = higher.optim.get_diff_optim(inner_opt, clf.parameters(),
- f_clf)
- diff_features = higher.patch.monkeypatch(features, copy_initial_weights=False)
- inner_opt_features = torch.optim.Adam(features.parameters(), lr=args.lr_features)
- diffopt_features = higher.optim.get_diff_optim(inner_opt_features,
- features.parameters(),
- diff_features)
- diff_clf = higher.patch.monkeypatch(clf, copy_initial_weights=False)
- inner_opt_clf = torch.optim.Adam(clf.parameters(), lr=args.lr_clf)
- diffopt_clf = higher.optim.get_diff_optim(inner_opt_clf,
- clf.parameters(), diff_clf)
- # features_old.load_state_dict(features.state_dict())
- clf_opt.zero_grad()
- for i in range(n_tasks):
- features_old.load_state_dict(diff_features.state_dict())
- for j in range(inner_steps):
- z_new = diff_features(x_s[i])
- z_old = features_old(x_s[i])
- y_new = diff_clf(z_new)
- y_old = diff_clf(z_old)
- ce_new = criterion(y_new, y_s[i])
- ce_old = criterion(y_old, y_s[i])
- aux_loss = torch.mean(critic(z_new))
- features_old_opt.zero_grad()
- ce_old.backward(retain_graph=True)
- features_old_opt.step()
- coef = epoch / args.epochs
- diffopt_clf.step(ce_new + coef * aux_loss)
- diffopt_features.step(ce_new + coef * aux_loss)
- losses_ce.append(ce_new.item())
- losses_aux.append(aux_loss.item())
- critic_opt.zero_grad()
- grads = {}
- for p in clf.parameters():
- grads.update({p: p.grad.clone()})
- z_new = diff_features(x_q[i])
- z_old = features_old(x_q[i])
- y_new = diff_clf(z_new)
- y_old = diff_clf(z_old)
- ce_new = criterion(y_new, y_q[i])
- ce_old = criterion(y_old, y_q[i]).detach()
- reward = ce_old - ce_new
- critic_loss = -torch.tanh(reward)
- critic_loss.backward(retain_graph=True)
- for p in clf.parameters():
- p.grad.data = grads[p]
- # del grads
- critic_opt.step()
- ce_new.backward(retain_graph=i < n_tasks - 1)
- features.load_state_dict(diff_features.state_dict())
- n_items = x_q.size(1)
- ces.append(F.cross_entropy(y_new, y_q[i]).item())
- pred = y_new.data.max(1, keepdim=True)[1]
- correct += pred.eq(y_q[i].data.view_as(pred)).sum().item()
- num += n_items
- clf_opt.step()
- return {'s_aux': np.mean(losses_aux),
- 's_ce': np.mean(losses_ce),
- 'q_acc': correct / num,
- 'q_ce': np.mean(ces)}
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement