Advertisement
zhangsongcui

yield V2.1 (with void)

Jul 29th, 2012
189
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #include <functional>
  2. #include <cassert>
  3. #include <iterator>
  4. #include <iostream>
  5. #include <memory>
  6. #ifdef _WIN32
  7. #   ifdef WIN32_LEAD_AND_MEAN
  8. #       include <Windows.h>
  9. #   else
  10. #       define WIN32_LEAD_AND_MEAN 1
  11. #       include <Windows.h>
  12. #       undef WIN32_LEAD_AND_MEAN
  13. #   endif
  14. #elif !defined(__APPLE__)
  15. #   include <ucontext.h>
  16. #else
  17. #   error OSX is not supported
  18. #endif
  19.  
  20. namespace FiberSpace {
  21.     class FiberHelper {
  22.         template <typename YieldValueType>
  23.         friend class Fiber;
  24.  
  25. #if defined(_WIN32) && (!defined(_WIN32_WINNT) || (_WIN32_WINNT) < 0x600)
  26.         // Vista之前不支持IsThreadAFiber
  27.         /// \brief 表示是否需要线纤互转
  28.         static bool bNeedConvert;
  29. #endif
  30.         /// \brief 存放当前纤程的param指针
  31.         static void* paraThis;
  32.     };
  33. #if defined(_WIN32) && (!defined(_WIN32_WINNT) || (_WIN32_WINNT) < 0x600)
  34.     bool FiberHelper::bNeedConvert = true;
  35. #endif
  36.     void* FiberHelper::paraThis = nullptr;
  37.  
  38.     /** \brief 主纤程类
  39.      *
  40.      * 目前可以嵌套开纤程,但最好只在每个纤程中新开一个纤程
  41.      *
  42.      * \warning 无线程安全
  43.      * \param 子纤程返回类型
  44.      */
  45.     template <typename YieldValueType = void>
  46.     class Fiber {
  47.         typedef std::function<YieldValueType ()> FuncType;
  48.  
  49.         /** \brief 纤程参数结构体
  50.          *
  51.          * 通过子纤程入口函数参数传入子纤程
  52.          */
  53.         struct Param {
  54.             template <typename Fn>
  55.             Param(Fn&& f)
  56.                 : func(std::forward<Fn>(f))
  57. #ifndef _WIN32
  58.                 // 只能放在堆里,否则SIGSRV,原因不知
  59.                 , fnew_stack(new uint8_t[SIGSTKSZ])
  60. #endif
  61.                 {};
  62.             /// \brief 子纤程返回值
  63.             YieldValueType yieldedValue;
  64.             /// \brief 存储子纤程抛出的异常
  65.             std::exception_ptr eptr;
  66.             /// \brief 子纤程是否结束
  67.             bool flagFinish;
  68.             /// \brief 真子纤程入口
  69.             FuncType func;
  70.             /// \brief 纤程信息
  71. #ifdef _WIN32
  72.             PVOID pMainFiber, pNewFiber;
  73.             bool isFormerAThread;
  74. #else
  75.             ucontext_t ctx_main, ctx_fnew;
  76.             const std::unique_ptr<uint8_t[]> fnew_stack;
  77. #endif
  78.         } param;
  79.  
  80.     public:
  81.         /** \brief 构造函数
  82.          *
  83.          * 把主线程转化为纤程,并创建一个子纤程
  84.          *
  85.          * \param 子纤程入口
  86.          */
  87.         template <typename Fn>
  88.         Fiber(Fn&& f): param(std::forward<Fn>(f)) {
  89. #ifdef _WIN32
  90. #   if defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x600
  91.             // 貌似MinGW杯具?
  92.             param.isFormerAThread = IsThreadAFiber() == FALSE;
  93. #   else
  94.             param.isFormerAThread = FiberHelper::bNeedConvert;
  95. #   endif
  96.             if (param.isFormerAThread) {
  97. #   if defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x502
  98.                 param.pMainFiber = ::ConvertThreadToFiberEx(nullptr, FIBER_FLAG_FLOAT_SWITCH);
  99. #   else
  100. #       warning See: msdn.microsoft.com/en-us/library/ms682117
  101.                 param.pMainFiber = ::ConvertThreadToFiber(nullptr);
  102. #   endif
  103. #   if !(defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x600)
  104.                 FiberHelper::bNeedConvert = false;
  105. #   endif
  106.             } else {
  107.                 param.pMainFiber = ::GetCurrentFiber();
  108.             }
  109.             // default stack size
  110. #   if defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x502
  111.             param.pNewFiber = ::CreateFiberEx(0, 0, FIBER_FLAG_FLOAT_SWITCH, &fEntry, &param);
  112. #   else
  113. #       warning See: msdn.microsoft.com/en-us/library/ms682406
  114.             param.pNewFiber = ::CreateFiber(0, &fEntry, nullptr);
  115. #   endif
  116. #else
  117.             ::getcontext(&param.ctx_fnew);
  118.             param.ctx_fnew.uc_stack.ss_sp = param.fnew_stack.get();
  119.             param.ctx_fnew.uc_stack.ss_size = SIGSTKSZ;
  120.             param.ctx_fnew.uc_link = &param.ctx_main;
  121.             ::makecontext(&param.ctx_fnew, &fEntry, 0);
  122. #endif
  123.             param.flagFinish = false;
  124.         }
  125.         /** \brief 析构函数
  126.          *
  127.          * 删除子纤程,并将主纤程转回线程
  128.          *
  129.          * \warning 主类析构时子纤程必须已经结束(return)
  130.          */
  131.         ~Fiber() {
  132.             if (!isFinished())
  133.                 std::terminate();
  134. #ifdef _WIN32
  135.             ::DeleteFiber(param.pNewFiber);
  136.             if (param.isFormerAThread) {
  137.                 ::ConvertFiberToThread();
  138. #   if !(defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x600)
  139.                 FiberHelper::bNeedConvert = true;
  140. #   endif
  141.             }
  142. #endif
  143.         }
  144.  
  145.         /** \brief 调用子纤程
  146.          *
  147.          * 程序流程转入子线程
  148.          *
  149.          * \warning 子纤程必须尚未结束
  150.          * \return 返回子纤程yield或return的值
  151.          */
  152.         YieldValueType call() {
  153.             assert(isFinished() == false);
  154.             void* oldPara = FiberHelper::paraThis;
  155.             FiberHelper::paraThis = &param;
  156. #ifdef _WIN32
  157.             ::SwitchToFiber(param.pNewFiber);
  158. #else
  159.             ::swapcontext(&param.ctx_main, &param.ctx_fnew);
  160. #endif
  161.             FiberHelper::paraThis = oldPara;
  162.             if (!(param.eptr == std::exception_ptr()))
  163.                 std::rethrow_exception(param.eptr);
  164.             return std::move(param.yieldedValue);
  165.         }
  166.  
  167.         /** \brief 判断子纤程是否结束
  168.          * \return 子纤程已经结束(return)返回true,否则false
  169.          */
  170.         bool isFinished() { return param.flagFinish; }
  171.  
  172.         /** \brief 转回主纤程并输出值
  173.          *
  174.          * \warning 必须由子纤程调用
  175.          *          参数类型必须与子纤程返回值相同,无类型安全
  176.          * \param 输出到主纤程的值
  177.          */
  178.         static void yield(YieldValueType value) {
  179.             assert(FiberHelper::paraThis != nullptr && "Fiber::yield() called with no active fiber");
  180.             Param& param = *reinterpret_cast<Param *>(FiberHelper::paraThis);
  181.             param.yieldedValue = std::move(value);
  182. #ifdef _WIN32
  183.             ::SwitchToFiber(param.pMainFiber);
  184. #else
  185.             ::swapcontext(&param.ctx_fnew, &param.ctx_main);
  186. #endif
  187.         }
  188.  
  189.     private:
  190.         /// \brief 子纤程入口的warpper
  191.  
  192. #ifdef _WIN32
  193.         static void WINAPI fEntry(void *) {
  194. #else
  195.         static void fEntry() {
  196. #endif
  197.             assert(FiberHelper::paraThis != nullptr);
  198.             Param& param = *reinterpret_cast<Param *>(FiberHelper::paraThis);
  199.             param.flagFinish = false;
  200.             try {
  201.                 param.yieldedValue = param.func();
  202.             } catch (...) {
  203.                 param.eptr = std::current_exception();
  204.             }
  205.             param.flagFinish = true;
  206. #ifdef _WIN32
  207.             ::SwitchToFiber(param.pMainFiber);
  208. #endif
  209.         }
  210.     };
  211.  
  212.     // 无返回值特化囧
  213.     template <>
  214.     class Fiber<void> {
  215.         typedef std::function<void ()> FuncType;
  216.  
  217.         struct Param {
  218.             template <typename Fn>
  219.             Param(Fn&& f)
  220.                 : func(std::forward<Fn>(f))
  221. #ifndef _WIN32
  222.                 , fnew_stack(new uint8_t[SIGSTKSZ])
  223. #endif
  224.                 {};
  225.             std::exception_ptr eptr;
  226.             bool flagFinish;
  227.             FuncType func;
  228. #ifdef _WIN32
  229.             PVOID pMainFiber, pNewFiber;
  230.             bool isFormerAThread;
  231. #else
  232.             ucontext_t ctx_main, ctx_fnew;
  233.             const std::unique_ptr<uint8_t[]> fnew_stack;
  234. #endif
  235.         } param;
  236.  
  237.     public:
  238.         template <typename Fn>
  239.         Fiber(Fn&& f): param(std::forward<Fn>(f)) {
  240. #ifdef _WIN32
  241. #   if defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x600
  242.             param.isFormerAThread = IsThreadAFiber() == FALSE;
  243. #   else
  244.             param.isFormerAThread = FiberHelper::bNeedConvert;
  245. #   endif
  246.             if (param.isFormerAThread) {
  247. #   if defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x502
  248.                 param.pMainFiber = ::ConvertThreadToFiberEx(nullptr, FIBER_FLAG_FLOAT_SWITCH);
  249. #   else
  250.                 param.pMainFiber = ::ConvertThreadToFiber(nullptr);
  251. #   endif
  252. #   if !(defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x600)
  253.                 FiberHelper::bNeedConvert = false;
  254. #   endif
  255.             } else {
  256.                 param.pMainFiber = ::GetCurrentFiber();
  257.             }
  258. #   if defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x502
  259.             param.pNewFiber = ::CreateFiberEx(0, 0, FIBER_FLAG_FLOAT_SWITCH, &fEntry, &param);
  260. #   else
  261.             param.pNewFiber = ::CreateFiber(0, &fEntry, nullptr);
  262. #   endif
  263. #else
  264.             ::getcontext(&param.ctx_fnew);
  265.             param.ctx_fnew.uc_stack.ss_sp = param.fnew_stack.get();
  266.             param.ctx_fnew.uc_stack.ss_size = SIGSTKSZ;
  267.             param.ctx_fnew.uc_link = &param.ctx_main;
  268.             ::makecontext(&param.ctx_fnew, &fEntry, 0);
  269. #endif
  270.             param.flagFinish = false;
  271.         }
  272.         ~Fiber() {
  273.             if (!isFinished())
  274.                 std::terminate();
  275. #ifdef _WIN32
  276.             ::DeleteFiber(param.pNewFiber);
  277.             if (param.isFormerAThread) {
  278.                 ::ConvertFiberToThread();
  279. #   if !(defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x600)
  280.                 FiberHelper::bNeedConvert = true;
  281. #   endif
  282.             }
  283. #endif
  284.         }
  285.         void call() {
  286.             assert(isFinished() == false);
  287.             void* oldPara = FiberHelper::paraThis;
  288.             FiberHelper::paraThis = &param;
  289. #ifdef _WIN32
  290.             ::SwitchToFiber(param.pNewFiber);
  291. #else
  292.             ::swapcontext(&param.ctx_main, &param.ctx_fnew);
  293. #endif
  294.             FiberHelper::paraThis = oldPara;
  295.             if (!(param.eptr == std::exception_ptr()))
  296.                 std::rethrow_exception(param.eptr);
  297.         }
  298.         bool isFinished() { return param.flagFinish; }
  299.         static void yield() {
  300.             assert(FiberHelper::paraThis != nullptr && "Fiber::yield() called with no active fiber");
  301.             Param& param = *reinterpret_cast<Param *>(FiberHelper::paraThis);
  302. #ifdef _WIN32
  303.             ::SwitchToFiber(param.pMainFiber);
  304. #else
  305.             ::swapcontext(&param.ctx_fnew, &param.ctx_main);
  306. #endif
  307.         }
  308.  
  309.     private:
  310. #ifdef _WIN32
  311.         static void WINAPI fEntry(void *) {
  312. #else
  313.         static void fEntry() {
  314. #endif
  315.             assert(FiberHelper::paraThis != nullptr);
  316.             Param& param = *reinterpret_cast<Param *>(FiberHelper::paraThis);
  317.             param.flagFinish = false;
  318.             try {
  319.                 param.func();
  320.             } catch (...) {
  321.                 param.eptr = std::current_exception();
  322.             }
  323.             param.flagFinish = true;
  324. #ifdef _WIN32
  325.             ::SwitchToFiber(param.pMainFiber);
  326. #endif
  327.         }
  328.     };
  329.  
  330.     /** \brief 纤程迭代器类
  331.      *
  332.      * 它通过使用 yield 函数对数组或集合类执行自定义迭代。
  333.      * 用于 C++11 for (...:...)
  334.      */
  335.     template <typename YieldValueType>
  336.     struct FiberIterator: std::iterator<std::input_iterator_tag, YieldValueType> {
  337.         /// \brief 迭代器尾
  338.         FiberIterator(): fiber(nullptr), value() {}
  339.         /** \brief 迭代器首
  340.          * \param 主线程类的引用
  341.          */
  342.         FiberIterator(Fiber<YieldValueType>& _f): fiber(&_f), value(_f.call()) {}
  343.  
  344.         /// \brief 转入子纤程
  345.         FiberIterator& operator ++() {
  346.             assert(fiber != nullptr);
  347.             if (!fiber->isFinished())
  348.                 value = fiber->call();
  349.             else
  350.                 fiber = nullptr;
  351.             return *this;
  352.         }
  353.         // 返回临时对象没问题吧-_-!!
  354.         FiberIterator operator ++(int) {
  355.             FiberIterator tmp(*this);
  356.             ++*this;
  357.             return tmp;
  358.         }
  359.  
  360.         /// \brief 取得返回值
  361.         YieldValueType& operator *() {
  362.             assert(fiber != nullptr);
  363.             return value;
  364.         }
  365.  
  366.         /** \brief 比较迭代器相等
  367.          *
  368.          * 通常用于判断迭代是否结束
  369.          * 最好别干别的 ;P
  370.          */
  371.         bool operator ==(const FiberIterator& rhs) {
  372.             return fiber == rhs.fiber;
  373.         }
  374.         bool operator !=(const FiberIterator& rhs) {
  375.             return !(*this == rhs);
  376.         }
  377.  
  378.     private:
  379.         Fiber<YieldValueType>* fiber;
  380.         YieldValueType value;
  381.     };
  382.  
  383.     /// \brief 返回迭代器首
  384.     template <typename YieldValueType>
  385.     FiberIterator<YieldValueType> begin(Fiber<YieldValueType>& fiber) {
  386.         return FiberIterator<YieldValueType>(fiber);
  387.     }
  388.  
  389.     /// \brief 返回迭代器尾
  390.     template <typename YieldValueType>
  391.     FiberIterator<YieldValueType> end(Fiber<YieldValueType>&) {
  392.         return FiberIterator<YieldValueType>();
  393.     }
  394. }
  395.  
  396. using namespace std;
  397. using FiberSpace::Fiber;
  398.  
  399. int Test(int& i) {
  400.     cout << "func, i = " << i << endl;
  401.     // 保留栈,转回主纤程,并输出值
  402.     Fiber<int>::yield(++i);
  403.     cout << "func, i = " << i << endl;
  404.     // 终止迭代,返回主纤程,并输出值
  405.     return ++i;
  406. }
  407.  
  408. int Test2(int beg, int end) {
  409.     while (beg !=end)
  410.         Fiber<int>::yield(beg++);
  411.     return beg;
  412. }
  413.  
  414. long Test3() {
  415.     auto testF = [] () -> int {
  416.         Fiber<int>::yield(1);
  417.         Fiber<int>::yield(2);
  418.         return 3;
  419.     };
  420.     for (auto i : Fiber<int>(testF))
  421.         Fiber<long>::yield(i);
  422.     return 4;
  423. }
  424.  
  425. long TestException() {
  426.     auto testF = [] () -> int {
  427.         Fiber<int> fiber([] () -> int { throw exception(); return 0; });
  428.         return fiber.call();
  429.     };
  430.     try {
  431.         return Fiber<int>(testF).call();
  432.     } catch (...) {
  433.         cout << "Exception catched in TestException()" << endl;
  434.         throw;
  435.     }
  436. }
  437.  
  438. class DerivedFiber: public Fiber<> {
  439. public:
  440.     DerivedFiber(): Fiber(std::bind(&DerivedFiber::run, this)) {}
  441.  
  442. private:
  443.     void run() {
  444.         puts("Derived fiber running.");
  445.     }
  446. };
  447.  
  448. void fiberFunc() {
  449.     puts("Composed fiber running.");
  450.     Fiber<>::yield();
  451.     puts("Composed fiber running.");
  452. }
  453.  
  454. int main() {
  455.     {
  456.     // 测试基本流程转换
  457.     int i = 0, t;
  458.     cout << "main, i = " << i << endl;
  459.     // 把主线程转化为纤程,并创建一个子纤程。参数为子纤程入口
  460.     Fiber<int> fiber(std::bind(Test, std::ref(i)));
  461.     // 流程转入子线程
  462.     t = fiber.call();
  463.     cout << "main, Test yield: " << t << endl;
  464.     t = fiber.call();
  465.     cout << "main, Test return: " << t << endl;
  466.     cout << "main, i = " << i << endl;
  467.     // 确保fiber正常析构
  468.     }
  469.     {
  470.     // Test from dlang.org
  471.     // create instances of each type
  472.     unique_ptr<DerivedFiber> derived(new DerivedFiber);
  473.     unique_ptr<Fiber<>> composed(new Fiber<void>(&fiberFunc));
  474.  
  475.     // call both fibers once
  476.     derived->call();
  477.     composed->call();
  478.     puts("Execution returned to calling context.");
  479.     composed->call();
  480.  
  481.     // since each fiber has run to completion, each should have state TERM
  482.     assert( derived->isFinished() );
  483.     assert( composed->isFinished() );
  484.     }
  485.     // 测试循环yield
  486.     for (Fiber<int> fiber(std::bind(Test2, 1, 10)); !fiber.isFinished();)
  487.         cout << fiber.call() << ' ';
  488.     cout << endl;
  489.  
  490.     // 测试返回非基本类型,foreach
  491.     // VC10: ╮(╯▽╰)╭
  492.     for (const string& s : Fiber<string>([] () -> std::string {
  493.         Fiber<string>::yield("Hello");
  494.         Fiber<string>::yield("World");
  495.         return "!!!!!";
  496.     })) {
  497.         cout << s << endl;
  498.     }
  499.  
  500.     // 测试深层调用、返回
  501.     for (auto i : Fiber<long>(Test3))
  502.         cout << i << ' ';
  503.     cout << endl;
  504. #if 0 // Test failed
  505.     {
  506.     unique_ptr<Fiber<char>> child, parent;
  507.     child.reset(new Fiber<char>([&] () -> char {
  508.         puts("before call parent");
  509.         parent->call();
  510.         puts("after call parent");
  511.         return 0;
  512.     }));
  513.     parent.reset(new Fiber<char>([&] () -> char {
  514.         puts("before call child");
  515.         child->call();
  516.         puts("after call child");
  517.         return 0;
  518.     }));
  519.     parent->call();
  520.     }
  521. #endif
  522. #if 0   // pass, but you should not write this
  523.     {
  524.     unique_ptr<Fiber<char>> f;
  525.     f.reset(new Fiber<char>([&] () -> char {
  526.         puts("f!!!");
  527.         Sleep(500);
  528.         return f->call();
  529.     }));
  530.     f->call();
  531.     }
  532. #endif
  533. #if !(defined(__GNUC__) && defined(_WIN32))
  534.     // 测试深层调用及异常安全
  535.     // Test fail with MinGW
  536.     try {
  537.         Fiber<long> fiber(TestException);
  538.         fiber.call();
  539.     } catch (exception&) {
  540.         cout << "Exception catched in main()!" << endl;
  541.     }
  542. #endif
  543. }
  544. /*
  545. Tested with VS2010(icl), MinGW-g++, archlinux-g++
  546. Expected Output:
  547. main, i = 0
  548. func, i = 0
  549. main, Test yield: 1
  550. func, i = 1
  551. main, Test return: 2
  552. main, i = 2
  553. Derived fiber running.
  554. Composed fiber running.
  555. Execution returned to calling context.
  556. Composed fiber running.
  557. 1 2 3 4 5 6 7 8 9 10
  558. Hello
  559. World
  560. !!!!!
  561. 1 2 3 4
  562. Exception catched in TestException()
  563. Exception catched in main()!
  564. */
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement