Advertisement
zhangsongcui

Yield v3.2

Jul 9th, 2017
203
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 11.48 KB | None | 0 0
  1. #ifdef _WIN32
  2. #   ifdef _WIN32_WINNT
  3. #       if _WIN32_WINNT < 0x0601
  4. #           error 需要 Windows 7 以上系统支持
  5. #       endif
  6. #   else
  7. #       define _WIN32_WINNT 0x0601
  8. #   endif
  9. #   ifdef WIN32_LEAD_AND_MEAN
  10. #       include <Windows.h>
  11. #   else
  12. #       define WIN32_LEAD_AND_MEAN 1
  13. #       include <Windows.h>
  14. #       undef WIN32_LEAD_AND_MEAN
  15. #   endif
  16. #else
  17. #   if defined(__APPLE__)
  18. #       define _XOPEN_SOURCE
  19. #   endif
  20. #   include <ucontext.h>
  21. #endif
  22.  
  23. #include <functional>
  24. #include <cassert>
  25. #include <iterator>
  26. #include <iostream>
  27. #include <array>
  28. #include <memory>
  29. #include <algorithm>
  30. #include <experimental/optional>
  31.  
  32. namespace FiberSpace {
  33.     enum class FiberStatus {
  34.         unstarted = -1,
  35.         running = 1,
  36.         suspended = 2,
  37.         closed = 0,
  38.     };
  39.    
  40.    
  41.     /** \brief 主纤程类析构异常类
  42.      *
  43.      * \warning 用户代码吃掉此异常可导致未定义行为。如果捕获到此异常,请转抛出去。
  44.      */
  45.     struct FiberReturn {
  46.         template <typename YieldValueType>
  47.         friend class Fiber;
  48.        
  49.     private:
  50.         FiberReturn() = default;
  51.     };
  52.    
  53.     /** \brief 主纤程类
  54.      *
  55.      * \warning 线程安全(?)
  56.      * \tparam YieldValueType 子纤程返回类型
  57.      */
  58.     template <typename YieldValueType>
  59.     class Fiber {
  60.         Fiber(const Fiber &) = delete;
  61.         Fiber& operator =(const Fiber &) = delete;
  62.  
  63.         typedef std::function<void (Fiber& fiber)> FuncType;
  64.  
  65.         /// \brief 子纤程返回值
  66.         std::experimental::optional<YieldValueType> yieldedValue;
  67.         /// \brief 存储子纤程抛出的异常
  68.         std::exception_ptr eptr = nullptr;
  69.         /// \brief 子纤程是否结束
  70.         FiberStatus status = FiberStatus::unstarted;
  71.         /// \brief 真子纤程入口,第一个参数传入纤程对象的引用
  72.         FuncType func;
  73.        
  74.         /// \brief 纤程信息
  75. #ifdef _WIN32
  76.         PVOID pMainFiber, pNewFiber;
  77.         bool isFormerAThread;
  78. #else
  79.         ucontext_t ctx_main, ctx_fnew;
  80.         const std::unique_ptr<std::array<uint8_t, SIGSTKSZ>> fnew_stack = std::make_unique<std::array<uint8_t, SIGSTKSZ>>();
  81. #endif
  82.  
  83.     public:
  84.         /** \brief 构造函数
  85.          *
  86.          * 把主线程转化为纤程,并创建一个子纤程
  87.          *
  88.          * \param f 子纤程入口
  89.          */
  90.         explicit Fiber(FuncType f) : func(std::move(f)) {
  91. #ifdef _WIN32
  92.             this->isFormerAThread = !IsThreadAFiber();
  93.             if (this->isFormerAThread) {
  94.                 this->pMainFiber = ::ConvertThreadToFiberEx(nullptr, FIBER_FLAG_FLOAT_SWITCH);
  95.             } else {
  96.                 this->pMainFiber = ::GetCurrentFiber();
  97.             }
  98.             // default stack size
  99.             this->pNewFiber = ::CreateFiberEx(0, 0, FIBER_FLAG_FLOAT_SWITCH, (void(*)(void *))&fEntry, this);
  100. #else
  101.             ::getcontext(&this->ctx_fnew);
  102.             this->ctx_fnew.uc_stack.ss_sp = this->fnew_stack.get();
  103.             this->ctx_fnew.uc_stack.ss_size = this->fnew_stack->size();
  104.             this->ctx_fnew.uc_link = &this->ctx_main;
  105.             ::makecontext(&this->ctx_fnew, (void(*)())&fEntry, 1, this);
  106. #endif
  107.         }
  108.        
  109.         template <class Fp, class ...Args,
  110.             class = typename std::enable_if
  111.             <
  112.                 (sizeof...(Args) > 0)
  113.             >::type
  114.         >
  115.         explicit Fiber(Fp&& f, Args&&... args): Fiber(std::bind(std::forward<Fp>(f), std::placeholders::_1, std::forward<Args>(args)...)) {}
  116.        
  117.         /** \brief 析构函数
  118.          *
  119.          * 删除子纤程,并将主纤程转回线程
  120.          *
  121.          * \warning 主类析构时如子纤程尚未结束(return),则会在子纤程中抛出 FiberReturn 来确保子纤程函数内所有对象都被正确析构
  122.          */
  123.         ~Fiber() noexcept {
  124.             if (!isFinished()) {
  125.                 return_();
  126.             }
  127.            
  128. #ifdef _WIN32
  129.             ::DeleteFiber(this->pNewFiber);
  130.             if (this->isFormerAThread) {
  131.                 ::ConvertFiberToThread();
  132.             }
  133. #endif
  134.         }
  135.  
  136.         /** \brief 调用子纤程
  137.          *
  138.          * 程序流程转入子纤程
  139.          *
  140.          * \warning 子纤程必须尚未结束
  141.          * \return 返回子纤程是否尚未结束
  142.          */
  143.         bool next() {
  144.             assert(!isFinished());
  145. #ifdef _WIN32
  146.             assert(GetCurrentFiber() != this->pNewFiber && "如果你想递归自己,请创建一个新纤程");
  147.             ::SwitchToFiber(this->pNewFiber);
  148. #else
  149.             ::swapcontext(&this->ctx_main, &this->ctx_fnew);
  150. #endif
  151.             if (this->eptr) {
  152.                 std::rethrow_exception(std::exchange(this->eptr, nullptr));
  153.             }
  154.  
  155.             return !isFinished();
  156.         }
  157.  
  158.         /** \brief 向子纤程内部抛出异常
  159.          *
  160.          * 程序流程转入子纤程,并在子纤程内部抛出异常
  161.          *
  162.          * \param eptr 需抛出异常的指针(可以通过 std::make_exception_ptr 获取)
  163.          * \warning 子纤程必须尚未结束
  164.          * \return 返回子纤程是否尚未结束
  165.          */
  166.         bool throw_(std::exception_ptr&& eptr) {
  167.             assert(!isFinished());
  168.             this->eptr = std::exchange(eptr, nullptr);
  169.             return next();
  170.         }
  171.        
  172.         /** \brief 强制退出子纤程
  173.          *
  174.          * 向子纤程内部抛出 FiberReturn 异常,以强制退出子纤程,并确保子纤程函数中所有对象都正确析构
  175.          *
  176.          * \warning 子纤程必须尚未结束
  177.          */
  178.         void return_() {
  179.             assert(!isFinished());
  180.             throw_(std::make_exception_ptr(FiberReturn()));
  181.             assert(isFinished() && "请勿吃掉 FiberReturn 异常!!!");
  182.         }
  183.        
  184.         /** \brief 获得子纤程返回的值
  185.          * \return 子纤程返回的值。如果子纤程没有启动,则返回默认构造值
  186.          */
  187.         const YieldValueType& current() const {
  188.             return *this->yieldedValue;
  189.         }
  190.  
  191.         /** \brief 判断子纤程是否结束
  192.         * \return 子纤程已经结束(return)返回true,否则false
  193.         */
  194.         bool isFinished() const noexcept {
  195.             return this->status == FiberStatus::closed;
  196.         }
  197.  
  198.         /** \brief 转回主纤程并输出值
  199.          *
  200.          * \warning 必须由子纤程调用
  201.          *          参数类型必须与子纤程返回值相同,无类型安全
  202.          * \param value 输出到主纤程的值
  203.          */
  204.         void yield(YieldValueType value) {
  205.             assert(!isFinished());
  206.             this->status = FiberStatus::suspended;
  207.             this->yieldedValue = std::move(value);
  208. #ifdef _WIN32
  209.             assert(GetCurrentFiber() != this->pMainFiber && "这虽然是游戏,但绝不是可以随便玩的");
  210.             ::SwitchToFiber(this->pMainFiber);
  211. #else
  212.             ::swapcontext(&this->ctx_fnew, &this->ctx_main);
  213. #endif
  214.             this->status = FiberStatus::running;
  215.            
  216.             if (this->eptr) {
  217.                 std::rethrow_exception(std::exchange(this->eptr, nullptr));
  218.             }
  219.         }
  220.        
  221.         /** \brief 输出子纤程的所有值
  222.          * \param fiber 另一子纤程
  223.          */
  224.         void yieldAll(Fiber& fiber) {
  225.             assert(&fiber != this);
  226.             while (fiber.next()) {
  227.                 this->yield(fiber.current());
  228.             }
  229.         }
  230.        
  231.         void yieldAll(Fiber&& fiber) {
  232.             this->yieldAll(fiber);
  233.         }
  234.  
  235.     private:
  236.         /// \brief 子纤程入口的warpper
  237.  
  238. #ifdef _WIN32
  239.         static void WINAPI fEntry(Fiber *fiber) {
  240. #else
  241.         static void fEntry(Fiber *fiber) {
  242. #endif
  243.             if (!fiber->eptr) {
  244.                 fiber->status = FiberStatus::running;
  245.                 try {
  246.                     fiber->func(*fiber);
  247.                 } catch (FiberReturn &) {
  248.                     // 主 Fiber 对象正在析构
  249.                 } catch (...) {
  250.                     fiber->eptr = std::current_exception();
  251.                 }
  252.             }
  253.             fiber->status = FiberStatus::closed;
  254.             fiber->yieldedValue = std::experimental::nullopt;
  255. #ifdef _WIN32
  256.             ::SwitchToFiber(fiber->pMainFiber);
  257. #endif
  258.         }
  259.     };
  260.  
  261.     /** \brief 纤程迭代器类
  262.     *
  263.     * 它通过使用 yield 函数对数组或集合类执行自定义迭代。
  264.     * 用于 C++11 for (... : ...)
  265.     */
  266.     template <typename YieldValueType>
  267.     struct FiberIterator : std::iterator<std::output_iterator_tag, YieldValueType> {
  268.         /// \brief 迭代器尾
  269.         FiberIterator() noexcept : fiber(nullptr) {}
  270.         /** \brief 迭代器首
  271.         * \param _f 主线程类的引用
  272.         */
  273.         FiberIterator(Fiber<YieldValueType>& _f) : fiber(&_f) {
  274.             next();
  275.         }
  276.  
  277.         /// \brief 转入子纤程
  278.         FiberIterator& operator ++() {
  279.             next();
  280.             return *this;
  281.         }
  282.  
  283.         /// \brief 取得返回值
  284.         const YieldValueType &operator *() const {
  285.             assert(fiber != nullptr);
  286.             return fiber->current();
  287.         }
  288.  
  289.         /** \brief 比较迭代器相等
  290.         *
  291.         * 通常用于判断迭代是否结束
  292.         * 最好别干别的 ;P
  293.         */
  294.         bool operator ==(const FiberIterator& rhs) const noexcept {
  295.             return fiber == rhs.fiber;
  296.         }
  297.         bool operator !=(const FiberIterator& rhs) const noexcept {
  298.             return !(*this == rhs);
  299.         }
  300.  
  301.     private:
  302.         void next() {
  303.             assert(fiber);
  304.             if (!fiber->next()) fiber = nullptr;
  305.         }
  306.    
  307.         Fiber<YieldValueType>* fiber;
  308.     };
  309.  
  310.     /// \brief 返回迭代器首
  311.     template <typename YieldValueType>
  312.     FiberIterator<YieldValueType> begin(Fiber<YieldValueType>& fiber) {
  313.         return FiberIterator<YieldValueType>(fiber);
  314.     }
  315.  
  316.     /// \brief 返回迭代器尾
  317.     template <typename YieldValueType>
  318.     FiberIterator<YieldValueType> end(Fiber<YieldValueType>&) noexcept {
  319.         return FiberIterator<YieldValueType>();
  320.     }
  321. }
  322.  
  323. using namespace std;
  324. using FiberSpace::Fiber;
  325.  
  326. bool destructedFlag = false;
  327.  
  328. struct TestDestruct {
  329.     ~TestDestruct() {
  330.         destructedFlag = true;
  331.     }
  332. };
  333.  
  334. void foo(Fiber<bool>& fiber, int arg) {
  335.     TestDestruct test;
  336.     for (int i = 1; i < 5; i++) {
  337.         printf("goroutine :%d\n", arg+i);
  338.         fiber.yield(false);
  339.     }
  340. }
  341.  
  342. void do_permutation(Fiber<array<int, 4>>& fiber, array<int, 4> arr, int length) {
  343.     if (length) {
  344.         for (auto i = 0; i < length; ++i) {
  345.             array<int, 4> newArr(arr);
  346.             std::copy_n(arr.begin(), i, newArr.begin());
  347.             std::copy_n(arr.begin() + i + 1, arr.size() - i - 1, newArr.begin() + i);
  348.             newArr.back() = arr[i];
  349.             fiber.yieldAll(Fiber<array<int, 4>>(do_permutation, newArr, length - 1));
  350.         }
  351.     } else {
  352.         fiber.yield(arr);
  353.     }
  354. }
  355.  
  356. void permutation(Fiber<array<int, 4>>& fiber, array<int, 4> arr) {
  357.     do_permutation(fiber, arr, arr.size());
  358. }
  359.  
  360. int main() {
  361.     {
  362.         Fiber<bool> arg1Fiber(foo, 0);
  363.         arg1Fiber.next();
  364.     }
  365.     assert(destructedFlag);
  366.    
  367.     for (auto&& result : Fiber<array<int, 4>>(permutation, array<int, 4> { 1, 2, 3, 4 })) {
  368.         copy(result.begin(), result.end(), std::ostream_iterator<int>(cout, ","));
  369.         cout << endl;
  370.     }
  371. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement