Advertisement
zhangsongcui

yield V2

Jul 28th, 2012
311
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. #include <functional>
  2. #include <cassert>
  3. #include <iterator>
  4. #include <iostream>
  5. #ifdef _WIN32
  6. #   ifdef WIN32_LEAD_AND_MEAN
  7. #       include <Windows.h>
  8. #   else
  9. #       define WIN32_LEAD_AND_MEAN 1
  10. #       include <Windows.h>
  11. #       undef WIN32_LEAD_AND_MEAN
  12. #   endif
  13. #else
  14. #   include <ucontext.h>
  15. #   include <memory>
  16. #endif
  17.  
  18. namespace FiberSpace {
  19.     class FiberHelper {
  20.         template <typename YieldValueType>
  21.         friend class Fiber;
  22.  
  23. #if defined(_WIN32) && (!defined(_WIN32_WINNT) || (_WIN32_WINNT) < 0x600)
  24.         // Vista之前不支持IsThreadAFiber
  25.         /// \brief 表示是否需要线纤互转
  26.         static bool bNeedConvert;
  27. #endif
  28.         /// \brief 存放当前纤程的param指针
  29.         static void* paraThis;
  30.     };
  31. #if defined(_WIN32) && (!defined(_WIN32_WINNT) || (_WIN32_WINNT) < 0x600)
  32.     bool FiberHelper::bNeedConvert = true;
  33. #endif
  34.     void* FiberHelper::paraThis = nullptr;
  35.  
  36.     /** \brief 主纤程类
  37.      *
  38.      * 目前可以嵌套开纤程,但最好只在每个纤程中新开一个纤程
  39.      *
  40.      * \warning 无线程安全
  41.      * \param 子纤程返回类型
  42.      */
  43.     template <typename YieldValueType>
  44.     class Fiber {
  45.         static_assert(!std::is_same<YieldValueType, void>::value, "A fiber which return void is unsupported");
  46.         typedef std::function<YieldValueType ()> FuncType;
  47.  
  48.         /** \brief 纤程参数结构体
  49.          *
  50.          * 通过子纤程入口函数参数传入子纤程
  51.          */
  52.         struct Param {
  53.             template <typename Fn>
  54.             Param(Fn&& f)
  55.                 : func(std::forward<Fn>(f))
  56. #ifndef _WIN32
  57.                 // 只能放在堆里,否则SIGSRV,原因不知
  58.                 , fnew_stack(new uint8_t[SIGSTKSZ])
  59. #endif
  60.                 {};
  61.             /// \brief 子纤程返回值
  62.             YieldValueType yieldedValue;
  63.             /// \brief 存储子纤程抛出的异常
  64.             std::exception_ptr eptr;
  65.             /// \brief 子纤程是否结束
  66.             bool flagFinish;
  67.             /// \brief 真子纤程入口
  68.             FuncType func;
  69.             /// \brief 纤程信息
  70. #ifdef _WIN32
  71.             PVOID pMainFiber, pNewFiber;
  72.             bool isFormerAThread;
  73. #else
  74.             ucontext_t ctx_main, ctx_fnew;
  75.             const std::unique_ptr<uint8_t[]> fnew_stack;
  76. #endif
  77.         } param;
  78.  
  79.     public:
  80.         /** \brief 构造函数
  81.          *
  82.          * 把主线程转化为纤程,并创建一个子纤程
  83.          *
  84.          * \param 子纤程入口
  85.          */
  86.         template <typename Fn>
  87.         Fiber(Fn&& f): param(std::forward<Fn>(f)) {
  88. #ifdef _WIN32
  89. #   if defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x600
  90.             // 貌似MinGW杯具?
  91.             param.isFormerAThread = IsThreadAFiber() == FALSE;
  92. #   else
  93.             param.isFormerAThread = FiberHelper::bNeedConvert;
  94. #   endif
  95.             if (param.isFormerAThread) {
  96. #   if defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x502
  97.                 param.pMainFiber = ::ConvertThreadToFiberEx(nullptr, FIBER_FLAG_FLOAT_SWITCH);
  98. #   else
  99. #       warning See: msdn.microsoft.com/en-us/library/ms682117
  100.                 param.pMainFiber = ::ConvertThreadToFiber(nullptr);
  101. #   endif
  102. #   if !(defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x600)
  103.                 FiberHelper::bNeedConvert = false;
  104. #   endif
  105.             } else {
  106.                 param.pMainFiber = ::GetCurrentFiber();
  107.             }
  108.             // default stack size
  109. #   if defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x502
  110.             param.pNewFiber = ::CreateFiberEx(0, 0, FIBER_FLAG_FLOAT_SWITCH, &fEntry, &param);
  111. #   else
  112. #       warning See: msdn.microsoft.com/en-us/library/ms682406
  113.             param.pNewFiber = ::CreateFiber(0, &fEntry, nullptr);
  114. #   endif
  115. #else
  116.             ::getcontext(&param.ctx_fnew);
  117.             param.ctx_fnew.uc_stack.ss_sp = param.fnew_stack.get();
  118.             param.ctx_fnew.uc_stack.ss_size = SIGSTKSZ;
  119.             param.ctx_fnew.uc_link = &param.ctx_main;
  120.             ::makecontext(&param.ctx_fnew, &fEntry, 0);
  121. #endif
  122.             param.flagFinish = false;
  123.         }
  124.         /** \brief 析构函数
  125.          *
  126.          * 删除子纤程,并将主纤程转回线程
  127.          *
  128.          * \warning 主类析构时子纤程必须已经结束(return)
  129.          */
  130.         ~Fiber() {
  131.             if (!isFinished())
  132.                 std::terminate();
  133. #ifdef _WIN32
  134.             ::DeleteFiber(param.pNewFiber);
  135.             if (param.isFormerAThread) {
  136.                 ::ConvertFiberToThread();
  137. #   if !(defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x600)
  138.                 FiberHelper::bNeedConvert = true;
  139. #   endif
  140.             }
  141. #endif
  142.         }
  143.  
  144.         /** \brief 调用子纤程
  145.          *
  146.          * 程序流程转入子线程
  147.          *
  148.          * \warning 子纤程必须尚未结束
  149.          * \return 返回子纤程yield或return的值
  150.          */
  151.         YieldValueType call() {
  152.             assert(isFinished() == false);
  153.             void* oldPara = FiberHelper::paraThis;
  154.             FiberHelper::paraThis = &param;
  155. #ifdef _WIN32
  156.             ::SwitchToFiber(param.pNewFiber);
  157. #else
  158.             ::swapcontext(&param.ctx_main, &param.ctx_fnew);
  159. #endif
  160.             FiberHelper::paraThis = oldPara;
  161.             if (!(param.eptr == std::exception_ptr()))
  162.                 std::rethrow_exception(param.eptr);
  163.             return std::move(param.yieldedValue);
  164.         }
  165.  
  166.         /** \brief 判断子纤程是否结束
  167.          * \return 子纤程已经结束(return)返回true,否则false
  168.          */
  169.         bool isFinished() { return param.flagFinish; }
  170.  
  171.         /** \brief 转回主纤程并输出值
  172.          *
  173.          * \warning 必须由子纤程调用
  174.          *          参数类型必须与子纤程返回值相同,无类型安全
  175.          * \param 输出到主纤程的值
  176.          */
  177.         static void yield(YieldValueType value) {
  178.             assert(FiberHelper::paraThis != nullptr && "Fiber::yield() called with no active fiber");
  179.             Param& param = *reinterpret_cast<Param *>(FiberHelper::paraThis);
  180.             param.yieldedValue = std::move(value);
  181. #ifdef _WIN32
  182.             ::SwitchToFiber(param.pMainFiber);
  183. #else
  184.             ::swapcontext(&param.ctx_fnew, &param.ctx_main);
  185. #endif
  186.         }
  187.  
  188.     private:
  189.         /// \brief 子纤程入口的warpper
  190.  
  191. #ifdef _WIN32
  192.         static void WINAPI fEntry(void *) {
  193. #else
  194.         static void fEntry() {
  195. #endif
  196.             assert(FiberHelper::paraThis != nullptr);
  197.             Param& param = *reinterpret_cast<Param *>(FiberHelper::paraThis);
  198.             param.flagFinish = false;
  199.             try {
  200.                 param.yieldedValue = param.func();
  201.             } catch (...) {
  202.                 param.eptr = std::current_exception();
  203.             }
  204.             param.flagFinish = true;
  205. #ifdef _WIN32
  206.             ::SwitchToFiber(param.pMainFiber);
  207. #endif
  208.         }
  209.     };
  210.  
  211.     /** \brief 纤程迭代器类
  212.      *
  213.      * 它通过使用 yield 函数对数组或集合类执行自定义迭代。
  214.      * 用于 C++11 for (...:...)
  215.      */
  216.     template <typename YieldValueType>
  217.     struct FiberIterator: std::iterator<std::input_iterator_tag, YieldValueType> {
  218.         /// \brief 迭代器尾
  219.         FiberIterator(): fiber(nullptr), value() {}
  220.         /** \brief 迭代器首
  221.          * \param 主线程类的引用
  222.          */
  223.         FiberIterator(Fiber<YieldValueType>& _f): fiber(&_f), value(_f.call()) {}
  224.  
  225.         /// \brief 转入子纤程
  226.         FiberIterator& operator ++() {
  227.             assert(fiber != nullptr);
  228.             if (!fiber->isFinished())
  229.                 value = fiber->call();
  230.             else
  231.                 fiber = nullptr;
  232.             return *this;
  233.         }
  234.         // 返回临时对象没问题吧-_-!!
  235.         FiberIterator operator ++(int) {
  236.             FiberIterator tmp(*this);
  237.             ++*this;
  238.             return tmp;
  239.         }
  240.  
  241.         /// \brief 取得返回值
  242.         YieldValueType& operator *() {
  243.             assert(fiber != nullptr);
  244.             return value;
  245.         }
  246.  
  247.         /** \brief 比较迭代器相等
  248.          *
  249.          * 通常用于判断迭代是否结束
  250.          * 最好别干别的 ;P
  251.          */
  252.         bool operator ==(const FiberIterator& rhs) {
  253.             return fiber == rhs.fiber;
  254.         }
  255.         bool operator !=(const FiberIterator& rhs) {
  256.             return !(*this == rhs);
  257.         }
  258.  
  259.     private:
  260.         Fiber<YieldValueType>* fiber;
  261.         YieldValueType value;
  262.     };
  263.  
  264.     /// \brief 返回迭代器首
  265.     template <typename YieldValueType>
  266.     FiberIterator<YieldValueType> begin(Fiber<YieldValueType>& fiber) {
  267.         return FiberIterator<YieldValueType>(fiber);
  268.     }
  269.  
  270.     /// \brief 返回迭代器尾
  271.     template <typename YieldValueType>
  272.     FiberIterator<YieldValueType> end(Fiber<YieldValueType>&) {
  273.         return FiberIterator<YieldValueType>();
  274.     }
  275. }
  276.  
  277. using namespace std;
  278. using FiberSpace::Fiber;
  279.  
  280. int Test(int& i) {
  281.     cout << "func, i = " << i << endl;
  282.     // 保留栈,转回主纤程,并输出值
  283.     Fiber<int>::yield(++i);
  284.     cout << "func, i = " << i << endl;
  285.     // 终止迭代,返回主纤程,并输出值
  286.     return ++i;
  287. }
  288.  
  289. int Test2(int beg, int end) {
  290.     while (beg !=end)
  291.         Fiber<int>::yield(beg++);
  292.     return beg;
  293. }
  294.  
  295. long Test3() {
  296.     auto testF = [] () -> int {
  297.         Fiber<int>::yield(1);
  298.         Fiber<int>::yield(2);
  299.         return 3;
  300.     };
  301.     for (auto i : Fiber<int>(testF))
  302.         Fiber<long>::yield(i);
  303.     return 4;
  304. }
  305.  
  306. long TestException() {
  307.     auto testF = [] () -> int {
  308.         Fiber<int> fiber([] () -> int { throw exception(); return 0; });
  309.         return fiber.call();
  310.     };
  311.     try {
  312.         return Fiber<int>(testF).call();
  313.     } catch (...) {
  314.         cout << "Exception catched in TestException()" << endl;
  315.         throw;
  316.     }
  317. }
  318.  
  319. class DerivedFiber: public Fiber<char> {
  320. public:
  321.     DerivedFiber(): Fiber(std::bind(&DerivedFiber::run, this)) {}
  322.  
  323. private:
  324.     char run() {
  325.         puts("Derived fiber running.");
  326.         return 0;
  327.     }
  328. };
  329.  
  330. char fiberFunc() {
  331.     puts("Composed fiber running.");
  332.     Fiber<char>::yield(0);
  333.     puts("Composed fiber running.");
  334.     return 0;
  335. }
  336.  
  337. int main() {
  338.     {
  339.     // 测试基本流程转换
  340.     int i = 0, t;
  341.     cout << "main, i = " << i << endl;
  342.     // 把主线程转化为纤程,并创建一个子纤程。参数为子纤程入口
  343.     Fiber<int> fiber(std::bind(Test, std::ref(i)));
  344.     // 流程转入子线程
  345.     t = fiber.call();
  346.     cout << "main, Test yield: " << t << endl;
  347.     t = fiber.call();
  348.     cout << "main, Test return: " << t << endl;
  349.     cout << "main, i = " << i << endl;
  350.     // 确保fiber正常析构
  351.     }
  352.     {
  353.     // Test from dlang.org
  354.     // create instances of each type
  355.     unique_ptr<DerivedFiber> derived(new DerivedFiber);
  356.     unique_ptr<Fiber<char>> composed(new Fiber<char>(&fiberFunc));
  357.  
  358.     // call both fibers once
  359.     derived->call();
  360.     composed->call();
  361.     puts("Execution returned to calling context.");
  362.     composed->call();
  363.  
  364.     // since each fiber has run to completion, each should have state TERM
  365.     assert( derived->isFinished() );
  366.     assert( composed->isFinished() );
  367.     }
  368.     // 测试循环yield
  369.     for (Fiber<int> fiber(std::bind(Test2, 1, 10)); !fiber.isFinished();)
  370.         cout << fiber.call() << ' ';
  371.     cout << endl;
  372.  
  373.     // 测试返回非基本类型,foreach
  374.     // VC10: ╮(╯▽╰)╭
  375.     for (const string& s : Fiber<string>([] () -> std::string {
  376.         Fiber<string>::yield("Hello");
  377.         Fiber<string>::yield("World");
  378.         return "!!!!!";
  379.     })) {
  380.         cout << s << endl;
  381.     }
  382.  
  383.     // 测试深层调用、返回
  384.     for (auto i : Fiber<long>(Test3))
  385.         cout << i << ' ';
  386.     cout << endl;
  387.  
  388. #if !(defined(__GNUC__) && defined(_WIN32))
  389.     // 测试深层调用及异常安全
  390.     // Test fail with MinGW
  391.     try {
  392.         Fiber<long> fiber(TestException);
  393.         fiber.call();
  394.     } catch (exception&) {
  395.         cout << "Exception catched in main()!" << endl;
  396.     }
  397. #endif
  398. }
  399. /*
  400. Tested with VS2010(icl), MinGW-g++, VMware-archlinux-g++
  401. Expected Output:
  402. main, i = 0
  403. func, i = 0
  404. main, Test yield: 1
  405. func, i = 1
  406. main, Test return: 2
  407. main, i = 2
  408. Derived fiber running.
  409. Composed fiber running.
  410. Execution returned to calling context.
  411. Composed fiber running.
  412. 1 2 3 4 5 6 7 8 9 10
  413. Hello
  414. World
  415. !!!!!
  416. 1 2 3 4
  417. Exception catched in TestException()
  418. Exception catched in main()!
  419. */
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement