Guest User

Untitled

a guest
Oct 19th, 2018
104
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.00 KB | None | 0 0
  1. #include <torch/extension.h>
  2. #include <cmath>
  3. #include <iostream>
  4. #include <vector>
  5.  
  6.  
  7. at::Tensor ex_forward(
  8. at::Tensor input
  9. ) {
  10. auto n_samples = input.size(0);
  11. auto n_features = input.size(1);
  12. auto G = n_features / 2;
  13. auto M = 2;
  14.  
  15. at::Tensor temp = at::zeros({n_samples, G, 2});
  16. at::Tensor slice1 = input.slice(1, 0, n_features, 2) + input.slice(1, 1, n_features, 2);
  17. at::Tensor slice2 = input.slice(1, 0, n_features, 2) - input.slice(1, 1, n_features, 2);
  18. temp = at::stack({slice1, slice2}, 2);
  19.  
  20. auto res = temp;
  21. for (auto dumb_idx = 0; dumb_idx < std::log2(n_features) + 1; dumb_idx++) {
  22. temp = at::zeros({n_samples, G / 2, M * 2});
  23. slice1 = res.slice(2, 0, M, 2).slice(1, 0, G, 2);
  24. slice2 = res.slice(2, 0, M, 2).slice(1, 1, G, 2);
  25. auto mesh1 = at::meshgrid({at::_cast_Long(at::arange(0, n_samples, 1)), at::_cast_Long(at::arange(0, G/2, 1)), at::_cast_Long(at::arange(0, 2 * M, 4))});
  26. auto mesh2 = at::meshgrid({at::_cast_Long(at::arange(0, n_samples, 1)), at::_cast_Long(at::arange(0, G/2, 1)), at::_cast_Long(at::arange(1, 2 * M, 4))});
  27. temp.index_put_(mesh1, slice1 + slice2);
  28. temp.index_put_(mesh2, slice1 - slice2);
  29. slice1 = res.slice(2, 1, M, 2).slice(1, 0, G, 2);
  30. slice2 = res.slice(2, 1, M, 2).slice(1, 1, G, 2);
  31. mesh1 = at::meshgrid({at::_cast_Long(at::arange(0, n_samples, 1)), at::_cast_Long(at::arange(0, G/2, 1)), at::_cast_Long(at::arange(2, 2 * M, 4))});
  32. mesh2 = at::meshgrid({at::_cast_Long(at::arange(0, n_samples, 1)), at::_cast_Long(at::arange(0, G/2, 1)), at::_cast_Long(at::arange(3, 2 * M, 4))});
  33. temp.index_put_(mesh1, slice1 - slice2);
  34. temp.index_put_(mesh2, slice1 + slice2);
  35. res = temp;
  36. G = G / 2;
  37. M = M * 2;
  38. }
  39. at::Tensor output = temp.select(1, 0); // select index 0 along dim 1
  40. return output * pow(std::sqrt(n_features), -1);
  41. }
  42.  
  43.  
  44. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  45. m.def("forward", &ex_forward, "EX forward");
  46. }
Add Comment
Please, Sign In to add comment