Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Code in file nn/two_layer_net_nn.py
- import torch
- D_in = 40
- model = torch.load('model.pytorch')
- device = torch.device('cpu')
- def loss1(y_pred,x):
- return (y_pred*(0.5-x.clamp(0,1))).sum()
- def loss2(y_pred,x):
- return (y_pred*(1-x.clamp(0,1))).sum()
- # Predict random input
- x = torch.rand(1,D_in, device=device,requires_grad=True)
- y_pred = model(x)
- # Is this
- %%timeit
- loss = loss1(y_pred,x)
- loss.backward(retain_graph=True)
- 202 µs ± 4.34 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
- # Slower than this?
- %%timeit
- loss = loss2(y_pred,x)
- loss.backward(retain_graph=True)
- 216 µs ± 27.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
- # Are successive backwards calls cheap?
- loss = lossX(y_pred,x)
- loss.backward(retain_graph=True)
- import numpy as np
- import torch
- import torch.nn as nn
- import time
- import os
- import psutil
- D_in = 1024
- model = nn.Sequential(nn.Linear(1024, 4096), nn.ReLU(), nn.Linear(4096, 4096), nn.ReLU(), nn.Linear(4096, 1024))
- device = torch.device('cpu')
- def loss1(y_pred,x):
- return (y_pred*(0.5-x.clamp(0,1))).sum()
- def loss2(y_pred,x):
- return (y_pred*(1-x.clamp(0,1))).sum()
- def timeit(func, repetitions):
- time_taken = []
- mem_used = []
- for _ in range(repetitions):
- time_start = time.time()
- mem_used.append(func())
- time_taken.append(time.time() - time_start)
- return np.round([np.mean(time_taken), np.min(time_taken), np.max(time_taken),
- np.mean(mem_used), np.min(mem_used), np.max(mem_used)], 4).tolist()
- # Predict random input
- x = torch.rand(1,D_in, device=device,requires_grad=True)
- def init():
- out = model(x)
- loss = loss1(out, x)
- loss.backward()
- def func1():
- x = torch.rand(1, D_in, device=device, requires_grad=True)
- loss = loss1(model(x),x)
- loss.backward()
- loss = loss2(model(x),x)
- loss.backward()
- del x
- process = psutil.Process(os.getpid())
- return process.memory_info().rss
- def func2():
- x = torch.rand(1, D_in, device=device, requires_grad=True)
- loss = loss1(model(x),x) + loss2(model(x),x)
- loss.backward()
- del x
- process = psutil.Process(os.getpid())
- return process.memory_info().rss
- def func3():
- x = torch.rand(1, D_in, device=device, requires_grad=True)
- loss = loss1(model(x),x)
- loss.backward(retain_graph=True)
- loss = loss2(model(x),x)
- loss.backward(retain_graph=True)
- del x
- process = psutil.Process(os.getpid())
- return process.memory_info().rss
- def func4():
- x = torch.rand(1, D_in, device=device, requires_grad=True)
- loss = loss1(model(x),x) + loss2(model(x),x)
- loss.backward(retain_graph=True)
- del x
- process = psutil.Process(os.getpid())
- return process.memory_info().rss
- init()
- print(timeit(func1, 100))
- print(timeit(func2, 100))
- print(timeit(func3, 100))
- print(timeit(func4, 100))
- # time mean, time min, time max, memory mean, memory min, memory max
- [0.1165, 0.1138, 0.1297, 383456419.84, 365731840.0, 384438272.0]
- [0.127, 0.1233, 0.1376, 400914759.68, 399638528.0, 434044928.0]
- [0.1167, 0.1136, 0.1272, 400424468.48, 399577088.0, 401223680.0]
- [0.1263, 0.1226, 0.134, 400815964.16, 399556608.0, 434307072.0]
- # time mean, time min, time max, memory mean, memory min, memory max
- [0.1208, 0.1136, **0.1579**, 350157455.36, 349331456.0, 350978048.0]
- [0.1297, 0.1232, 0.1499, 393928540.16, 350052352.0, 401854464.0]
- [0.1197, 0.1152, 0.1547, 350787338.24, 349982720.0, 351629312.0]
- [0.1335, 0.1229, 0.1793, 382819123.2, 349929472.0, 401776640.0]
Add Comment
Please, Sign In to add comment