Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- index ede2865..de5eb9f 100755
- --- a/examples/maml-omniglot.py
- +++ b/examples/maml-omniglot.py
- @@ -30,6 +30,7 @@ import higher
- from omniglot_loaders import OmniglotNShot
- +
- def main():
- argparser = argparse.ArgumentParser()
- argparser.add_argument('--n_way', type=int, help='n way', default=5)
- @@ -58,23 +59,22 @@ def main():
- # Before higher, models could *not* be created like this
- # and the parameters needed to be manually updated and copied
- # for the updates.
- +
- net = nn.Sequential(
- - nn.Conv2d(1, 64, 3, 2),
- - nn.ReLU(inplace=True),
- - nn.BatchNorm2d(64),
- - nn.Conv2d(64, 64, 3, 2),
- + nn.Conv2d(1, 64, 3),
- + nn.BatchNorm2d(64, momentum=1, affine=True),
- nn.ReLU(inplace=True),
- - nn.BatchNorm2d(64),
- - nn.Conv2d(64, 64, 3, 2),
- + nn.MaxPool2d(2, 2),
- + nn.Conv2d(64, 64, 3),
- + nn.BatchNorm2d(64, momentum=1, affine=True),
- nn.ReLU(inplace=True),
- - nn.BatchNorm2d(64),
- - nn.Conv2d(64, 64, 2, 1),
- + nn.MaxPool2d(2, 2),
- + nn.Conv2d(64, 64, 3),
- + nn.BatchNorm2d(64, momentum=1, affine=True),
- nn.ReLU(inplace=True),
- - nn.BatchNorm2d(64),
- + nn.MaxPool2d(2,2),
- Flatten(),
- - nn.Linear(64, args.n_way)
- - ).to(device)
- -
- + nn.Linear(64, args.n_way)).to(device)
- # We will use Adam to (meta-)optimize the initial parameters
- # to be adapted.
- meta_opt = optim.Adam(net.parameters(), lr=1e-3)
- @@ -90,6 +90,7 @@ def train(db, net, device, meta_opt, epoch, log):
- net.train()
- n_train_iter = db.x_train.shape[0] // db.batchsz
- +
- for batch_idx in range(n_train_iter):
- # Sample a batch of support and query images and labels.
- x_spt, y_spt, x_qry, y_qry = db.next()
- @@ -107,7 +108,7 @@ def train(db, net, device, meta_opt, epoch, log):
- # Initialize the inner optimizer to adapt the parameters to
- # the support set.
- n_inner_iter = 1
- - inner_opt = torch.optim.SGD(net.parameters(), lr=4e-1)
- + inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
- qry_losses = []
- qry_accs = []
- @@ -167,6 +168,7 @@ def test(db, net, device, epoch, log):
- qry_losses = []
- qry_accs = []
- +
- for batch_idx in range(n_test_iter):
- x_spt, y_spt, x_qry, y_qry = db.next('test')
- @@ -180,8 +182,8 @@ def test(db, net, device, epoch, log):
- # TODO: Maybe pull this out into a separate module so it
- # doesn't have to be duplicated between `train` and `test`?
- - n_inner_iter = 3
- - inner_opt = torch.optim.SGD(net.parameters(), lr=4e-1)
- + n_inner_iter = 5
- + inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
- for i in range(task_num):
- with higher.innerloop_ctx(net, inner_opt) as (fnet, diffopt):
- @@ -195,7 +197,7 @@ def test(db, net, device, epoch, log):
- # The query loss and acc induced by these parameters.
- qry_logits = fnet(x_qry[i]).detach()
- - qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none')
- + qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none').detach()
- qry_losses.append(qry_loss)
- qry_accs.append(
- qry_logits.argmax(dim=1) == y_qry[i]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement