Advertisement
Guest User

Untitled

a guest
Apr 24th, 2019
81
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.33 KB | None | 0 0
  1. %matplotlib inline
  2. from matplotlib import pyplot as plt
  3. from tqdm import tqdm_notebook as tqdm
  4. import torch
  5.  
  6. x = torch.randn(1000, 1)
  7. x = x.sort(0)[0]
  8. y = x.mul(3).cos() + torch.randn(1000, 1) * 0.2
  9.  
  10. dropout = False
  11.  
  12. net = torch.nn.Sequential(
  13. torch.nn.Linear(1, 128),
  14. torch.nn.ReLU(),
  15. torch.nn.Dropout(0.2 if dropout else 0),
  16. torch.nn.Linear(128, 128),
  17. torch.nn.ReLU(),
  18. torch.nn.Dropout(0.2 if dropout else 0),
  19. torch.nn.Linear(128, 1))
  20.  
  21. opt = torch.optim.Adam(net.parameters())
  22.  
  23. for _ in tqdm(range(1000)):
  24. if dropout == False:
  25. p = torch.randperm(x.size(0))
  26. a = torch.rand(len(x), 1)
  27. xx = a * x + (1 - a) * x[p]
  28. yy = a * y + (1 - a) * y[p]
  29. else:
  30. xx = x
  31. yy = y
  32.  
  33. opt.zero_grad()
  34. (net(xx) - yy).pow(2).mean().backward()
  35. opt.step()
  36.  
  37. if dropout == False:
  38. net.eval()
  39.  
  40. x_test = torch.linspace(-5, 5, 100).view(-1, 1)
  41. for _ in range(10):
  42. if dropout:
  43. preds = net(x_test)
  44. else:
  45. a = .7 # torch.rand(len(x_test), 1)
  46. p = torch.randperm(len(x))[0]
  47.  
  48. xx = a * x_test + (1 - a) * x[p].view(1, -1)
  49. preds = (net(xx) - (1 - a) * y[p].view(1, 1)) / a
  50.  
  51. plt.plot(x_test.numpy(), preds.detach().numpy(), color="C0", alpha=0.2)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement