Advertisement
zhangsongcui

yield

Jul 27th, 2012
379
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 9.60 KB | None | 0 0
  1. #include <functional>
  2. #include <cassert>
  3. #include <iterator>
  4. #include <iostream>
  5. #ifdef _WIN32
  6. #   define WIN32_LEAD_AND_MEAN
  7. #   include <Windows.h>
  8. #   undef WIN32_LEAD_AND_MEAN
  9. #else
  10. #   include <ucontext.h>
  11. #endif
  12.  
  13. namespace Fiber {
  14.     /** \brief 辅助类
  15.      *
  16.      * 存放各种静态变量
  17.      * 放于Fiber<...>外,确保类型无关
  18.      */
  19.     class FiberHelper {
  20.         template <typename YieldValueType>
  21.         friend class Fiber;
  22. #ifdef _WIN32
  23.         /// \brief 主纤程信息
  24.         static void* pMainFiber;
  25. #else
  26.         /// \brief 主纤程信息
  27.         static ucontext_t ctx_main;
  28.         /// \brief 子纤程信息
  29.         static ucontext_t ctx_fnew;
  30.         static uint8_t fnew_stack[SIGSTKSZ];
  31. #endif
  32.         /// \brief 子纤程是否结束
  33.         static bool flagFinish;
  34.         static std::exception_ptr eptr;
  35.     };
  36. #ifdef _WIN32
  37.     void* FiberHelper::pMainFiber = nullptr;
  38. #else
  39.     ucontext_t FiberHelper::ctx_main, FiberHelper::ctx_fnew;
  40.     uint8_t FiberHelper::fnew_stack[SIGSTKSZ];
  41. #endif
  42.     bool FiberHelper::flagFinish = false;
  43.     std::exception_ptr FiberHelper::eptr;
  44.  
  45.     /** \brief 主纤程类
  46.      *
  47.      * 只支持创建一个实例
  48.      * 各种不安全&不支持,玩玩足矣
  49.      *
  50.      * \todo 改用单例模式?
  51.      */
  52.     template <typename YieldValueType>
  53.     class Fiber {
  54.         typedef std::function<YieldValueType ()> FuncType;
  55.         static YieldValueType yieldedValue;
  56.  
  57.     public:
  58.         /** \brief 构造函数
  59.          *
  60.          * 把主线程转化为纤程,并创建一个子纤程
  61.          *
  62.          * \warning 只支持创建一个Fiber实例
  63.          * \param 子纤程入口
  64.          */
  65.         template <typename Fn>
  66.         Fiber(Fn&& f): func(std::forward<Fn>(f)) {
  67. #ifdef _WIN32
  68.             assert(FiberHelper::pMainFiber == nullptr && "Only one fiber is supported");
  69. #   if _WIN32_WINNT >= 0x0400
  70.             FiberHelper::pMainFiber = ::ConvertThreadToFiberEx(nullptr, FIBER_FLAG_FLOAT_SWITCH);
  71.             // default stack size
  72.             pNewFiber = ::CreateFiberEx(0, 0, FIBER_FLAG_FLOAT_SWITCH, &fEntry, &func);
  73. #   else
  74. #       warning See: msdn.microsoft.com/en-us/library/ms682117.aspx
  75.             FiberHelper::pMainFiber = ::ConvertThreadToFiber(nullptr);
  76.             pNewFiber = ::CreateFiber(0, &fEntry, &func);
  77. #   endif
  78. #else
  79.             ucontext_t* const uctx = &FiberHelper::ctx_fnew;
  80.             ::getcontext(&FiberHelper::ctx_fnew);
  81.             uctx->uc_stack.ss_sp = FiberHelper::fnew_stack;
  82.             uctx->uc_stack.ss_size = sizeof(FiberHelper::fnew_stack);
  83.             uctx->uc_link = &FiberHelper::ctx_main;
  84.             ::makecontext(uctx, reinterpret_cast<void (*)()>(&fEntry), 1, &func);
  85. #endif
  86.             FiberHelper::flagFinish = false;
  87.         }
  88.         /** \brief 析构函数
  89.          *
  90.          * 删除子纤程,并将主纤程转回线程
  91.          *
  92.          * \warning 主类析构时子纤程必须已经结束
  93.          */
  94.         ~Fiber() {
  95.             assert(FiberHelper::flagFinish == true);
  96. #ifdef _WIN32
  97.             ::DeleteFiber(pNewFiber);
  98.             ::ConvertFiberToThread();
  99.             FiberHelper::pMainFiber = nullptr;
  100. #endif
  101.         }
  102.  
  103.         /** \brief 调用子纤程
  104.          *
  105.          * 程序流程转入子线程
  106.          *
  107.          * \warning 子纤程必须尚未结束
  108.          * \return 返回子纤程yield或return的值
  109.          */
  110.         YieldValueType call() {
  111.             assert(FiberHelper::flagFinish == false);
  112. #ifdef _WIN32
  113.             assert(pNewFiber != nullptr);
  114.             ::SwitchToFiber(pNewFiber);
  115. #else
  116.             ::swapcontext(&FiberHelper::ctx_main, &FiberHelper::ctx_fnew);
  117. #endif
  118.             if (!(FiberHelper::eptr == std::exception_ptr()))
  119.                 std::rethrow_exception(FiberHelper::eptr);
  120.             return std::move(yieldedValue);
  121.         }
  122.  
  123.         /** \brief 判断子纤程是否结束
  124.          *
  125.          * \return 子纤程已经结束(return)返回true,否则false
  126.          */
  127.         bool isFinished() { return FiberHelper::flagFinish; }
  128.  
  129.         /** \brief 转回主纤程并输出值
  130.          *
  131.          * \warning 必须由子纤程调用
  132.          *          参数类型必须与子纤程返回值相同,无类型安全
  133.          * \param 输出到主纤程的值
  134.          */
  135.         static void yield(YieldValueType value) {
  136.             yieldedValue = std::move(value);
  137. #ifdef _WIN32
  138.             assert(FiberHelper::pMainFiber != nullptr);
  139.             assert(GetCurrentFiber() != FiberHelper::pMainFiber);
  140.             ::SwitchToFiber(FiberHelper::pMainFiber);
  141. #else
  142.             ::swapcontext(&FiberHelper::ctx_fnew, &FiberHelper::ctx_main);
  143. #endif
  144.         }
  145.  
  146.     private:
  147.         /** \brief 子纤程入口的warpper
  148.          *
  149.          * 不足为外人道也
  150.          */
  151.         static void
  152. #ifdef _WIN32
  153.         WINAPI
  154. #endif
  155.         fEntry(void* param) {
  156.             assert(FiberHelper::flagFinish == false);
  157.             assert(param != nullptr);
  158.             FiberHelper::flagFinish = false;
  159.             FuncType& f = *reinterpret_cast<FuncType *>(param);
  160.             try {
  161.                 yieldedValue = f();
  162.             } catch (...) { // 没finally淡腾
  163.                 FiberHelper::eptr = std::current_exception();
  164.                 FiberHelper::flagFinish = true;
  165. #ifdef _WIN32
  166.                 ::SwitchToFiber(FiberHelper::pMainFiber);
  167. #else
  168.                 return;
  169. #endif
  170.             }
  171.             FiberHelper::flagFinish = true;
  172. #ifdef _WIN32
  173.             ::SwitchToFiber(FiberHelper::pMainFiber);
  174. #endif
  175.         }
  176.  
  177. #ifdef _WIN32
  178.         /// \brief 子纤程信息
  179.         void* pNewFiber;
  180. #endif
  181.         /// \brief 真子纤程入口
  182.         FuncType func;
  183.     };
  184.     template <typename YieldValueType>
  185.     YieldValueType Fiber<YieldValueType>::yieldedValue;
  186.  
  187.     /** \brief 纤程迭代器类
  188.      *
  189.      * 它通过使用 yield 函数对数组或集合类执行自定义迭代。
  190.      * 用于 C++11 for (...:...)
  191.      *
  192.      * \bug C++中,iterator的有效范围为[&front, &back+1)
  193.      *      而纤程需要java/.NET的(&front-1, back]
  194.      *      暂时无解,保留待改
  195.      */
  196.     template <typename YieldValueType>
  197.     struct FiberIterator: std::iterator<std::input_iterator_tag, YieldValueType> {
  198.         /// \brief 迭代器尾
  199.         FiberIterator(): fiber(nullptr), value() {}
  200.         /** \brief 迭代器首
  201.          * \param 主线程类的引用
  202.          */
  203.         FiberIterator(Fiber<YieldValueType>& _f): fiber(&_f), value(_f.call()) {}
  204.  
  205.         /// \brief 转入子纤程
  206.         FiberIterator& operator ++() {
  207.             assert(fiber != nullptr);
  208.             if (!fiber->isFinished())
  209.                 value = fiber->call();
  210.             else
  211.                 fiber = nullptr;
  212.             return *this;
  213.         }
  214.         // 返回临时对象没问题吧-_-!!
  215.         FiberIterator operator ++(int) {
  216.             FiberIterator tmp(*this);
  217.             ++*this;
  218.             return tmp;
  219.         }
  220.  
  221.         /// \brief 取得返回值
  222.         YieldValueType& operator *() {
  223.             assert(fiber != nullptr);
  224.             return value;
  225.         }
  226.  
  227.         /** \brief 比较迭代器相等
  228.          *
  229.          * 通常用于判断迭代是否结束
  230.          *
  231.          * \warning 最好别干别的 ;P
  232.          */
  233.         bool operator ==(const FiberIterator& rhs) {
  234.             return fiber == rhs.fiber;
  235.         }
  236.         bool operator !=(const FiberIterator& rhs) {
  237.             return !(*this == rhs);
  238.         }
  239.  
  240.     private:
  241.         Fiber<YieldValueType>* fiber;
  242.         YieldValueType value;
  243.     };
  244.  
  245.     /// \brief 返回迭代器首
  246.     /// 用于 C++11 for (...:...)
  247.     template <typename YieldValueType>
  248.     FiberIterator<YieldValueType> begin(Fiber<YieldValueType>& fiber) {
  249.         return FiberIterator<YieldValueType>(fiber);
  250.     }
  251.  
  252.     /// \brief 返回迭代器尾
  253.     /// 用于 C++11 for (...:...)
  254.     template <typename YieldValueType>
  255.     FiberIterator<YieldValueType> end(Fiber<YieldValueType>&) {
  256.         return FiberIterator<YieldValueType>();
  257.     }
  258. }
  259.  
  260. using namespace std;
  261. int Test(int& i) {
  262.     cout << "func, i = " << i << endl;
  263.     // 保留栈,转回主纤程,并输出值
  264.     Fiber::Fiber<int>::yield(++i);
  265.     cout << "func, i = " << i << endl;
  266.     // 终止迭代,返回主纤程,并输出值
  267.     return ++i;
  268. }
  269.  
  270. int Test2(int beg, int end) {
  271.     while (beg !=end)
  272.         Fiber::Fiber<int>::yield(beg++);
  273.     return beg;
  274. }
  275.  
  276. int TestException() {
  277.     throw exception();
  278. }
  279.  
  280. int main() {
  281.     {
  282.         int i = 0, t;
  283.         cout << "main, i = " << i << endl;
  284.         // 把主线程转化为纤程,并创建一个子纤程。参数为子纤程入口
  285.         Fiber::Fiber<int> fiber(std::bind(Test, std::ref(i)));
  286.         // 流程转入子线程
  287.         t = fiber.call();
  288.         cout << "main, Test yield: " << t << endl;
  289.         t = fiber.call();
  290.         cout << "main, Test return: " << t << endl;
  291.         cout << "main, i = " << i << endl;
  292.         // 确保fiber正常析构
  293.     }
  294.     {
  295.         // VC10: ╮(╯▽╰)╭
  296.         for (int i : Fiber::Fiber<int>(std::bind(Test2, 1, 10)))
  297.             cout << i << '\t';
  298.         cout << endl;
  299.     }
  300.     {
  301.         try {
  302.             Fiber::Fiber<int> fiber(TestException);
  303.             fiber.call();
  304.         } catch (exception&) {
  305.             cout << "Exception catched!" << endl;
  306.         }
  307.     }
  308. }
  309. /* Expected Output:
  310. main, i = 0
  311. func, i = 0
  312. main, Test yield: 1
  313. func, i = 1
  314. main, Test return: 2
  315. main, i = 2
  316. 0       1       2       3       4       5       6       7       8       9
  317. */
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement