Guest User

Untitled

a guest
Jun 13th, 2018
64
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.25 KB | None | 0 0
  1. #include "pybind11/numpy.h" // vectorize
  2. #include "pybind11/pybind11.h"
  3. #include "pybind11/operators.h"
  4. #include "numpy/arrayobject.h"
  5. #include "ndarray/pybind11.h"
  6. #include <inttypes.h>
  7. #include <stdexcept>
  8. #include "xtensor/xarray.hpp"
  9. #include "xtensor/xtensor.hpp"
  10. #include "xtensor/xcontainer.hpp"
  11. #include "xtensor-python/pyarray.hpp"
  12. #include "xtensor-python/pyvectorize.hpp"
  13. #include "xtensor/xview.hpp"
  14. #include "xtensor-blas/xlinalg.hpp"
  15. #include "xtensor/xnorm.hpp"
  16. #include "xtensor/xeval.hpp"
  17.  
  18. namespace py = pybind11;
  19. #include <complex>
  20. #include <algorithm>
  21. //#include "apply.hpp"
  22.  
  23. template<typename flt_t>
  24. struct mag_sqr {
  25. typedef flt_t argument_type;
  26. typedef flt_t result_type;
  27.  
  28. flt_t operator()(flt_t x) const { return x * x; }
  29. };
  30.  
  31. typedef std::complex<double> complex_t;
  32. typedef std::complex<float> complex64_t;
  33.  
  34. template<typename flt_t>
  35. struct mag_sqr<std::complex<flt_t> > {
  36. typedef std::complex<flt_t> argument_type;
  37. typedef flt_t result_type;
  38.  
  39. flt_t operator()(std::complex<flt_t> x) const { return real(x) * real(x) + imag(x) * imag(x); }
  40. };
  41.  
  42. // template <typename T, int N, int C>
  43. // nd::Array<typename mag_sqr<T>::result_type,N,N> do_mag_sqr_flat (nd::Array<T,N,C> const& in) {
  44. // ndarray::Array<typename boost::remove_const<T>::type,1,1> flat_in = ndarray::flatten<1>(in);
  45. // ndarray::Array<typename mag_sqr<T>::result_type,N,N> out = ndarray::allocate(in.getShape());
  46. // ndarray::Array<typename mag_sqr<T>::result_type,1,1> flat_out = ndarray::flatten<1>(out);
  47. // int size = flat_in.template getSize<0>();
  48. // for (int n=0; n < size; ++n) {
  49. // flat_out[n] = mag_sqr<T>() (flat_in[n]);
  50. // }
  51.  
  52. // return out;
  53. // }
  54.  
  55. // template<typename flt_t>
  56. // flt_t do_mag_sqr (flt_t x) { return x * x; }
  57.  
  58. // template<typename flt_t>
  59. // flt_t do_mag_sqr (std::complex<flt_t> x) { return real(x) * real(x) + imag(x) * imag(x); }
  60.  
  61. // template<typename flt_t>
  62. // auto mag_sqr_vec (py::array_t<std::complex<flt_t>> a) {
  63. // return py::vectorize([](std::complex<flt_t> x) { return real(x) * real(x) + imag(x) * imag(x); })(a);
  64. // }
  65. // template<typename flt_t>
  66. // auto mag_sqr_vec (py::array_t<flt_t> a) {
  67. // return py::vectorize([](flt_t x) { return x * x; })(a);
  68. // }
  69.  
  70. PYBIND11_PLUGIN (mag_sqr) {
  71. if (_import_array() < 0) {
  72. PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import");
  73. return nullptr;
  74. }
  75.  
  76. py::module m("mag_sqr", "pybind11 example plugin");
  77.  
  78. py::object float32_obj = py::module::import("numpy").attr("float32");
  79. py::object complex64_obj = py::module::import("numpy").attr("complex64");
  80.  
  81. m.def("mag_sqr", [](py::array_t<double> a) {
  82. // print ("double");
  83. return py::vectorize([](double x) {return x * x;})(a);
  84. },
  85. py::arg("in").noconvert()
  86. );
  87. m.def("mag_sqr", [](py::array_t<float> a) {
  88. // print ("float");
  89. return py::vectorize([](float x) {return x * x;})(a);
  90. },
  91. py::arg("in").noconvert()
  92. );
  93. m.def("mag_sqr", [](py::array_t<std::complex<double>> a) {
  94. // print ("complex");
  95. return py::vectorize([](std::complex<double> x) {return real(x) * real(x) + imag(x) * imag(x);})(a);
  96. },
  97. py::arg("in").noconvert()
  98. );
  99. m.def("mag_sqr", [](py::array_t<std::complex<float>> a) {
  100. return py::vectorize([](std::complex<float> x) {return real(x) * real(x) + imag(x) * imag(x);})(a);
  101. },
  102. py::arg("in").noconvert()
  103. );
  104. m.def("mag_sqr", [](double x) { return x * x; });
  105. m.def("mag_sqr", [](float x) { return x * x; });
  106. m.def("mag_sqr", [](std::complex<double> x) { return real(x) * real(x) + imag(x) * imag(x);});
  107. m.def("mag_sqr", [](std::complex<float> x) { return real(x) * real(x) + imag(x) * imag(x);});
  108.  
  109. m.def ("xt_norm_2", [](xt::pyarray<std::complex<double>> x) { return xt::norm_l2 (x)(); },
  110. py::arg("in").noconvert()
  111. );
  112. m.def("xt_norm_2", [](xt::pyarray<double> x) { return xt::norm_l2 (x)(); },
  113. py::arg("in").noconvert()
  114. );
  115. m.def("norm_2", [](xt::pyarray<std::complex<double>> x) { return xt::linalg::norm (x, 2); },
  116. py::arg("in").noconvert()
  117. );
  118. m.def("norm_2", [](xt::pyarray<double> x) { return xt::linalg::norm (x, 2); },
  119. py::arg("in").noconvert()
  120. );
  121. m.def("norm_2", [](double x) { return xt::linalg::norm(xt::pyarray<double>(x), 2); });
  122. m.def("norm_2", [](std::complex<double> x) { return xt::linalg::norm(xt::pyarray<std::complex<double>>(x), 2); });
  123. // m.def("norm_2", [](xt::pyarray<double> x) { return xt::linalg::norm (x); });
  124. // using py::print;
  125. // m.def("mag_sqr", [complex64_obj, float32_obj](py::object a) {
  126. // if (py::isinstance<py::array_t<complex_t>>(a) or py::isinstance<complex_t>(a)) {
  127. // print ("complex");
  128. // //return mag_sqr_vec<complex_t>(*static_cast<py::array_t<complex_t>*>(&a));
  129. // return py::vectorize([](complex_t x) { return real(x) * real(x) + imag(x) * imag(x); })(a);
  130. // }
  131. // else if (py::isinstance<py::array_t<double>>(a) or py::isinstance<py::float_>(a)) {
  132. // print ("double");
  133. // return py::vectorize([](double x) { return x * x; })(a);
  134. // }
  135. // else if (py::isinstance<py::array_t<complex64_t>>(a) or PyObject_IsInstance(a.ptr(), complex64_obj.ptr())) {
  136. // print ("complex64");
  137. // return py::vectorize([](complex64_t x) { return real(x) * real(x) + imag(x) * imag(x); })(a);
  138. // }
  139. // else if (py::isinstance<py::array_t<float>>(a) or PyObject_IsInstance(a.ptr(), float32_obj.ptr())) {
  140. // print ("float");
  141. // return py::vectorize([](float x) { return x * x; })(a);
  142. // }
  143. // else
  144. // throw py::type_error("mag_sqr unhandled type");
  145. // });
  146. //m.def ("mag_sqr", &apply_1d<mag_sqr,complex_t>);
  147. // m.def ("mag_sqr", &apply_1d<mag_sqr,complex64_t>);
  148.  
  149. // m.def ("mag_sqr", &apply_1d<mag_sqr,double>);
  150. // m.def ("mag_sqr", &apply_1d<mag_sqr,float>);
  151.  
  152. // m.def ("mag_sqr", &apply_2d<mag_sqr,complex_t>);
  153. // m.def ("mag_sqr", &apply_2d<mag_sqr,complex64_t>);
  154.  
  155. // m.def ("mag_sqr", &apply_2d<mag_sqr,double>);
  156. // m.def ("mag_sqr", &apply_2d<mag_sqr,float>);
  157.  
  158. // m.def ("mag_sqr", &apply_3d<mag_sqr,complex_t>);
  159. // m.def ("mag_sqr", &apply_3d<mag_sqr,complex64_t>);
  160.  
  161. // m.def ("mag_sqr", &apply_3d<mag_sqr,double>);
  162. // m.def ("mag_sqr", &apply_3d<mag_sqr,float>);
  163.  
  164. // m.def ("mag_sqr", &apply_scalar<mag_sqr,double>);
  165. // m.def ("mag_sqr", &apply_scalar<mag_sqr,complex_t>);
  166.  
  167. // m.def ("mag_sqr_flat", &do_mag_sqr_flat<double,1,1>);
  168. // m.def ("mag_sqr_flat", &do_mag_sqr_flat<complex_t,1,1> );
  169.  
  170. // // m.def ("norm_2", [](py::object o) {
  171. // // if (py::isinstance<py::array_t<double>>(o)) {
  172. // // py::print("double");
  173. // // return py::cast<nd::Array<double,1>>(o).asEigen().norm();
  174. // // }
  175.  
  176. // // if (py::isinstance<py::array_t<complex_t>>(o))
  177. // // return py::cast<nd::Array<complex_t,1>>(o).asEigen().norm();
  178. // // else
  179. // // throw py::type_error("norm_2 unhandled type");
  180. // // });
  181.  
  182. // m.def ("norm_2", [](nd::Array<double,1> in) {
  183. // // py::print("double");
  184. // return in.asEigen().norm();
  185. // });
  186.  
  187. // m.def ("norm_2", [](nd::Array<float,1> in) {
  188. // // py::print("float");
  189. // return in.asEigen().norm();
  190. // });
  191.  
  192. // m.def ("norm_2", [](nd::Array<complex_t,1> in) {
  193. // // py::print("complex");
  194. // return in.asEigen().norm();
  195. // });
  196.  
  197. return m.ptr();
  198. }
Add Comment
Please, Sign In to add comment