here2share

pybind11_tester

Apr 5th, 2021
727
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. ##### MakeLists.txt
  2. cmake_minimum_required(VERSION 3.4...3.18)
  3. project(pybindtest)
  4. add_subdirectory(pybind11)
  5. pybind11_add_module(module_name main.cpp)
  6.  
  7.  
  8.  
  9. ##### main.cpp
  10. #include <vector>
  11. #include <pybind11/pybind11.h>
  12. #include <pybind11/stl.h>
  13. #include <pybind11/numpy.h>
  14. #include <chrono>
  15. #include <thread>
  16.  
  17. namespace py = pybind11;
  18.  
  19. float some_fn(float arg1, float arg2) {
  20.   return arg1 + arg2;
  21. }
  22.  
  23. class SomeClass {
  24.   float multiplier;
  25. public:
  26.   SomeClass(float multiplier_) : multiplier(multiplier_) {};
  27.  
  28.   float multiply(float input) {
  29.     return multiplier * input;
  30.   }
  31.  
  32.   std::vector<float> multiply_list(std::vector<float> items) {
  33.     for (auto i = 0; i < items.size(); i++) {
  34.       items[i] = multiply(items.at(i));
  35.     }
  36.     return items;
  37.   }
  38.  
  39.   // py::tuple multiply_two(float one, float two) {
  40.   //   return py::make_tuple(multiply(one), multiply(two));
  41.   // }
  42.  
  43.   std::vector<std::vector<uint8_t>> make_image() {
  44.     auto out = std::vector<std::vector<uint8_t>>();
  45.     for (auto i = 0; i < 128; i++) {
  46.       out.push_back(std::vector<uint8_t>(64));
  47.     }
  48.     for (auto i = 0; i < 30; i++) {
  49.       for (auto j = 0; j < 30; j++) { out[i][j] = 255; }
  50.     }
  51.     return out;
  52.   }
  53.  
  54.   void set_mult(float val) {
  55.     multiplier = val;
  56.   }
  57.  
  58.   float get_mult() {
  59.     return multiplier;
  60.   }
  61.  
  62.   void function_that_takes_a_while() {
  63.     py::gil_scoped_release release;
  64.     std::cout << "starting" << std::endl;
  65.     std::this_thread::sleep_for(std::chrono::milliseconds(2000));
  66.     std::cout << "ended" << std::endl;
  67.  
  68.     // py::gil_scoped_acquire acquire;
  69.     // auto list = py::list();
  70.     // list.append(1);
  71.   }
  72. };
  73.  
  74. SomeClass some_class_factory(float multiplier) {
  75.   return SomeClass(multiplier);
  76. }
  77.  
  78.  
  79. PYBIND11_MODULE(module_name, module_handle) {
  80.   module_handle.doc() = "I'm a docstring hehe";
  81.   module_handle.def("some_fn_python_name", &some_fn);
  82.   module_handle.def("some_class_factory", &some_class_factory);
  83.   py::class_<SomeClass>(
  84.             module_handle, "PySomeClass"
  85.             ).def(py::init<float>())
  86.     .def_property("multiplier", &SomeClass::get_mult, &SomeClass::set_mult)
  87.     .def("multiply", &SomeClass::multiply)
  88.     .def("multiply_list", &SomeClass::multiply_list)
  89.     // .def_property_readonly("image", &SomeClass::make_image)
  90.     .def_property_readonly("image", [](SomeClass &self) {
  91.                       py::array out = py::cast(self.make_image());
  92.                       return out;
  93.                     })
  94.     // .def("multiply_two", &SomeClass::multiply_two)
  95.     .def("multiply_two", [](SomeClass &self, float one, float two) {
  96.                return py::make_tuple(self.multiply(one), self.multiply(two));
  97.              })
  98.     .def("function_that_takes_a_while", &SomeClass::function_that_takes_a_while)
  99.     ;
  100. }
  101.  
  102.  
  103.  
  104. ##### test.py
  105. import time
  106. import traceback
  107. import cv2
  108. from build.module_name import *
  109.  
  110. from concurrent.futures import ThreadPoolExecutor
  111.  
  112. def call_and_print_exc(fn):
  113.     try:
  114.         fn()
  115.     except Exception:
  116.         traceback.print_exc()
  117.  
  118. print(PySomeClass)
  119.  
  120.  
  121. m = some_class_factory(10)
  122.  
  123. m2 = PySomeClass(10)
  124.  
  125. print(m, m2)
  126.  
  127. print(m.multiply(20))
  128.  
  129. # print(m.multiply("20"))
  130.  
  131. arr = m.multiply_list([0.0, 1.0, 2.0, 3.0])
  132.  
  133. print(arr)
  134.  
  135. print(m.multiply_two(50, 200))
  136.  
  137. print(m.image)
  138.  
  139. print(m.image.shape)
  140.  
  141. cv2.imwrite("/tmp/test.png", m.image)
  142.  
  143. print(m.multiplier)
  144.  
  145. m.multiplier = 100
  146.  
  147. print(m.multiplier)
  148.  
  149. start = time.time()
  150.  
  151. with ThreadPoolExecutor(4) as ex:
  152.     ex.map(lambda x: m.function_that_takes_a_while(), [None]*4)
  153.  
  154. print(f"Threaded fun took {time.time() - start} seconds")
RAW Paste Data

Adblocker detected! Please consider disabling it...

We've detected AdBlock Plus or some other adblocking software preventing Pastebin.com from fully loading.

We don't have any obnoxious sound, or popup ads, we actively block these annoying types of ads!

Please add Pastebin.com to your ad blocker whitelist or disable your adblocking software.

×