zhangsongcui

Yield 3.1

Jul 9th, 2017
269
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 11.09 KB | None | 0 0
  1. #ifdef _WIN32
  2. #   ifdef _WIN32_WINNT
  3. #       if _WIN32_WINNT < 0x0601
  4. #           error 需要 Windows 7 以上系统支持
  5. #       endif
  6. #   else
  7. #       define _WIN32_WINNT 0x0601
  8. #   endif
  9. #   ifdef WIN32_LEAD_AND_MEAN
  10. #       include <Windows.h>
  11. #   else
  12. #       define WIN32_LEAD_AND_MEAN 1
  13. #       include <Windows.h>
  14. #       undef WIN32_LEAD_AND_MEAN
  15. #   endif
  16. #else
  17. #   if defined(__APPLE__)
  18. #       define _XOPEN_SOURCE
  19. #   endif
  20. #   include <ucontext.h>
  21. #endif
  22.  
  23. #include <functional>
  24. #include <cassert>
  25. #include <iterator>
  26. #include <iostream>
  27. #include <array>
  28. #include <memory>
  29. #include <algorithm>
  30.  
  31. namespace FiberSpace {
  32.     enum class FiberStatus {
  33.         unstarted = -1,
  34.         running = 1,
  35.         suspended = 2,
  36.         closed = 0,
  37.     };
  38.    
  39.    
  40.     /** \brief 主纤程类析构异常类
  41.      *
  42.      * \warning 用户代码吃掉此异常可导致未定义行为。如果捕获到此异常,请转抛出去。
  43.      */
  44.     struct FiberReturn {
  45.         template <typename YieldValueType>
  46.         friend class Fiber;
  47.        
  48.     private:
  49.         FiberReturn() = default;
  50.     };
  51.    
  52.     /** \brief 主纤程类
  53.      *
  54.      * \warning 线程安全(?)
  55.      * \param YieldValueType 子纤程返回类型
  56.      */
  57.     template <typename YieldValueType = void>
  58.     class Fiber {
  59.         Fiber(const Fiber &) = delete;
  60.         Fiber& operator =(const Fiber &) = delete;
  61.  
  62.         typedef std::function<void (Fiber& fiber)> FuncType;
  63.  
  64.         /// \brief 子纤程返回值
  65.         YieldValueType yieldedValue = YieldValueType();
  66.         /// \brief 存储子纤程抛出的异常
  67.         std::exception_ptr eptr = nullptr;
  68.         /// \brief 子纤程是否结束
  69.         FiberStatus status = FiberStatus::unstarted;
  70.         /// \brief 真子纤程入口,第一个参数传入纤程对象的引用
  71.         FuncType func;
  72.         /// \brief 纤程信息
  73. #ifdef _WIN32
  74.         PVOID pMainFiber, pNewFiber;
  75.         bool isFormerAThread;
  76. #else
  77.         ucontext_t ctx_main, ctx_fnew;
  78.         const std::unique_ptr<std::array<uint8_t, SIGSTKSZ>> fnew_stack = std::make_unique<std::array<uint8_t, SIGSTKSZ>>();
  79. #endif
  80.  
  81.     public:
  82.         /** \brief 构造函数
  83.          *
  84.          * 把主线程转化为纤程,并创建一个子纤程
  85.          *
  86.          * \param f 子纤程入口
  87.          */
  88.         Fiber(FuncType f) : func(std::move(f)) {
  89. #ifdef _WIN32
  90.             this->isFormerAThread = !IsThreadAFiber();
  91.             if (this->isFormerAThread) {
  92.                 this->pMainFiber = ::ConvertThreadToFiberEx(nullptr, FIBER_FLAG_FLOAT_SWITCH);
  93.             } else {
  94.                 this->pMainFiber = ::GetCurrentFiber();
  95.             }
  96.             // default stack size
  97.             this->pNewFiber = ::CreateFiberEx(0, 0, FIBER_FLAG_FLOAT_SWITCH, (void(*)(void *))&fEntry, this);
  98. #else
  99.             ::getcontext(&this->ctx_fnew);
  100.             this->ctx_fnew.uc_stack.ss_sp = this->fnew_stack.get();
  101.             this->ctx_fnew.uc_stack.ss_size = this->fnew_stack->size();
  102.             this->ctx_fnew.uc_link = &this->ctx_main;
  103.             ::makecontext(&this->ctx_fnew, (void(*)())&fEntry, 1, this);
  104. #endif
  105.  
  106. //        template<typename Fn, typename Args...>
  107. //        Fiber(Fn f,)
  108.  
  109.         }
  110.         /** \brief 析构函数
  111.          *
  112.          * 删除子纤程,并将主纤程转回线程
  113.          *
  114.          * \warning 主类析构时如子纤程尚未结束(return),则会在子纤程中抛出 FiberReturn 来确保子纤程函数内所有对象都被正确析构
  115.          */
  116.         ~Fiber() noexcept {
  117.             if (!isFinished()) {
  118.                 return_();
  119.             }
  120.            
  121. #ifdef _WIN32
  122.             ::DeleteFiber(this->pNewFiber);
  123.             if (this->isFormerAThread) {
  124.                 ::ConvertFiberToThread();
  125.             }
  126. #endif
  127.         }
  128.  
  129.         /** \brief 调用子纤程
  130.          *
  131.          * 程序流程转入子纤程
  132.          *
  133.          * \warning 子纤程必须尚未结束
  134.          * \return 返回子纤程是否尚未结束
  135.          */
  136.         bool next() {
  137.             assert(!isFinished());
  138. #ifdef _WIN32
  139.             assert(GetCurrentFiber() != this->pNewFiber && "如果你想递归自己,请创建一个新纤程");
  140.             ::SwitchToFiber(this->pNewFiber);
  141. #else
  142.             ::swapcontext(&this->ctx_main, &this->ctx_fnew);
  143. #endif
  144.             if (this->eptr) {
  145.                 std::rethrow_exception(std::exchange(this->eptr, nullptr));
  146.             }
  147.  
  148.             return !isFinished();
  149.         }
  150.  
  151.         /** \brief 向子纤程内部抛出异常
  152.          *
  153.          * 程序流程转入子纤程,并在子纤程内部抛出异常
  154.          *
  155.          * \param eptr 需抛出异常的指针(可以通过 std::make_exception_ptr 获取)
  156.          * \warning 子纤程必须尚未结束
  157.          * \return 返回子纤程是否尚未结束
  158.          */
  159.         bool throw_(std::exception_ptr&& eptr) {
  160.             assert(!isFinished());
  161.             this->eptr = std::exchange(eptr, nullptr);
  162.             return next();
  163.         }
  164.        
  165.         /** \brief 强制退出子纤程
  166.          *
  167.          * 向子纤程内部抛出 FiberReturn 异常,以强制退出子纤程,并确保子纤程函数中所有对象都正确析构
  168.          *
  169.          * \warning 子纤程必须尚未结束
  170.          */
  171.         void return_() {
  172.             assert(!isFinished());
  173.             throw_(std::make_exception_ptr(FiberReturn()));
  174.             assert(isFinished() && "请勿吃掉 FiberReturn 异常!!!");
  175.         }
  176.        
  177.         /** \brief 获得子纤程返回的值
  178.          * \return 子纤程返回的值。如果子纤程没有启动,则返回默认构造值
  179.          */
  180.         YieldValueType current() const {
  181.             return this->yieldedValue;
  182.         }
  183.  
  184.         /** \brief 判断子纤程是否结束
  185.         * \return 子纤程已经结束(return)返回true,否则false
  186.         */
  187.         bool isFinished() const noexcept {
  188.             return this->status == FiberStatus::closed;
  189.         }
  190.  
  191.         /** \brief 转回主纤程并输出值
  192.          *
  193.          * \warning 必须由子纤程调用
  194.          *          参数类型必须与子纤程返回值相同,无类型安全
  195.          * \param value 输出到主纤程的值
  196.          */
  197.         void yield(YieldValueType value) {
  198.             assert(!isFinished());
  199.             this->status = FiberStatus::suspended;
  200.             this->yieldedValue = std::move(value);
  201. #ifdef _WIN32
  202.             assert(GetCurrentFiber() != this->pMainFiber && "这虽然是游戏,但绝不是可以随便玩的");
  203.             ::SwitchToFiber(this->pMainFiber);
  204. #else
  205.             ::swapcontext(&this->ctx_fnew, &this->ctx_main);
  206. #endif
  207.             this->status = FiberStatus::running;
  208.            
  209.             if (this->eptr) {
  210.                 std::rethrow_exception(std::exchange(this->eptr, nullptr));
  211.             }
  212.         }
  213.        
  214.         /** \brief 输出子纤程的所有值
  215.          * \param fiber 另一子纤程
  216.          */
  217.         void yieldAll(Fiber& fiber) {
  218.             assert(&fiber != this);
  219.             while (fiber.next()) {
  220.                 this->yield(fiber.current());
  221.             }
  222.         }
  223.        
  224.         void yieldAll(Fiber&& fiber) {
  225.             this->yieldAll(fiber);
  226.         }
  227.  
  228.     private:
  229.         /// \brief 子纤程入口的warpper
  230.  
  231. #ifdef _WIN32
  232.         static void WINAPI fEntry(Fiber *fiber) {
  233. #else
  234.         static void fEntry(Fiber *fiber) {
  235. #endif
  236.             if (!fiber->eptr) {
  237.                 fiber->status = FiberStatus::running;
  238.                 try {
  239.                     fiber->func(*fiber);
  240.                 } catch (FiberReturn &) {
  241.                     // 主 Fiber 对象正在析构
  242.                 } catch (...) {
  243.                     fiber->eptr = std::current_exception();
  244.                 }
  245.             }
  246.             fiber->status = FiberStatus::closed;
  247. #ifdef _WIN32
  248.             ::SwitchToFiber(fiber->pMainFiber);
  249. #endif
  250.         }
  251.     };
  252.  
  253.     /** \brief 纤程迭代器类
  254.     *
  255.     * 它通过使用 yield 函数对数组或集合类执行自定义迭代。
  256.     * 用于 C++11 for (... : ...)
  257.     */
  258.     template <typename YieldValueType>
  259.     struct FiberIterator : std::iterator<std::input_iterator_tag, YieldValueType> {
  260.         /// \brief 迭代器尾
  261.         FiberIterator() noexcept : fiber(nullptr) {}
  262.         /** \brief 迭代器首
  263.         * \param _f 主线程类的引用
  264.         */
  265.         FiberIterator(Fiber<YieldValueType>& _f) : fiber(&_f) {
  266.             next();
  267.         }
  268.  
  269.         /// \brief 转入子纤程
  270.         FiberIterator& operator ++() {
  271.             next();
  272.             return *this;
  273.         }
  274.  
  275.         /// \brief 取得返回值
  276.         YieldValueType operator *() const {
  277.             assert(fiber != nullptr);
  278.             return std::move(fiber->current());
  279.         }
  280.  
  281.         /** \brief 比较迭代器相等
  282.         *
  283.         * 通常用于判断迭代是否结束
  284.         * 最好别干别的 ;P
  285.         */
  286.         bool operator ==(const FiberIterator& rhs) const noexcept {
  287.             return fiber == rhs.fiber;
  288.         }
  289.         bool operator !=(const FiberIterator& rhs) const noexcept {
  290.             return !(*this == rhs);
  291.         }
  292.  
  293.     private:
  294.         void next() {
  295.             assert(fiber);
  296.             if (!fiber->next()) fiber = nullptr;
  297.         }
  298.    
  299.         Fiber<YieldValueType>* fiber;
  300.     };
  301.  
  302.     /// \brief 返回迭代器首
  303.     template <typename YieldValueType>
  304.     FiberIterator<YieldValueType> begin(Fiber<YieldValueType>& fiber) {
  305.         return FiberIterator<YieldValueType>(fiber);
  306.     }
  307.  
  308.     /// \brief 返回迭代器尾
  309.     template <typename YieldValueType>
  310.     FiberIterator<YieldValueType> end(Fiber<YieldValueType>&) noexcept {
  311.         return FiberIterator<YieldValueType>();
  312.     }
  313. }
  314.  
  315. using namespace std;
  316. using FiberSpace::Fiber;
  317.  
  318. bool destructedFlag = false;
  319.  
  320. struct TestDestruct {
  321.     ~TestDestruct() {
  322.         destructedFlag = true;
  323.     }
  324. };
  325.  
  326. void foo(Fiber<bool>& fiber, int arg) {
  327.     TestDestruct test;
  328.     for (int i = 1; i < 5; i++) {
  329.         printf("goroutine :%d\n", arg+i);
  330.         fiber.yield(false);
  331.     }
  332. }
  333.  
  334. void permutation(Fiber<array<int, 4>>& fiber, array<int, 4> arr, int length) {
  335.     if (length) {
  336.         for (auto i = 0; i < length; ++i) {
  337.             array<int, 4> newArr(arr);
  338.             std::copy_n(arr.begin(), i, newArr.begin());
  339.             std::copy_n(arr.begin() + i + 1, arr.size() - i - 1, newArr.begin() + i);
  340.             newArr.back() = arr[i];
  341.             fiber.yieldAll(Fiber<array<int, 4>>(bind(permutation, placeholders::_1, newArr, length - 1)));
  342.         }
  343.     } else {
  344.         fiber.yield(arr);
  345.     }
  346. }
  347.  
  348. int main() {
  349.     {
  350.         Fiber<bool> arg1Fiber(bind(foo, placeholders::_1, 0));
  351.         arg1Fiber.next();
  352.     }
  353.     assert(destructedFlag);
  354.  
  355.     Fiber<array<int, 4>> fiber(
  356.         bind(permutation, placeholders::_1, array<int, 4> {1, 2, 3, 4}, 4)
  357.     );
  358.     for (auto&& result : fiber) {
  359.         copy(result.begin(), result.end(), std::ostream_iterator<int>(cout, ","));
  360.         cout << endl;
  361.     }
  362. }
Advertisement
Add Comment
Please, Sign In to add comment