Guest User

Untitled

a guest
Sep 26th, 2018
103
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.45 KB | None | 0 0
  1. # Code in file nn/two_layer_net_nn.py
  2. import torch
  3.  
  4. D_in = 40
  5. model = torch.load('model.pytorch')
  6. device = torch.device('cpu')
  7. def loss1(y_pred,x):
  8. return (y_pred*(0.5-x.clamp(0,1))).sum()
  9.  
  10. def loss2(y_pred,x):
  11. return (y_pred*(1-x.clamp(0,1))).sum()
  12.  
  13.  
  14. # Predict random input
  15. x = torch.rand(1,D_in, device=device,requires_grad=True)
  16. y_pred = model(x)
  17.  
  18. # Is this
  19. %%timeit
  20. loss = loss1(y_pred,x)
  21. loss.backward(retain_graph=True)
  22.  
  23. 202 µs ± 4.34 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
  24.  
  25. # Slower than this?
  26. %%timeit
  27. loss = loss2(y_pred,x)
  28. loss.backward(retain_graph=True)
  29.  
  30. 216 µs ± 27.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
  31.  
  32.  
  33. # Are successive backwards calls cheap?
  34. loss = lossX(y_pred,x)
  35. loss.backward(retain_graph=True)
  36.  
  37. import numpy as np
  38. import torch
  39. import torch.nn as nn
  40. import time
  41. import os
  42. import psutil
  43. D_in = 1024
  44. model = nn.Sequential(nn.Linear(1024, 4096), nn.ReLU(), nn.Linear(4096, 4096), nn.ReLU(), nn.Linear(4096, 1024))
  45. device = torch.device('cpu')
  46. def loss1(y_pred,x):
  47. return (y_pred*(0.5-x.clamp(0,1))).sum()
  48.  
  49.  
  50. def loss2(y_pred,x):
  51. return (y_pred*(1-x.clamp(0,1))).sum()
  52.  
  53.  
  54. def timeit(func, repetitions):
  55. time_taken = []
  56. mem_used = []
  57. for _ in range(repetitions):
  58. time_start = time.time()
  59. mem_used.append(func())
  60. time_taken.append(time.time() - time_start)
  61. return np.round([np.mean(time_taken), np.min(time_taken), np.max(time_taken),
  62. np.mean(mem_used), np.min(mem_used), np.max(mem_used)], 4).tolist()
  63.  
  64.  
  65. # Predict random input
  66. x = torch.rand(1,D_in, device=device,requires_grad=True)
  67.  
  68. def init():
  69. out = model(x)
  70. loss = loss1(out, x)
  71. loss.backward()
  72.  
  73. def func1():
  74. x = torch.rand(1, D_in, device=device, requires_grad=True)
  75. loss = loss1(model(x),x)
  76. loss.backward()
  77. loss = loss2(model(x),x)
  78. loss.backward()
  79. del x
  80. process = psutil.Process(os.getpid())
  81. return process.memory_info().rss
  82.  
  83. def func2():
  84. x = torch.rand(1, D_in, device=device, requires_grad=True)
  85. loss = loss1(model(x),x) + loss2(model(x),x)
  86. loss.backward()
  87. del x
  88. process = psutil.Process(os.getpid())
  89. return process.memory_info().rss
  90.  
  91.  
  92. def func3():
  93. x = torch.rand(1, D_in, device=device, requires_grad=True)
  94. loss = loss1(model(x),x)
  95. loss.backward(retain_graph=True)
  96. loss = loss2(model(x),x)
  97. loss.backward(retain_graph=True)
  98. del x
  99. process = psutil.Process(os.getpid())
  100. return process.memory_info().rss
  101.  
  102.  
  103. def func4():
  104. x = torch.rand(1, D_in, device=device, requires_grad=True)
  105. loss = loss1(model(x),x) + loss2(model(x),x)
  106. loss.backward(retain_graph=True)
  107. del x
  108. process = psutil.Process(os.getpid())
  109. return process.memory_info().rss
  110.  
  111. init()
  112. print(timeit(func1, 100))
  113. print(timeit(func2, 100))
  114. print(timeit(func3, 100))
  115. print(timeit(func4, 100))
  116.  
  117. # time mean, time min, time max, memory mean, memory min, memory max
  118. [0.1165, 0.1138, 0.1297, 383456419.84, 365731840.0, 384438272.0]
  119. [0.127, 0.1233, 0.1376, 400914759.68, 399638528.0, 434044928.0]
  120. [0.1167, 0.1136, 0.1272, 400424468.48, 399577088.0, 401223680.0]
  121. [0.1263, 0.1226, 0.134, 400815964.16, 399556608.0, 434307072.0]
  122.  
  123. # time mean, time min, time max, memory mean, memory min, memory max
  124. [0.1208, 0.1136, **0.1579**, 350157455.36, 349331456.0, 350978048.0]
  125. [0.1297, 0.1232, 0.1499, 393928540.16, 350052352.0, 401854464.0]
  126. [0.1197, 0.1152, 0.1547, 350787338.24, 349982720.0, 351629312.0]
  127. [0.1335, 0.1229, 0.1793, 382819123.2, 349929472.0, 401776640.0]
Add Comment
Please, Sign In to add comment