Advertisement
Guest User

Untitled

a guest
Jul 15th, 2019
89
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.27 KB | None | 0 0
  1. index ede2865..de5eb9f 100755
  2. --- a/examples/maml-omniglot.py
  3. +++ b/examples/maml-omniglot.py
  4. @@ -30,6 +30,7 @@ import higher
  5.  
  6. from omniglot_loaders import OmniglotNShot
  7.  
  8. +
  9. def main():
  10. argparser = argparse.ArgumentParser()
  11. argparser.add_argument('--n_way', type=int, help='n way', default=5)
  12. @@ -58,23 +59,22 @@ def main():
  13. # Before higher, models could *not* be created like this
  14. # and the parameters needed to be manually updated and copied
  15. # for the updates.
  16. +
  17. net = nn.Sequential(
  18. - nn.Conv2d(1, 64, 3, 2),
  19. - nn.ReLU(inplace=True),
  20. - nn.BatchNorm2d(64),
  21. - nn.Conv2d(64, 64, 3, 2),
  22. + nn.Conv2d(1, 64, 3),
  23. + nn.BatchNorm2d(64, momentum=1, affine=True),
  24. nn.ReLU(inplace=True),
  25. - nn.BatchNorm2d(64),
  26. - nn.Conv2d(64, 64, 3, 2),
  27. + nn.MaxPool2d(2, 2),
  28. + nn.Conv2d(64, 64, 3),
  29. + nn.BatchNorm2d(64, momentum=1, affine=True),
  30. nn.ReLU(inplace=True),
  31. - nn.BatchNorm2d(64),
  32. - nn.Conv2d(64, 64, 2, 1),
  33. + nn.MaxPool2d(2, 2),
  34. + nn.Conv2d(64, 64, 3),
  35. + nn.BatchNorm2d(64, momentum=1, affine=True),
  36. nn.ReLU(inplace=True),
  37. - nn.BatchNorm2d(64),
  38. + nn.MaxPool2d(2,2),
  39. Flatten(),
  40. - nn.Linear(64, args.n_way)
  41. - ).to(device)
  42. -
  43. + nn.Linear(64, args.n_way)).to(device)
  44. # We will use Adam to (meta-)optimize the initial parameters
  45. # to be adapted.
  46. meta_opt = optim.Adam(net.parameters(), lr=1e-3)
  47. @@ -90,6 +90,7 @@ def train(db, net, device, meta_opt, epoch, log):
  48. net.train()
  49. n_train_iter = db.x_train.shape[0] // db.batchsz
  50.  
  51. +
  52. for batch_idx in range(n_train_iter):
  53. # Sample a batch of support and query images and labels.
  54. x_spt, y_spt, x_qry, y_qry = db.next()
  55. @@ -107,7 +108,7 @@ def train(db, net, device, meta_opt, epoch, log):
  56. # Initialize the inner optimizer to adapt the parameters to
  57. # the support set.
  58. n_inner_iter = 1
  59. - inner_opt = torch.optim.SGD(net.parameters(), lr=4e-1)
  60. + inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
  61.  
  62. qry_losses = []
  63. qry_accs = []
  64. @@ -167,6 +168,7 @@ def test(db, net, device, epoch, log):
  65.  
  66. qry_losses = []
  67. qry_accs = []
  68. +
  69.  
  70. for batch_idx in range(n_test_iter):
  71. x_spt, y_spt, x_qry, y_qry = db.next('test')
  72. @@ -180,8 +182,8 @@ def test(db, net, device, epoch, log):
  73.  
  74. # TODO: Maybe pull this out into a separate module so it
  75. # doesn't have to be duplicated between `train` and `test`?
  76. - n_inner_iter = 3
  77. - inner_opt = torch.optim.SGD(net.parameters(), lr=4e-1)
  78. + n_inner_iter = 5
  79. + inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
  80.  
  81. for i in range(task_num):
  82. with higher.innerloop_ctx(net, inner_opt) as (fnet, diffopt):
  83. @@ -195,7 +197,7 @@ def test(db, net, device, epoch, log):
  84.  
  85. # The query loss and acc induced by these parameters.
  86. qry_logits = fnet(x_qry[i]).detach()
  87. - qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none')
  88. + qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none').detach()
  89. qry_losses.append(qry_loss)
  90. qry_accs.append(
  91. qry_logits.argmax(dim=1) == y_qry[i]
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement