Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- %matplotlib inline
- from matplotlib import pyplot as plt
- from tqdm import tqdm_notebook as tqdm
- import torch
- x = torch.randn(1000, 1)
- x = x.sort(0)[0]
- y = x.mul(3).cos() + torch.randn(1000, 1) * 0.2
- dropout = False
- net = torch.nn.Sequential(
- torch.nn.Linear(1, 128),
- torch.nn.ReLU(),
- torch.nn.Dropout(0.2 if dropout else 0),
- torch.nn.Linear(128, 128),
- torch.nn.ReLU(),
- torch.nn.Dropout(0.2 if dropout else 0),
- torch.nn.Linear(128, 1))
- opt = torch.optim.Adam(net.parameters())
- for _ in tqdm(range(1000)):
- if dropout == False:
- p = torch.randperm(x.size(0))
- a = torch.rand(len(x), 1)
- xx = a * x + (1 - a) * x[p]
- yy = a * y + (1 - a) * y[p]
- else:
- xx = x
- yy = y
- opt.zero_grad()
- (net(xx) - yy).pow(2).mean().backward()
- opt.step()
- if dropout == False:
- net.eval()
- x_test = torch.linspace(-5, 5, 100).view(-1, 1)
- for _ in range(10):
- if dropout:
- preds = net(x_test)
- else:
- a = .7 # torch.rand(len(x_test), 1)
- p = torch.randperm(len(x))[0]
- xx = a * x_test + (1 - a) * x[p].view(1, -1)
- preds = (net(xx) - (1 - a) * y[p].view(1, 1)) / a
- plt.plot(x_test.numpy(), preds.detach().numpy(), color="C0", alpha=0.2)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement