Advertisement
Guest User

Untitled

a guest
Apr 26th, 2017
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.26 KB | None | 0 0
  1. #include <numpy/arrayobject.h>
  2. #include "pybind11/pybind11.h"
  3. #include "pybind11/stl.h"
  4. #include "xtensor/xarray.hpp"
  5. #include "xtensor/xtensor.hpp"
  6. #include "xtensor/xcontainer.hpp"
  7. #include "xtensor/xbroadcast.hpp"
  8. //#include "xtensor/xbuilder.hpp"
  9. #include "xtensor/xview.hpp"
  10. #include "xtensor/xeval.hpp"
  11. #include "xtensor/xstridedview.hpp"
  12. #include "xtensor-python/pyarray.hpp"
  13. #include "xtensor-python/pytensor.hpp"
  14. #include <algorithm> // ?
  15.  
  16. namespace py = pybind11;
  17.  
  18. template<class E1>
  19. auto logsumexp1 (E1 const& e1) {
  20. using value_type = typename std::decay_t<E1>::value_type;
  21. auto max = xt::amax (e1)();
  22. return std::move (max + xt::log (xt::sum (xt::exp (e1-max))));
  23. }
  24.  
  25. template<class E1, class X>
  26. auto logsumexp2 (const E1& e1, X const& axes) {
  27. using value_type = typename std::decay_t<E1>::value_type;
  28. auto max = xt::eval(xt::amax(e1, axes));
  29. auto sv = xt::slice_vector(max);
  30. for (int i = 0; i < e1.dimension(); i++)
  31. {
  32. if (std::find (axes.begin(), axes.end(), i) != axes.end())
  33. sv.push_back(xt::newaxis());
  34. else
  35. sv.push_back(e1.shape()[i]);
  36. }
  37. auto max2 = xt::dynamic_view(max, sv);
  38. return xt::pyarray<value_type>(max2 + xt::log(xt::sum(xt::exp(e1 - max2), axes)));
  39. }
  40.  
  41. template<class value_type>
  42. auto normalize (xt::pyarray<value_type> const& e1) {
  43. auto shape = std::vector<size_t>{e1.shape().size()-1};
  44. auto ls = logsumexp2 (e1, shape);
  45. auto sv = xt::slice_vector(ls);
  46. for (int i = 0; i < e1.dimension()-1; i++)
  47. sv.push_back (xt::all());
  48. sv.push_back (xt::newaxis());
  49.  
  50. auto ls2 = xt::dynamic_view (ls, sv);
  51. return xt::pyarray<value_type> ((e1 - ls2));
  52. //return ls;
  53. }
  54.  
  55. PYBIND11_PLUGIN (logsumexp) {
  56. if (_import_array() < 0) {
  57. PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import");
  58. return nullptr;
  59. }
  60. py::module m("logsumexp", "pybind11 example plugin");
  61.  
  62. m.def("logsumexp", [](xt::pyarray<double>const& x) {
  63. return xt::pyarray<double> (xt::eval (logsumexp1 (x)));
  64. });
  65.  
  66. m.def("logsumexp", [](xt::pyarray<double>const& x, std::vector<size_t>const& ax) {
  67. //return xt::pyarray<double> ( (logsumexp2 (x, ax)));
  68. return logsumexp2 (x, ax);
  69. });
  70.  
  71. m.def("normalize", [](xt::pyarray<double>const& x) {
  72. return normalize (x);
  73. });
  74.  
  75. return m.ptr();
  76. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement