Guest User

Untitled

a guest
May 22nd, 2018
97
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.70 KB | None | 0 0
  1. import numpy as np
  2. import tvm
  3. import topi
  4. import torch
  5. from topi.util import get_const_tuple
  6.  
  7. def verify_pool(n, ic, ih, kh, sh, padding, ceil_mode, count_include_pad=True):
  8. iw = ih
  9. kw = kh
  10. sw = sh
  11. A = tvm.placeholder((n, ic, ih, iw), name='A')
  12. B = topi.nn.pool(A, kernel=[kh, kw], stride=[sh, sw], padding=padding,
  13. pool_type='avg', ceil_mode=ceil_mode, count_include_pad=count_include_pad)
  14. B = topi.nn.relu(B)
  15. dtype = A.dtype
  16.  
  17. a_np = np.random.uniform(size=(n, ic, ih, iw)).astype(dtype)
  18.  
  19. avg_pool = torch.nn.AvgPool2d(kernel_size=(kh, kh), stride=(sh, sh), padding=padding, ceil_mode=ceil_mode,
  20. count_include_pad=count_include_pad)
  21. b_torch_np = avg_pool(torch.Tensor(a_np)).numpy()
  22.  
  23. def check_device(device):
  24. ctx = tvm.context(device, 0)
  25. if not ctx.exist:
  26. print("Skip because %s is not enabled" % device)
  27. return
  28. print("Running on target: %s" % device)
  29. with tvm.target.create(device):
  30. s = topi.generic.schedule_pool(B)
  31.  
  32. a = tvm.nd.array(a_np, ctx)
  33. b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
  34. f = tvm.build(s, [A, B], device)
  35. f(a, b)
  36. np.testing.assert_allclose(b.asnumpy(), b_torch_np, rtol=1e-5)
  37.  
  38. for device in ['cuda', 'llvm']:
  39. check_device(device)
  40.  
  41. def test_pool():
  42. verify_pool(1, 256, 32, 2, 2, [0, 0], False, True)
  43. verify_pool(1, 256, 31, 4, 4, [1, 2], False, True)
  44. verify_pool(1, 256, 32, 4, 4, [1, 2], False, False)
  45. verify_pool(1, 256, 31, 6, 6, [3, 3], False, False)
  46. verify_pool(1, 256, 31, 6, 6, [0, 0], False, False)
  47.  
  48. if __name__ == "__main__":
  49. test_pool()
Add Comment
Please, Sign In to add comment