Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <numpy/arrayobject.h>
- #include "pybind11/pybind11.h"
- #include "pybind11/stl.h"
- #include "xtensor/xarray.hpp"
- #include "xtensor/xtensor.hpp"
- #include "xtensor/xcontainer.hpp"
- #include "xtensor/xbroadcast.hpp"
- //#include "xtensor/xbuilder.hpp"
- #include "xtensor/xview.hpp"
- #include "xtensor/xeval.hpp"
- #include "xtensor/xstridedview.hpp"
- #include "xtensor-python/pyarray.hpp"
- #include "xtensor-python/pytensor.hpp"
- #include <algorithm> // ?
- namespace py = pybind11;
- template<class E1>
- auto logsumexp1 (E1 const& e1) {
- using value_type = typename std::decay_t<E1>::value_type;
- auto max = xt::amax (e1)();
- return std::move (max + xt::log (xt::sum (xt::exp (e1-max))));
- }
- template<class E1, class X>
- auto logsumexp2 (const E1& e1, X const& axes) {
- using value_type = typename std::decay_t<E1>::value_type;
- auto max = xt::eval(xt::amax(e1, axes));
- auto sv = xt::slice_vector(max);
- for (int i = 0; i < e1.dimension(); i++)
- {
- if (std::find (axes.begin(), axes.end(), i) != axes.end())
- sv.push_back(xt::newaxis());
- else
- sv.push_back(e1.shape()[i]);
- }
- auto max2 = xt::dynamic_view(max, sv);
- return xt::pyarray<value_type>(max2 + xt::log(xt::sum(xt::exp(e1 - max2), axes)));
- }
- template<class value_type>
- auto normalize (xt::pyarray<value_type> const& e1) {
- auto shape = std::vector<size_t>{e1.shape().size()-1};
- auto ls = logsumexp2 (e1, shape);
- auto sv = xt::slice_vector(ls);
- for (int i = 0; i < e1.dimension()-1; i++)
- sv.push_back (xt::all());
- sv.push_back (xt::newaxis());
- auto ls2 = xt::dynamic_view (ls, sv);
- return xt::pyarray<value_type> ((e1 - ls2));
- //return ls;
- }
- PYBIND11_PLUGIN (logsumexp) {
- if (_import_array() < 0) {
- PyErr_SetString(PyExc_ImportError, "numpy.core.multiarray failed to import");
- return nullptr;
- }
- py::module m("logsumexp", "pybind11 example plugin");
- m.def("logsumexp", [](xt::pyarray<double>const& x) {
- return xt::pyarray<double> (xt::eval (logsumexp1 (x)));
- });
- m.def("logsumexp", [](xt::pyarray<double>const& x, std::vector<size_t>const& ax) {
- //return xt::pyarray<double> ( (logsumexp2 (x, ax)));
- return logsumexp2 (x, ax);
- });
- m.def("normalize", [](xt::pyarray<double>const& x) {
- return normalize (x);
- });
- return m.ptr();
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement