Advertisement
Guest User

Untitled

a guest
Oct 26th, 2017
214
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import torch
  2. import time
  3. from torch.autograd import Variable
  4. import numpy as np
  5.  
  6. """
  7. This program gets slower and slower by time without the c = z.detach()...
  8. No noticeable change in memory consumption seen on nvidia-smi.
  9. """
  10.  
  11. usecuda = 1
  12.  
  13. dtype = torch.FloatTensor
  14. N=1
  15. if usecuda:
  16.   N = 100*N
  17.   dtype = torch.cuda.FloatTensor
  18.  
  19. xdim = (100,50,2)
  20.  
  21. x0 = torch.randn(*xdim).type(dtype)
  22. z0 = torch.randn(N, xdim[0]).type(dtype)
  23.  
  24. x = Variable(x0, requires_grad=True)
  25. z = Variable(z0, requires_grad=False)
  26.  
  27.  
  28. t0 = time.clock()
  29. tsum = np.zeros(5)
  30. tsump = np.zeros(5)
  31. i = 0
  32. M=100
  33. times = []
  34. for i in range(M*30):
  35.   t1 = time.clock()
  36.   b = x.repeat(N,1,1,1).view(N, -1, 2).sum(2).view(N, xdim[0], xdim[1])
  37.   c = z.detach().view(N, xdim[0], 1).expand(N, xdim[0], xdim[1])/b
  38.   loss = c.sum()
  39.   t2 = time.clock()
  40.   loss.backward()
  41.   t3 = time.clock()
  42.   z -= 0.1*b.view(N, -1 , xdim[1]).lt(0.1).sum(2).type(dtype)
  43.   t4 = time.clock()
  44.   x.data -= 0.0000000001*x.grad.data
  45.   x.grad.data.zero_()
  46.   t5 = time.clock()
  47.   tsum += np.array((t1 - t0, t2 - t1, t3 - t2, t4 - t3, t5 - t4))
  48.   t0 = t5
  49.   if i%M == M-1:
  50.     times.append((tsum, tsum - tsump))
  51.     print(tsum, tsum - tsump)
  52.     #print("%.2fs +%.1f%%" % (tsum, tsum*100/tsump - 100))
  53.     tsump = tsum
  54.     tsum = np.zeros(5)
  55.   i = i + 1
  56.  
  57. from matplotlib import pyplot as plt
  58. t = np.arange(len(times))
  59. plt.plot(t, np.array(times)[:,0,0], 'r', label='t0..1')
  60. plt.plot(t, np.array(times)[:,0,1], 'g', label='t1..2')
  61. plt.plot(t, np.array(times)[:,0,2], 'b', label='t2..3')
  62. plt.plot(t, np.array(times)[:,0,3], 'y', label='t3..4')
  63. plt.plot(t, np.array(times)[:,0,4], 'w', label='t4..5')
  64. plt.legend(loc=2)
  65. plt.show()
Advertisement
RAW Paste Data Copied
Advertisement