Advertisement
Guest User

Untitled

a guest
Oct 26th, 2017
288
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.67 KB | None | 0 0
  1. import torch
  2. import time
  3. from torch.autograd import Variable
  4. import numpy as np
  5.  
  6. """
  7. This program gets slower and slower by time without the c = z.detach()...
  8. No noticeable change in memory consumption seen on nvidia-smi.
  9. """
  10.  
  11. usecuda = 1
  12.  
  13. dtype = torch.FloatTensor
  14. N=1
  15. if usecuda:
  16.   N = 100*N
  17.   dtype = torch.cuda.FloatTensor
  18.  
  19. xdim = (100,50,2)
  20.  
  21. x0 = torch.randn(*xdim).type(dtype)
  22. z0 = torch.randn(N, xdim[0]).type(dtype)
  23.  
  24. x = Variable(x0, requires_grad=True)
  25. z = Variable(z0, requires_grad=False)
  26.  
  27.  
  28. t0 = time.clock()
  29. tsum = np.zeros(5)
  30. tsump = np.zeros(5)
  31. i = 0
  32. M=100
  33. times = []
  34. for i in range(M*30):
  35.   t1 = time.clock()
  36.   b = x.repeat(N,1,1,1).view(N, -1, 2).sum(2).view(N, xdim[0], xdim[1])
  37.   c = z.detach().view(N, xdim[0], 1).expand(N, xdim[0], xdim[1])/b
  38.   loss = c.sum()
  39.   t2 = time.clock()
  40.   loss.backward()
  41.   t3 = time.clock()
  42.   z -= 0.1*b.view(N, -1 , xdim[1]).lt(0.1).sum(2).type(dtype)
  43.   t4 = time.clock()
  44.   x.data -= 0.0000000001*x.grad.data
  45.   x.grad.data.zero_()
  46.   t5 = time.clock()
  47.   tsum += np.array((t1 - t0, t2 - t1, t3 - t2, t4 - t3, t5 - t4))
  48.   t0 = t5
  49.   if i%M == M-1:
  50.     times.append((tsum, tsum - tsump))
  51.     print(tsum, tsum - tsump)
  52.     #print("%.2fs +%.1f%%" % (tsum, tsum*100/tsump - 100))
  53.     tsump = tsum
  54.     tsum = np.zeros(5)
  55.   i = i + 1
  56.  
  57. from matplotlib import pyplot as plt
  58. t = np.arange(len(times))
  59. plt.plot(t, np.array(times)[:,0,0], 'r', label='t0..1')
  60. plt.plot(t, np.array(times)[:,0,1], 'g', label='t1..2')
  61. plt.plot(t, np.array(times)[:,0,2], 'b', label='t2..3')
  62. plt.plot(t, np.array(times)[:,0,3], 'y', label='t3..4')
  63. plt.plot(t, np.array(times)[:,0,4], 'w', label='t4..5')
  64. plt.legend(loc=2)
  65. plt.show()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement