View difference between Paste ID: wyz9QuBk and zpYiY1Qe
SHOW: | | - or go back to the newest paste.
1
#include <functional>
2
#include <cassert>
3
#include <iterator>
4
#include <iostream>
5
#include <memory>
6
#ifdef _WIN32
7
#   ifdef WIN32_LEAD_AND_MEAN
8
#       include <Windows.h>
9
#   else
10
#       define WIN32_LEAD_AND_MEAN 1
11
#       include <Windows.h>
12
#       undef WIN32_LEAD_AND_MEAN
13
#   endif
14
#elif !defined(__APPLE__)
15-
#   include <memory>
15+
16
#else
17
#   error OSX is not supported
18
#endif
19
20
namespace FiberSpace {
21
    class FiberHelper {
22
        template <typename YieldValueType>
23
        friend class Fiber;
24
25
#if defined(_WIN32) && (!defined(_WIN32_WINNT) || (_WIN32_WINNT) < 0x600)
26
        // Vista之前不支持IsThreadAFiber
27
        /// \brief 表示是否需要线纤互转
28
        static bool bNeedConvert;
29
#endif
30
        /// \brief 存放当前纤程的param指针
31
        static void* paraThis;
32
    };
33
#if defined(_WIN32) && (!defined(_WIN32_WINNT) || (_WIN32_WINNT) < 0x600)
34
    bool FiberHelper::bNeedConvert = true;
35
#endif
36
    void* FiberHelper::paraThis = nullptr;
37
38
    /** \brief 主纤程类
39
     *
40
     * 目前可以嵌套开纤程,但最好只在每个纤程中新开一个纤程
41
     *
42
     * \warning 无线程安全
43
     * \param 子纤程返回类型
44
     */
45-
        static_assert(!std::is_same<YieldValueType, void>::value, "A fiber which return void is unsupported");
45+
    template <typename YieldValueType = void>
46
    class Fiber {
47
        typedef std::function<YieldValueType ()> FuncType;
48
49
        /** \brief 纤程参数结构体
50
         *
51
         * 通过子纤程入口函数参数传入子纤程
52
         */
53
        struct Param {
54
            template <typename Fn>
55
            Param(Fn&& f)
56
                : func(std::forward<Fn>(f))
57
#ifndef _WIN32
58
                // 只能放在堆里,否则SIGSRV,原因不知
59
                , fnew_stack(new uint8_t[SIGSTKSZ])
60
#endif
61
                {};
62
            /// \brief 子纤程返回值
63
            YieldValueType yieldedValue;
64
            /// \brief 存储子纤程抛出的异常
65
            std::exception_ptr eptr;
66
            /// \brief 子纤程是否结束
67
            bool flagFinish;
68
            /// \brief 真子纤程入口
69
            FuncType func;
70
            /// \brief 纤程信息
71
#ifdef _WIN32
72
            PVOID pMainFiber, pNewFiber;
73
            bool isFormerAThread;
74
#else
75
            ucontext_t ctx_main, ctx_fnew;
76
            const std::unique_ptr<uint8_t[]> fnew_stack;
77
#endif
78
        } param;
79
80
    public:
81
        /** \brief 构造函数
82
         *
83
         * 把主线程转化为纤程,并创建一个子纤程
84
         *
85
         * \param 子纤程入口
86
         */
87
        template <typename Fn>
88
        Fiber(Fn&& f): param(std::forward<Fn>(f)) {
89
#ifdef _WIN32
90
#   if defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x600
91
            // 貌似MinGW杯具?
92
            param.isFormerAThread = IsThreadAFiber() == FALSE;
93
#   else
94
            param.isFormerAThread = FiberHelper::bNeedConvert;
95
#   endif
96
            if (param.isFormerAThread) {
97
#   if defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x502
98
                param.pMainFiber = ::ConvertThreadToFiberEx(nullptr, FIBER_FLAG_FLOAT_SWITCH);
99
#   else
100
#       warning See: msdn.microsoft.com/en-us/library/ms682117
101
                param.pMainFiber = ::ConvertThreadToFiber(nullptr);
102
#   endif
103
#   if !(defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x600)
104
                FiberHelper::bNeedConvert = false;
105
#   endif
106
            } else {
107
                param.pMainFiber = ::GetCurrentFiber();
108
            }
109
            // default stack size
110
#   if defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x502
111
            param.pNewFiber = ::CreateFiberEx(0, 0, FIBER_FLAG_FLOAT_SWITCH, &fEntry, &param);
112
#   else
113
#       warning See: msdn.microsoft.com/en-us/library/ms682406
114
            param.pNewFiber = ::CreateFiber(0, &fEntry, nullptr);
115
#   endif
116
#else
117
            ::getcontext(&param.ctx_fnew);
118
            param.ctx_fnew.uc_stack.ss_sp = param.fnew_stack.get();
119
            param.ctx_fnew.uc_stack.ss_size = SIGSTKSZ;
120
            param.ctx_fnew.uc_link = &param.ctx_main;
121
            ::makecontext(&param.ctx_fnew, &fEntry, 0);
122
#endif
123
            param.flagFinish = false;
124
        }
125
        /** \brief 析构函数
126
         *
127
         * 删除子纤程,并将主纤程转回线程
128
         *
129
         * \warning 主类析构时子纤程必须已经结束(return)
130
         */
131
        ~Fiber() {
132
            if (!isFinished())
133
                std::terminate();
134
#ifdef _WIN32
135
            ::DeleteFiber(param.pNewFiber);
136
            if (param.isFormerAThread) {
137
                ::ConvertFiberToThread();
138
#   if !(defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x600)
139
                FiberHelper::bNeedConvert = true;
140
#   endif
141
            }
142
#endif
143
        }
144
145
        /** \brief 调用子纤程
146
         *
147
         * 程序流程转入子线程
148
         *
149
         * \warning 子纤程必须尚未结束
150
         * \return 返回子纤程yield或return的值
151
         */
152
        YieldValueType call() {
153
            assert(isFinished() == false);
154
            void* oldPara = FiberHelper::paraThis;
155
            FiberHelper::paraThis = &param;
156
#ifdef _WIN32
157
            ::SwitchToFiber(param.pNewFiber);
158
#else
159
            ::swapcontext(&param.ctx_main, &param.ctx_fnew);
160
#endif
161
            FiberHelper::paraThis = oldPara;
162
            if (!(param.eptr == std::exception_ptr()))
163
                std::rethrow_exception(param.eptr);
164
            return std::move(param.yieldedValue);
165
        }
166
167
        /** \brief 判断子纤程是否结束
168
         * \return 子纤程已经结束(return)返回true,否则false
169
         */
170
        bool isFinished() { return param.flagFinish; }
171
172
        /** \brief 转回主纤程并输出值
173
         *
174
         * \warning 必须由子纤程调用
175
         *          参数类型必须与子纤程返回值相同,无类型安全
176
         * \param 输出到主纤程的值
177
         */
178
        static void yield(YieldValueType value) {
179
            assert(FiberHelper::paraThis != nullptr && "Fiber::yield() called with no active fiber");
180
            Param& param = *reinterpret_cast<Param *>(FiberHelper::paraThis);
181
            param.yieldedValue = std::move(value);
182
#ifdef _WIN32
183
            ::SwitchToFiber(param.pMainFiber);
184
#else
185
            ::swapcontext(&param.ctx_fnew, &param.ctx_main);
186
#endif
187
        }
188
189
    private:
190
        /// \brief 子纤程入口的warpper
191
192
#ifdef _WIN32
193
        static void WINAPI fEntry(void *) {
194
#else
195
        static void fEntry() {
196
#endif
197
            assert(FiberHelper::paraThis != nullptr);
198
            Param& param = *reinterpret_cast<Param *>(FiberHelper::paraThis);
199
            param.flagFinish = false;
200
            try {
201
                param.yieldedValue = param.func();
202
            } catch (...) {
203
                param.eptr = std::current_exception();
204
            }
205
            param.flagFinish = true;
206
#ifdef _WIN32
207
            ::SwitchToFiber(param.pMainFiber);
208
#endif
209
        }
210
    };
211
212
    // 无返回值特化囧
213
    template <>
214
    class Fiber<void> {
215
        typedef std::function<void ()> FuncType;
216
217
        struct Param {
218
            template <typename Fn>
219
            Param(Fn&& f)
220
                : func(std::forward<Fn>(f))
221
#ifndef _WIN32
222
                , fnew_stack(new uint8_t[SIGSTKSZ])
223
#endif
224
                {};
225
            std::exception_ptr eptr;
226
            bool flagFinish;
227
            FuncType func;
228
#ifdef _WIN32
229
            PVOID pMainFiber, pNewFiber;
230
            bool isFormerAThread;
231
#else
232
            ucontext_t ctx_main, ctx_fnew;
233
            const std::unique_ptr<uint8_t[]> fnew_stack;
234
#endif
235
        } param;
236
237
    public:
238
        template <typename Fn>
239
        Fiber(Fn&& f): param(std::forward<Fn>(f)) {
240
#ifdef _WIN32
241
#   if defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x600
242
            param.isFormerAThread = IsThreadAFiber() == FALSE;
243
#   else
244
            param.isFormerAThread = FiberHelper::bNeedConvert;
245
#   endif
246
            if (param.isFormerAThread) {
247
#   if defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x502
248
                param.pMainFiber = ::ConvertThreadToFiberEx(nullptr, FIBER_FLAG_FLOAT_SWITCH);
249
#   else
250
                param.pMainFiber = ::ConvertThreadToFiber(nullptr);
251
#   endif
252
#   if !(defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x600)
253
                FiberHelper::bNeedConvert = false;
254
#   endif
255
            } else {
256
                param.pMainFiber = ::GetCurrentFiber();
257
            }
258
#   if defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x502
259
            param.pNewFiber = ::CreateFiberEx(0, 0, FIBER_FLAG_FLOAT_SWITCH, &fEntry, &param);
260
#   else
261
            param.pNewFiber = ::CreateFiber(0, &fEntry, nullptr);
262
#   endif
263
#else
264
            ::getcontext(&param.ctx_fnew);
265
            param.ctx_fnew.uc_stack.ss_sp = param.fnew_stack.get();
266
            param.ctx_fnew.uc_stack.ss_size = SIGSTKSZ;
267
            param.ctx_fnew.uc_link = &param.ctx_main;
268
            ::makecontext(&param.ctx_fnew, &fEntry, 0);
269
#endif
270
            param.flagFinish = false;
271
        }
272
        ~Fiber() {
273
            if (!isFinished())
274
                std::terminate();
275
#ifdef _WIN32
276
            ::DeleteFiber(param.pNewFiber);
277
            if (param.isFormerAThread) {
278
                ::ConvertFiberToThread();
279
#   if !(defined(_WIN32_WINNT) && (_WIN32_WINNT) > 0x600)
280
                FiberHelper::bNeedConvert = true;
281
#   endif
282
            }
283
#endif
284
        }
285
        void call() {
286
            assert(isFinished() == false);
287
            void* oldPara = FiberHelper::paraThis;
288
            FiberHelper::paraThis = &param;
289
#ifdef _WIN32
290
            ::SwitchToFiber(param.pNewFiber);
291
#else
292
            ::swapcontext(&param.ctx_main, &param.ctx_fnew);
293
#endif
294
            FiberHelper::paraThis = oldPara;
295
            if (!(param.eptr == std::exception_ptr()))
296
                std::rethrow_exception(param.eptr);
297
        }
298
        bool isFinished() { return param.flagFinish; }
299
        static void yield() {
300
            assert(FiberHelper::paraThis != nullptr && "Fiber::yield() called with no active fiber");
301
            Param& param = *reinterpret_cast<Param *>(FiberHelper::paraThis);
302
#ifdef _WIN32
303
            ::SwitchToFiber(param.pMainFiber);
304
#else
305
            ::swapcontext(&param.ctx_fnew, &param.ctx_main);
306
#endif
307
        }
308
309
    private:
310
#ifdef _WIN32
311
        static void WINAPI fEntry(void *) {
312
#else
313
        static void fEntry() {
314
#endif
315
            assert(FiberHelper::paraThis != nullptr);
316
            Param& param = *reinterpret_cast<Param *>(FiberHelper::paraThis);
317
            param.flagFinish = false;
318
            try {
319-
class DerivedFiber: public Fiber<char> {
319+
                param.func();
320
            } catch (...) {
321
                param.eptr = std::current_exception();
322
            }
323
            param.flagFinish = true;
324-
    char run() {
324+
325
            ::SwitchToFiber(param.pMainFiber);
326
#endif
327
        }
328
    };
329
330-
char fiberFunc() {
330+
331
     *
332-
    Fiber<char>::yield(0);
332+
333
     * 用于 C++11 for (...:...)
334-
    return 0;
334+
335
    template <typename YieldValueType>
336
    struct FiberIterator: std::iterator<std::input_iterator_tag, YieldValueType> {
337
        /// \brief 迭代器尾
338
        FiberIterator(): fiber(nullptr), value() {}
339
        /** \brief 迭代器首
340
         * \param 主线程类的引用
341
         */
342
        FiberIterator(Fiber<YieldValueType>& _f): fiber(&_f), value(_f.call()) {}
343
344
        /// \brief 转入子纤程
345
        FiberIterator& operator ++() {
346
            assert(fiber != nullptr);
347
            if (!fiber->isFinished())
348
                value = fiber->call();
349
            else
350
                fiber = nullptr;
351
            return *this;
352
        }
353
        // 返回临时对象没问题吧-_-!!
354
        FiberIterator operator ++(int) {
355
            FiberIterator tmp(*this);
356-
    unique_ptr<Fiber<char>> composed(new Fiber<char>(&fiberFunc));
356+
357
            return tmp;
358
        }
359
360
        /// \brief 取得返回值
361
        YieldValueType& operator *() {
362
            assert(fiber != nullptr);
363
            return value;
364
        }
365
366
        /** \brief 比较迭代器相等
367
         *
368
         * 通常用于判断迭代是否结束
369
         * 最好别干别的 ;P
370
         */
371
        bool operator ==(const FiberIterator& rhs) {
372
            return fiber == rhs.fiber;
373
        }
374
        bool operator !=(const FiberIterator& rhs) {
375
            return !(*this == rhs);
376
        }
377
378
    private:
379
        Fiber<YieldValueType>* fiber;
380
        YieldValueType value;
381
    };
382
383
    /// \brief 返回迭代器首
384
    template <typename YieldValueType>
385
    FiberIterator<YieldValueType> begin(Fiber<YieldValueType>& fiber) {
386
        return FiberIterator<YieldValueType>(fiber);
387
    }
388
389
    /// \brief 返回迭代器尾
390
    template <typename YieldValueType>
391
    FiberIterator<YieldValueType> end(Fiber<YieldValueType>&) {
392
        return FiberIterator<YieldValueType>();
393
    }
394
}
395
396
using namespace std;
397
using FiberSpace::Fiber;
398
399
int Test(int& i) {
400-
Tested with VS2010(icl), MinGW-g++, VMware-archlinux-g++
400+
401
    // 保留栈,转回主纤程,并输出值
402
    Fiber<int>::yield(++i);
403
    cout << "func, i = " << i << endl;
404
    // 终止迭代,返回主纤程,并输出值
405
    return ++i;
406
}
407
408
int Test2(int beg, int end) {
409
    while (beg !=end)
410
        Fiber<int>::yield(beg++);
411
    return beg;
412
}
413
414
long Test3() {
415
    auto testF = [] () -> int {
416
        Fiber<int>::yield(1);
417
        Fiber<int>::yield(2);
418
        return 3;
419
    };
420
    for (auto i : Fiber<int>(testF))
421
        Fiber<long>::yield(i);
422
    return 4;
423
}
424
425
long TestException() {
426
    auto testF = [] () -> int {
427
        Fiber<int> fiber([] () -> int { throw exception(); return 0; });
428
        return fiber.call();
429
    };
430
    try {
431
        return Fiber<int>(testF).call();
432
    } catch (...) {
433
        cout << "Exception catched in TestException()" << endl;
434
        throw;
435
    }
436
}
437
438
class DerivedFiber: public Fiber<> {
439
public:
440
    DerivedFiber(): Fiber(std::bind(&DerivedFiber::run, this)) {}
441
442
private:
443
    void run() {
444
        puts("Derived fiber running.");
445
    }
446
};
447
448
void fiberFunc() {
449
    puts("Composed fiber running.");
450
    Fiber<>::yield();
451
    puts("Composed fiber running.");
452
}
453
454
int main() {
455
    {
456
    // 测试基本流程转换
457
    int i = 0, t;
458
    cout << "main, i = " << i << endl;
459
    // 把主线程转化为纤程,并创建一个子纤程。参数为子纤程入口
460
    Fiber<int> fiber(std::bind(Test, std::ref(i)));
461
    // 流程转入子线程
462
    t = fiber.call();
463
    cout << "main, Test yield: " << t << endl;
464
    t = fiber.call();
465
    cout << "main, Test return: " << t << endl;
466
    cout << "main, i = " << i << endl;
467
    // 确保fiber正常析构
468
    }
469
    {
470
    // Test from dlang.org
471
    // create instances of each type
472
    unique_ptr<DerivedFiber> derived(new DerivedFiber);
473
    unique_ptr<Fiber<>> composed(new Fiber<void>(&fiberFunc));
474
475
    // call both fibers once
476
    derived->call();
477
    composed->call();
478
    puts("Execution returned to calling context.");
479
    composed->call();
480
481
    // since each fiber has run to completion, each should have state TERM
482
    assert( derived->isFinished() );
483
    assert( composed->isFinished() );
484
    }
485
    // 测试循环yield
486
    for (Fiber<int> fiber(std::bind(Test2, 1, 10)); !fiber.isFinished();)
487
        cout << fiber.call() << ' ';
488
    cout << endl;
489
490
    // 测试返回非基本类型,foreach
491
    // VC10: ╮(╯▽╰)╭
492
    for (const string& s : Fiber<string>([] () -> std::string {
493
        Fiber<string>::yield("Hello");
494
        Fiber<string>::yield("World");
495
        return "!!!!!";
496
    })) {
497
        cout << s << endl;
498
    }
499
500
    // 测试深层调用、返回
501
    for (auto i : Fiber<long>(Test3))
502
        cout << i << ' ';
503
    cout << endl;
504
#if 0 // Test failed
505
    {
506
    unique_ptr<Fiber<char>> child, parent;
507
    child.reset(new Fiber<char>([&] () -> char {
508
        puts("before call parent");
509
        parent->call();
510
        puts("after call parent");
511
        return 0;
512
    }));
513
    parent.reset(new Fiber<char>([&] () -> char {
514
        puts("before call child");
515
        child->call();
516
        puts("after call child");
517
        return 0;
518
    }));
519
    parent->call();
520
    }
521
#endif
522
#if 0   // pass, but you should not write this
523
    {
524
    unique_ptr<Fiber<char>> f;
525
    f.reset(new Fiber<char>([&] () -> char {
526
        puts("f!!!");
527
        Sleep(500);
528
        return f->call();
529
    }));
530
    f->call();
531
    }
532
#endif
533
#if !(defined(__GNUC__) && defined(_WIN32))
534
    // 测试深层调用及异常安全
535
    // Test fail with MinGW
536
    try {
537
        Fiber<long> fiber(TestException);
538
        fiber.call();
539
    } catch (exception&) {
540
        cout << "Exception catched in main()!" << endl;
541
    }
542
#endif
543
}
544
/*
545
Tested with VS2010(icl), MinGW-g++, archlinux-g++
546
Expected Output:
547
main, i = 0
548
func, i = 0
549
main, Test yield: 1
550
func, i = 1
551
main, Test return: 2
552
main, i = 2
553
Derived fiber running.
554
Composed fiber running.
555
Execution returned to calling context.
556
Composed fiber running.
557
1 2 3 4 5 6 7 8 9 10
558
Hello
559
World
560
!!!!!
561
1 2 3 4
562
Exception catched in TestException()
563
Exception catched in main()!
564
*/