Advertisement
Guest User

Untitled

a guest
Mar 24th, 2017
103
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 32.16 KB | None | 0 0
  1. #include "fiber_tasking_lib/task_scheduler.h"
  2.  
  3.  
  4. struct NumberSubset {
  5. uint64 start;
  6. uint64 end;
  7.  
  8. uint64 total;
  9. };
  10.  
  11.  
  12. FTL_TASK_ENTRY_POINT(AddNumberSubset) {
  13. NumberSubset *subset = reinterpret_cast<NumberSubset *>(arg);
  14.  
  15. subset->total = 0;
  16.  
  17. while (subset->start != subset->end) {
  18. subset->total += subset->start;
  19. ++subset->start;
  20. }
  21.  
  22. subset->total += subset->end;
  23. }
  24.  
  25.  
  26. /**
  27. * Calculates the value of a triangle number by dividing the additions up into tasks
  28. *
  29. * A triangle number is defined as:
  30. * Tn = 1 + 2 + 3 + ... + n
  31. *
  32. * The code is checked against the numerical solution which is:
  33. * Tn = n * (n + 1) / 2
  34. */
  35. FTL_TASK_ENTRY_POINT(MainTask) {
  36. // Define the constants to test
  37. const uint64 triangleNum = 47593243ull;
  38. const uint64 numAdditionsPerTask = 10000ull;
  39. const uint64 numTasks = (triangleNum + numAdditionsPerTask - 1ull) / numAdditionsPerTask;
  40.  
  41. // Create the tasks
  42. FiberTaskingLib::Task *tasks = new FiberTaskingLib::Task[numTasks];
  43. NumberSubset *subsets = new NumberSubset[numTasks];
  44. uint64 nextNumber = 1ull;
  45.  
  46. for (uint64 i = 0ull; i < numTasks; ++i) {
  47. NumberSubset *subset = &subsets[i];
  48.  
  49. subset->start = nextNumber;
  50. subset->end = nextNumber + numAdditionsPerTask - 1ull;
  51. if (subset->end > triangleNum) {
  52. subset->end = triangleNum;
  53. }
  54.  
  55. tasks[i] = {AddNumberSubset, subset};
  56.  
  57. nextNumber = subset->end + 1;
  58. }
  59.  
  60. // Schedule the tasks and wait for them to complete
  61. std::shared_ptr<std::atomic_uint> counter = taskScheduler->AddTasks(numTasks, tasks);
  62. delete[] tasks;
  63.  
  64. taskScheduler->WaitForCounter(counter, 0);
  65.  
  66.  
  67. // Add the results
  68. uint64 result = 0ull;
  69. for (uint64 i = 0; i < numTasks; ++i) {
  70. result += subsets[i].total;
  71. }
  72.  
  73. // Test
  74. assert(triangleNum * (triangleNum + 1ull) / 2ull == result);
  75.  
  76. // Cleanup
  77. delete[] subsets;
  78. }
  79.  
  80. int main(int argc, char *argv) {
  81. FiberTaskingLib::TaskScheduler taskScheduler;
  82. taskScheduler.Run(25, MainTask);
  83.  
  84. return 0;
  85. }
  86.  
  87. #pragma once
  88.  
  89. #include "fiber_tasking_lib/typedefs.h"
  90. #include "fiber_tasking_lib/thread_abstraction.h"
  91. #include "fiber_tasking_lib/fiber.h"
  92. #include "fiber_tasking_lib/task.h"
  93. #include "fiber_tasking_lib/wait_free_queue.h"
  94.  
  95. #include <atomic>
  96. #include <vector>
  97. #include <climits>
  98. #include <memory>
  99.  
  100.  
  101. namespace FiberTaskingLib {
  102.  
  103. /**
  104. * A class that enables task-based multithreading.
  105. *
  106. * Underneath the covers, it uses fibers to allow cores to work on other tasks
  107. * when the current task is waiting on a synchronization atomic
  108. */
  109. class TaskScheduler {
  110. public:
  111. TaskScheduler();
  112. ~TaskScheduler();
  113.  
  114. private:
  115. enum {
  116. FTL_INVALID_INDEX = UINT_MAX
  117. };
  118.  
  119. std::size_t m_numThreads;
  120. std::vector<ThreadType> m_threads;
  121.  
  122. std::size_t m_fiberPoolSize;
  123. /* The backing storage for the fiber pool */
  124. Fiber *m_fibers;
  125. /**
  126. * An array of atomics, which signify if a fiber is available to be used. The indices of m_waitingFibers
  127. * correspond 1 to 1 with m_fibers. So, if m_freeFibers[i] == true, then m_fibers[i] can be used.
  128. * Each atomic acts as a lock to ensure that threads do not try to use the same fiber at the same time
  129. */
  130. std::atomic<bool> *m_freeFibers;
  131. /**
  132. * An array of atomic, which signify if a fiber is waiting for a counter. The indices of m_waitingFibers
  133. * correspond 1 to 1 with m_fibers. So, if m_waitingFibers[i] == true, then m_fibers[i] is waiting for a counter
  134. */
  135. std::atomic<bool> *m_waitingFibers;
  136.  
  137. /**
  138. * Holds a Counter that is being waited on. Specifically, until Counter == TargetValue
  139. */
  140. struct WaitingBundle {
  141. std::atomic_uint *Counter;
  142. uint TargetValue;
  143. };
  144. /**
  145. * An array of WaitingBundles, which correspond 1 to 1 with m_waitingFibers. If m_waitingFiber[i] == true,
  146. * m_waitingBundles[i] will contain the data for the waiting fiber in m_fibers[i].
  147. */
  148. std::vector<WaitingBundle> m_waitingBundles;
  149.  
  150. std::atomic_bool m_quit;
  151.  
  152. enum class FiberDestination {
  153. None = 0,
  154. ToPool = 1,
  155. ToWaiting = 2,
  156. };
  157.  
  158. /**
  159. * Holds a task that is ready to to be executed by the worker threads
  160. * Counter is the counter for the task(group). It will be decremented when the task completes
  161. */
  162. struct TaskBundle {
  163. Task TaskToExecute;
  164. std::shared_ptr<std::atomic_uint> Counter;
  165. };
  166.  
  167. struct ThreadLocalStorage {
  168. ThreadLocalStorage()
  169. : ThreadFiber(),
  170. CurrentFiberIndex(FTL_INVALID_INDEX),
  171. OldFiberIndex(FTL_INVALID_INDEX),
  172. OldFiberDestination(FiberDestination::None),
  173. TaskQueue(),
  174. LastSuccessfulSteal(1) {
  175. }
  176.  
  177. /**
  178. * Boost fibers require that fibers created from threads finish on the same thread where they started
  179. *
  180. * To accommodate this, we have save the initial fibers created in each thread, and immediately switch
  181. * out of them into the general fiber pool. Once the 'mainTask' has finished, we signal all the threads to
  182. * start quitting. When the receive the signal, they switch back to the ThreadFiber, allowing it to
  183. * safely clean up.
  184. */
  185. Fiber ThreadFiber;
  186. /* The index of the current fiber in m_fibers */
  187. std::size_t CurrentFiberIndex;
  188. /* The index of the previously executed fiber in m_fibers */
  189. std::size_t OldFiberIndex;
  190. /* Where OldFiber should be stored when we call CleanUpPoolAndWaiting() */
  191. FiberDestination OldFiberDestination;
  192. /* The queue of waiting tasks */
  193. WaitFreeQueue<TaskBundle> TaskQueue;
  194. /* The last queue that we successfully stole from. This is an offset index from the current thread index */
  195. std::size_t LastSuccessfulSteal;
  196. };
  197. /**
  198. * c++ Thread Local Storage is, by definition, static/global. This poses some problems, such as multiple TaskScheduler
  199. * instances. In addition, with Boost::Context, we have no way of telling the compiler to disable TLS optimizations, so we
  200. * have to fake TLS anyhow.
  201. *
  202. * During initialization of the TaskScheduler, we create one ThreadLocalStorage instance per thread. Threads index into
  203. * their storage using m_tls[GetCurrentThreadIndex()]
  204. */
  205. ThreadLocalStorage *m_tls;
  206.  
  207.  
  208. public:
  209. /**
  210. * Initializes the TaskScheduler and then starts executing 'mainTask'
  211. *
  212. * NOTE: Run will "block" until 'mainTask' returns. However, it doesn't block in the traditional sense; 'mainTask' is created as a Fiber.
  213. * Therefore, the current thread will save it's current state, and then switch execution to the the 'mainTask' fiber. When 'mainTask'
  214. * finishes, the thread will switch back to the saved state, and Run() will return.
  215. *
  216. * @param fiberPoolSize The size of the fiber pool. The fiber pool is used to run new tasks when the current task is waiting on a counter
  217. * @param mainTask The main task to run
  218. * @param mainTaskArg The argument to pass to 'mainTask'
  219. * @param threadPoolSize The size of the thread pool to run. 0 corresponds to NumHarewareThreads()
  220. */
  221. void Run(uint fiberPoolSize, TaskFunction mainTask, void *mainTaskArg = nullptr, uint threadPoolSize = 0);
  222.  
  223. /**
  224. * Adds a task to the internal queue.
  225. *
  226. * @param task The task to queue
  227. * @return An atomic counter corresponding to this task. Initially it will equal 1. When the task completes, it will be decremented.
  228. */
  229. std::shared_ptr<std::atomic_uint> AddTask(Task task);
  230. /**
  231. * Adds a group of tasks to the internal queue
  232. *
  233. * @param numTasks The number of tasks
  234. * @param tasks The tasks to queue
  235. * @return An atomic counter corresponding to the task group as a whole. Initially it will equal numTasks. When each task completes, it will be decremented.
  236. */
  237. std::shared_ptr<std::atomic_uint> AddTasks(uint numTasks, Task *tasks);
  238.  
  239. /**
  240. * Yields execution to another task until counter == value
  241. *
  242. * @param counter The counter to check
  243. * @param value The value to wait for
  244. */
  245. void WaitForCounter(std::shared_ptr<std::atomic_uint> &counter, uint value);
  246.  
  247. private:
  248. /**
  249. * Gets the 0-based index of the current thread
  250. * This is useful for m_tls[GetCurrentThreadIndex()]
  251. *
  252. * @return The index of the current thread
  253. */
  254. std::size_t GetCurrentThreadIndex();
  255. /**
  256. * Pops the next task off the queue into nextTask. If there are no tasks in the
  257. * the queue, it will return false.
  258. *
  259. * @param nextTask If the queue is not empty, will be filled with the next task
  260. * @return True: Successfully popped a task out of the queue
  261. */
  262. bool GetNextTask(TaskBundle *nextTask);
  263. /**
  264. * Gets the index of the next available fiber in the pool
  265. *
  266. * @return The index of the next available fiber in the pool
  267. */
  268. std::size_t GetNextFreeFiberIndex();
  269. /**
  270. * If necessary, moves the old fiber to the fiber pool or the waiting list
  271. * The old fiber is the last fiber to run on the thread before the current fiber
  272. */
  273. void CleanUpOldFiber();
  274.  
  275. /**
  276. * The threadProc function for all worker threads
  277. *
  278. * @param arg An instance of ThreadStartArgs
  279. * @return The return status of the thread
  280. */
  281. static FTL_THREAD_FUNC_DECL ThreadStart(void *arg);
  282. /**
  283. * The fiberProc function that wraps the main fiber procedure given by the user
  284. *
  285. * @param arg An instance of TaskScheduler
  286. */
  287. static void MainFiberStart(void *arg);
  288. /**
  289. * The fiberProc function for all fibers in the fiber pool
  290. *
  291. * @param arg An instance of TaskScheduler
  292. */
  293. static void FiberStart(void *arg);
  294. };
  295.  
  296. } // End of namespace FiberTaskingLib
  297.  
  298. #pragma once
  299.  
  300. #include "fiber_tasking_lib/config.h"
  301.  
  302. #include <boost_context/fcontext.h>
  303.  
  304. #include <cassert>
  305. #include <cstdlib>
  306. #include <algorithm>
  307.  
  308. #if defined(FTL_VALGRIND)
  309. #include <valgrind/valgrind.h>
  310. #endif
  311.  
  312. #if defined(FTL_FIBER_STACK_GUARD_PAGES)
  313. #if defined(FTL_OS_LINUX) || defined(FTL_OS_MAC) || defined(FTL_iOS)
  314. #include <sys/mman.h>
  315. #include <unistd.h>
  316. #elif defined(FTL_OS_WINDOWS)
  317. #define WIN32_LEAN_AND_MEAN
  318. #include <Windows.h>
  319. #endif
  320. #endif
  321.  
  322.  
  323. namespace FiberTaskingLib {
  324.  
  325. #if defined(FTL_VALGRIND)
  326. #define FTL_VALGRIND_ID uint m_stackId
  327.  
  328. #define FTL_VALGRIND_REGISTER(s, e)
  329. m_stackId = VALGRIND_STACK_REGISTER(s, e)
  330.  
  331. #define SEW_VALGRIND_DEREGISTER() VALGRIND_STACK_DEREGISTER(m_stackId)
  332. #else
  333. #define FTL_VALGRIND_ID
  334. #define FTL_VALGRIND_REGISTER(s, e)
  335. #define FTL_VALGRIND_DEREGISTER()
  336. #endif
  337.  
  338.  
  339. inline void MemoryGuard(void *memory, size_t bytes);
  340. inline void MemoryGuardRelease(void *memory, size_t bytes);
  341. inline std::size_t SystemPageSize();
  342. inline void *AlignedAlloc(std::size_t size, std::size_t alignment);
  343. inline void AlignedFree(void *block);
  344. inline std::size_t RoundUp(std::size_t numToRound, std::size_t multiple);
  345.  
  346. typedef void (*FiberStartRoutine)(void *arg);
  347.  
  348.  
  349. class Fiber {
  350. public:
  351. /**
  352. * Default constructor
  353. * Nothing is allocated. This can be used as a thread fiber.
  354. */
  355. Fiber()
  356. : m_stack(nullptr),
  357. m_systemPageSize(0),
  358. m_stackSize(0),
  359. m_context(nullptr),
  360. m_arg(0) {
  361. }
  362. /**
  363. * Allocates a stack and sets it up to start executing 'startRoutine' when first switched to
  364. *
  365. * @param stackSize The stack size for the fiber. If guard pages are being used, this will be rounded up to the next multiple of the system page size
  366. * @param startRoutine The function to run when the fiber first starts
  367. * @param arg The argument to pass to 'startRoutine'
  368. */
  369. Fiber(std::size_t stackSize, FiberStartRoutine startRoutine, void *arg)
  370. : m_arg(arg) {
  371. #if defined(FTL_FIBER_STACK_GUARD_PAGES)
  372. m_systemPageSize = SystemPageSize();
  373. #else
  374. m_systemPageSize = 0;
  375. #endif
  376.  
  377. m_stackSize = RoundUp(stackSize, m_systemPageSize);
  378. // We add a guard page both the top and the bottom of the stack
  379. m_stack = AlignedAlloc(m_systemPageSize + m_stackSize + m_systemPageSize, m_systemPageSize);
  380. m_context = boost_context::make_fcontext(static_cast<char *>(m_stack) + m_systemPageSize + stackSize, stackSize, startRoutine);
  381.  
  382. FTL_VALGRIND_REGISTER(static_cast<char *>(m_stack) + m_systemPageSize + stackSize, static_cast<char *>(m_stack) + m_systemPageSize);
  383. #if defined(FTL_FIBER_STACK_GUARD_PAGES)
  384. MemoryGuard(static_cast<char *>(m_stack), m_systemPageSize);
  385. MemoryGuard(static_cast<char *>(m_stack) + m_systemPageSize + stackSize, m_systemPageSize);
  386. #endif
  387. }
  388.  
  389. /**
  390. * Deleted copy constructor
  391. * It makes no sense to copy a stack and its corresponding context. Therefore, we explicitly forbid it.
  392. */
  393. Fiber(const Fiber &other) = delete;
  394. /**
  395. * Move constructor
  396. * This does a swap() of all the member variables
  397. *
  398. * @param other
  399. *
  400. * @return
  401. */
  402. Fiber(Fiber &&other)
  403. : Fiber() {
  404. swap(*this, other);
  405. }
  406.  
  407. /**
  408. * Move assignment operator
  409. * This does a swap() of all the member variables
  410. *
  411. * @param other The fiber to move
  412. */
  413. Fiber &operator=(Fiber &&other) {
  414. swap(*this, other);
  415.  
  416. return *this;
  417. }
  418. ~Fiber() {
  419. if (m_stack != nullptr) {
  420. if (m_systemPageSize != 0) {
  421. MemoryGuardRelease(static_cast<char *>(m_stack), m_systemPageSize);
  422. MemoryGuardRelease(static_cast<char *>(m_stack) + m_systemPageSize + m_stackSize, m_systemPageSize);
  423. }
  424. FTL_VALGRIND_DEREGISTER();
  425.  
  426. AlignedFree(m_stack);
  427. }
  428. }
  429.  
  430. private:
  431. void *m_stack;
  432. std::size_t m_systemPageSize;
  433. std::size_t m_stackSize;
  434. boost_context::fcontext_t m_context;
  435. void *m_arg;
  436. FTL_VALGRIND_ID;
  437.  
  438. public:
  439. /**
  440. * Saves the current stack context and then switches to the given fiber
  441. * Execution will resume here once another fiber switches to this fiber
  442. *
  443. * @param fiber The fiber to switch to
  444. */
  445. void SwitchToFiber(Fiber *fiber) {
  446. boost_context::jump_fcontext(&m_context, fiber->m_context, fiber->m_arg);
  447. }
  448. /**
  449. * Re-initializes the stack with a new startRoutine and arg
  450. *
  451. * NOTE: This can NOT be called on a fiber that has m_stack == nullptr || m_stackSize == 0
  452. * AKA, a default constructed fiber.
  453. *
  454. * @param startRoutine The function to run when the fiber is next switched to
  455. * @param arg The arg for 'startRoutine'
  456. *
  457. * @return
  458. */
  459. void Reset(FiberStartRoutine startRoutine, void *arg) {
  460. m_context = boost_context::make_fcontext(static_cast<char *>(m_stack) + m_stackSize, m_stackSize, startRoutine);
  461. m_arg = arg;
  462. }
  463.  
  464. private:
  465. /**
  466. * Helper function for the move operators
  467. * Swaps all the member variables
  468. *
  469. * @param first The first fiber
  470. * @param second The second fiber
  471. */
  472. void swap(Fiber &first, Fiber &second) {
  473. using std::swap;
  474.  
  475. swap(first.m_stack, second.m_stack);
  476. swap(first.m_systemPageSize, second.m_systemPageSize);
  477. swap(first.m_stackSize, second.m_stackSize);
  478. swap(first.m_context, second.m_context);
  479. swap(first.m_arg, second.m_arg);
  480. }
  481. };
  482.  
  483. #if defined(FTL_FIBER_STACK_GUARD_PAGES)
  484. #if defined(FTL_OS_LINUX) || defined(FTL_OS_MAC) || defined(FTL_iOS)
  485. inline void MemoryGuard(void *memory, size_t bytes) {
  486. int result = mprotect(memory, bytes, PROT_NONE);
  487. if(result) {
  488. perror("mprotect failed with error:");
  489. assert(!result);
  490. }
  491. }
  492.  
  493. inline void MemoryGuardRelease(void *memory, size_t bytes) {
  494. int result = mprotect(memory, bytes, PROT_READ | PROT_WRITE);
  495. if(result) {
  496. perror("mprotect failed with error:");
  497. assert(!result);
  498. }
  499. }
  500.  
  501. inline std::size_t SystemPageSize() {
  502. int pageSize = getpagesize();
  503. return pageSize;
  504. }
  505.  
  506. inline void *AlignedAlloc(std::size_t size, std::size_t alignment) {
  507. void *returnPtr;
  508. posix_memalign(&returnPtr, alignment, size);
  509.  
  510. return returnPtr;
  511. }
  512.  
  513. inline void AlignedFree(void *block) {
  514. free(block);
  515. }
  516. #elif defined(FTL_OS_WINDOWS)
  517. inline void MemoryGuard(void *memory, size_t bytes) {
  518. DWORD ignored;
  519.  
  520. BOOL result = VirtualProtect(memory, bytes, PAGE_NOACCESS, &ignored);
  521. assert(result);
  522. }
  523.  
  524. inline void MemoryGuardRelease(void *memory, size_t bytes) {
  525. DWORD ignored;
  526.  
  527. BOOL result = VirtualProtect(memory, bytes, PAGE_READWRITE, &ignored);
  528. assert(result);
  529. }
  530.  
  531. inline std::size_t SystemPageSize() {
  532. SYSTEM_INFO sysInfo;
  533. GetSystemInfo(&sysInfo);
  534. return sysInfo.dwPageSize;
  535. }
  536.  
  537. inline void *AlignedAlloc(std::size_t size, std::size_t alignment) {
  538. return _aligned_malloc(size, alignment);
  539. }
  540.  
  541. inline void AlignedFree(void *block) {
  542. _aligned_free(block);
  543. }
  544. #else
  545. #error "Need a way to protect memory for this platform".
  546. #endif
  547. #else
  548. inline void MemoryGuard(void *memory, size_t bytes) {
  549. (void)memory;
  550. (void)bytes;
  551. }
  552.  
  553. inline void MemoryGuardRelease(void *memory, size_t bytes) {
  554. (void)memory;
  555. (void)bytes;
  556. }
  557.  
  558. inline std::size_t SystemPageSize() {
  559. return 0;
  560. }
  561.  
  562. inline void *AlignedAlloc(std::size_t size, std::size_t alignment) {
  563. return malloc(size);
  564. }
  565.  
  566. inline void AlignedFree(void *block) {
  567. free(block);
  568. }
  569. #endif
  570.  
  571. inline std::size_t RoundUp(std::size_t numToRound, std::size_t multiple) {
  572. if (multiple == 0) {
  573. return numToRound;
  574. }
  575.  
  576. std::size_t remainder = numToRound % multiple;
  577. if (remainder == 0)
  578. return numToRound;
  579.  
  580. return numToRound + multiple - remainder;
  581. }
  582.  
  583. } // End of namespace FiberTaskingLib
  584.  
  585. #include "fiber_tasking_lib/task_scheduler.h"
  586.  
  587.  
  588. namespace FiberTaskingLib {
  589.  
  590. TaskScheduler::TaskScheduler()
  591. : m_numThreads(0),
  592. m_fiberPoolSize(0),
  593. m_fibers(nullptr),
  594. m_freeFibers(nullptr),
  595. m_waitingFibers(nullptr),
  596. m_tls(nullptr) {
  597. }
  598.  
  599. TaskScheduler::~TaskScheduler() {
  600. delete[] m_fibers;
  601. delete[] m_freeFibers;
  602. delete[] m_waitingFibers;
  603. delete[] m_tls;
  604. }
  605.  
  606. void TaskScheduler::Run(uint fiberPoolSize, TaskFunction mainTask, void *mainTaskArg, uint threadPoolSize) {
  607. // Create and populate the fiber pool
  608. m_fiberPoolSize = fiberPoolSize;
  609. m_fibers = new Fiber[fiberPoolSize];
  610. m_freeFibers = new std::atomic<bool>[fiberPoolSize];
  611. m_waitingFibers = new std::atomic<bool>[fiberPoolSize];
  612.  
  613. for (uint i = 0; i < fiberPoolSize; ++i) {
  614. m_fibers[i] = std::move(Fiber(512000, FiberStart, this));
  615. m_freeFibers[i].store(true, std::memory_order_release);
  616. m_waitingFibers[i].store(false, std::memory_order_release);
  617. }
  618. m_waitingBundles.resize(fiberPoolSize);
  619.  
  620. if (threadPoolSize == 0) {
  621. // 1 thread for each logical processor
  622. m_numThreads = GetNumHardwareThreads();
  623. } else {
  624. m_numThreads = threadPoolSize;
  625. }
  626.  
  627. // Initialize all the things
  628. m_quit.store(false, std::memory_order_release);
  629. m_threads.resize(m_numThreads);
  630. m_tls = new ThreadLocalStorage[m_numThreads];
  631.  
  632. // Set the properties for the current thread
  633. SetCurrentThreadAffinity(1);
  634. m_threads[0] = GetCurrentThread();
  635.  
  636. // Create the remaining threads
  637. for (uint i = 1; i < m_numThreads; ++i) {
  638. ThreadStartArgs *threadArgs = new ThreadStartArgs();
  639. threadArgs->taskScheduler = this;
  640. threadArgs->threadIndex = i;
  641.  
  642. if (!CreateThread(524288, ThreadStart, threadArgs, i, &m_threads[i])) {
  643. printf("Error: Failed to create all the worker threads");
  644. return;
  645. }
  646. }
  647.  
  648.  
  649. // Start the main task
  650.  
  651. // Get a free fiber
  652. std::size_t freeFiberIndex = GetNextFreeFiberIndex();
  653. Fiber *freeFiber = &m_fibers[freeFiberIndex];
  654.  
  655. // Repurpose it as the main task fiber and switch to it
  656. MainFiberStartArgs mainFiberArgs;
  657. mainFiberArgs.taskScheduler = this;
  658. mainFiberArgs.MainTask = mainTask;
  659. mainFiberArgs.Arg = mainTaskArg;
  660.  
  661. freeFiber->Reset(MainFiberStart, &mainFiberArgs);
  662. m_tls[0].CurrentFiberIndex = freeFiberIndex;
  663. m_tls[0].ThreadFiber.SwitchToFiber(freeFiber);
  664.  
  665.  
  666. // And we're back
  667. // Wait for the worker threads to finish
  668. for (std::size_t i = 1; i < m_numThreads; ++i) {
  669. JoinThread(m_threads[i]);
  670. }
  671.  
  672. return;
  673. }
  674.  
  675. std::shared_ptr<std::atomic_uint> TaskScheduler::AddTask(Task task) {
  676. std::shared_ptr<std::atomic_uint> counter(new std::atomic_uint());
  677. counter->store(1);
  678.  
  679. TaskBundle bundle = {task, counter};
  680. ThreadLocalStorage &tls = m_tls[GetCurrentThreadIndex()];
  681. tls.TaskQueue.Push(bundle);
  682.  
  683. return counter;
  684. }
  685.  
  686. std::shared_ptr<std::atomic_uint> TaskScheduler::AddTasks(uint numTasks, Task *tasks) {
  687. std::shared_ptr<std::atomic_uint> counter(new std::atomic_uint());
  688. counter->store(numTasks);
  689.  
  690. ThreadLocalStorage &tls = m_tls[GetCurrentThreadIndex()];
  691. for (uint i = 0; i < numTasks; ++i) {
  692. TaskBundle bundle = {tasks[i], counter};
  693. tls.TaskQueue.Push(bundle);
  694. }
  695.  
  696. return counter;
  697. }
  698.  
  699. std::size_t TaskScheduler::GetCurrentThreadIndex() {
  700. #if defined(FTL_WIN32_THREADS)
  701. DWORD threadId = GetCurrentThreadId();
  702. for (std::size_t i = 0; i < m_numThreads; ++i) {
  703. if (m_threads[i].Id == threadId) {
  704. return i;
  705. }
  706. }
  707. #elif defined(FTL_POSIX_THREADS)
  708. pthread_t currentThread = pthread_self();
  709. for (std::size_t i = 0; i < m_numThreads; ++i) {
  710. if (pthread_equal(currentThread, m_threads[i])) {
  711. return i;
  712. }
  713. }
  714. #endif
  715.  
  716. return FTL_INVALID_INDEX;
  717. }
  718.  
  719. bool TaskScheduler::GetNextTask(TaskBundle *nextTask) {
  720. std::size_t currentThreadIndex = GetCurrentThreadIndex();
  721. ThreadLocalStorage &tls = m_tls[currentThreadIndex];
  722.  
  723. // Try to pop from our own queue
  724. if (tls.TaskQueue.Pop(nextTask)) {
  725. return true;
  726. }
  727.  
  728. // Ours is empty, try to steal from the others'
  729. std::size_t threadIndex = tls.LastSuccessfulSteal;
  730. for (std::size_t i = 0; i < m_numThreads; ++i) {
  731. const std::size_t threadIndexToStealFrom = (threadIndex + i) % m_numThreads;
  732. if (threadIndexToStealFrom == currentThreadIndex) {
  733. continue;
  734. }
  735. ThreadLocalStorage &otherTLS = m_tls[threadIndexToStealFrom];
  736. if (otherTLS.TaskQueue.Steal(nextTask)) {
  737. tls.LastSuccessfulSteal = i;
  738. return true;
  739. }
  740. }
  741.  
  742. return false;
  743. }
  744.  
  745. std::size_t TaskScheduler::GetNextFreeFiberIndex() {
  746. for (uint j = 0; ; ++j) {
  747. for (std::size_t i = 0; i < m_fiberPoolSize; ++i) {
  748. // Double lock
  749. if (!m_freeFibers[i].load(std::memory_order_relaxed)) {
  750. continue;
  751. }
  752.  
  753. if (!m_freeFibers[i].load(std::memory_order_acquire)) {
  754. continue;
  755. }
  756.  
  757. bool expected = true;
  758. if (std::atomic_compare_exchange_weak_explicit(&m_freeFibers[i], &expected, false, std::memory_order_release, std::memory_order_relaxed)) {
  759. return i;
  760. }
  761. }
  762.  
  763. if (j > 10) {
  764. printf("No free fibers in the pool. Possible deadlock");
  765. }
  766. }
  767. }
  768.  
  769. void TaskScheduler::CleanUpOldFiber() {
  770. // Clean up from the last Fiber to run on this thread
  771. //
  772. // Explanation:
  773. // When switching between fibers, there's the innate problem of tracking the fibers.
  774. // For example, let's say we discover a waiting fiber that's ready. We need to put the currently
  775. // running fiber back into the fiber pool, and then switch to the waiting fiber. However, we can't
  776. // just do the equivalent of:
  777. // m_fibers.Push(currentFiber)
  778. // currentFiber.SwitchToFiber(waitingFiber)
  779. // In the time between us adding the current fiber to the fiber pool and switching to the waiting fiber, another
  780. // thread could come along and pop the current fiber from the fiber pool and try to run it.
  781. // This leads to stack corruption and/or other undefined behavior.
  782. //
  783. // In the previous implementation of TaskScheduler, we used helper fibers to do this work for us.
  784. // AKA, we stored currentFiber and waitingFiber in TLS, and then switched to the helper fiber. The
  785. // helper fiber then did:
  786. // m_fibers.Push(currentFiber)
  787. // helperFiber.SwitchToFiber(waitingFiber)
  788. // If we have 1 helper fiber per thread, we can guarantee that currentFiber is free to be executed by any thread
  789. // once it is added back to the fiber pool
  790. //
  791. // This solution works well, however, we actually don't need the helper fibers
  792. // The code structure guarantees that between any two fiber switches, the code will always end up in WaitForCounter or FibeStart.
  793. // Therefore, instead of using a helper fiber and immediately pushing the fiber to the fiber pool or waiting list,
  794. // we defer the push until the next fiber gets to one of those two places
  795. //
  796. // Proof:
  797. // There are only two places where we switch fibers:
  798. // 1. When we're waiting for a counter, we pull a new fiber from the fiber pool and switch to it.
  799. // 2. When we found a counter that's ready, we put the current fiber back in the fiber pool, and switch to the waiting fiber.
  800. //
  801. // Case 1:
  802. // A fiber from the pool will always either be completely new or just come back from switching to a waiting fiber
  803. // The while and the if/else in FiberStart will guarantee the code will call CleanUpOldFiber() before executing any other fiber.
  804. // QED
  805. //
  806. // Case 2:
  807. // A waiting fiber can do two things:
  808. // a. Finish the task and return
  809. // b. Wait on another counter
  810. // In case a, the while loop and if/else will again guarantee the code will call CleanUpOldFiber() before executing any other fiber.
  811. // In case b, WaitOnCounter will directly call CleanUpOldFiber(). Any further code is just a recursion.
  812. // QED
  813.  
  814. // In this specific implementation, the fiber pool and waiting list are flat arrays signaled by atomics
  815. // So in order to "Push" the fiber to the fiber pool or waiting list, we just set their corresponding atomics to true
  816. ThreadLocalStorage &tls = m_tls[GetCurrentThreadIndex()];
  817. switch (tls.OldFiberDestination) {
  818. case FiberDestination::ToPool:
  819. m_freeFibers[tls.OldFiberIndex].store(true, std::memory_order_release);
  820. tls.OldFiberDestination = FiberDestination::None;
  821. tls.OldFiberIndex = FTL_INVALID_INDEX;
  822. break;
  823. case FiberDestination::ToWaiting:
  824. m_waitingFibers[tls.OldFiberIndex].store(true, std::memory_order_release);
  825. tls.OldFiberDestination = FiberDestination::None;
  826. tls.OldFiberIndex = FTL_INVALID_INDEX;
  827. break;
  828. case FiberDestination::None:
  829. default:
  830. break;
  831. }
  832. }
  833.  
  834. void TaskScheduler::WaitForCounter(std::shared_ptr<std::atomic_uint> &counter, uint value) {
  835. // Fast out
  836. if (counter->load(std::memory_order_relaxed) == value) {
  837. return;
  838. }
  839.  
  840. ThreadLocalStorage &tls = m_tls[GetCurrentThreadIndex()];
  841.  
  842. // Fill in the WaitingBundle data
  843. WaitingBundle &bundle = m_waitingBundles[tls.CurrentFiberIndex];
  844. bundle.Counter = counter.get();
  845. bundle.TargetValue = value;
  846.  
  847. // Get a free fiber
  848. std::size_t freeFiberIndex = GetNextFreeFiberIndex();
  849.  
  850. // Clean up the old fiber
  851. CleanUpOldFiber();
  852.  
  853. // Fill in tls
  854. tls.OldFiberIndex = tls.CurrentFiberIndex;
  855. tls.CurrentFiberIndex = freeFiberIndex;
  856. tls.OldFiberDestination = FiberDestination::ToWaiting;
  857.  
  858. // Switch
  859. m_fibers[tls.OldFiberIndex].SwitchToFiber(&m_fibers[freeFiberIndex]);
  860.  
  861. // And we're back
  862. }
  863.  
  864. } // End of namespace FiberTaskingLib
  865.  
  866. struct ThreadStartArgs {
  867. TaskScheduler *taskScheduler;
  868. uint threadIndex;
  869. };
  870.  
  871. FTL_THREAD_FUNC_RETURN_TYPE TaskScheduler::ThreadStart(void *arg) {
  872. ThreadStartArgs *threadArgs = reinterpret_cast<ThreadStartArgs *>(arg);
  873. TaskScheduler *taskScheduler = threadArgs->taskScheduler;
  874. uint index = threadArgs->threadIndex;
  875.  
  876. // Clean up
  877. delete threadArgs;
  878.  
  879.  
  880. // Get a free fiber to switch to
  881. std::size_t freeFiberIndex = taskScheduler->GetNextFreeFiberIndex();
  882.  
  883. // Initialize tls
  884. taskScheduler->m_tls[index].CurrentFiberIndex = freeFiberIndex;
  885. // Switch
  886. taskScheduler->m_tls[index].ThreadFiber.SwitchToFiber(&taskScheduler->m_fibers[freeFiberIndex]);
  887.  
  888.  
  889. // And we've returned
  890.  
  891. // Cleanup and shutdown
  892. EndCurrentThread();
  893. FTL_THREAD_FUNC_END;
  894. }
  895.  
  896. struct MainFiberStartArgs {
  897. TaskFunction MainTask;
  898. void *Arg;
  899. TaskScheduler *taskScheduler;
  900. };
  901.  
  902. void TaskScheduler::FiberStart(void *arg) {
  903. TaskScheduler *taskScheduler = reinterpret_cast<TaskScheduler *>(arg);
  904.  
  905. while (!taskScheduler->m_quit.load(std::memory_order_acquire)) {
  906. // Clean up from the last fiber to run on this thread
  907. taskScheduler->CleanUpOldFiber();
  908.  
  909. // Check if any of the waiting tasks are ready
  910. std::size_t waitingFiberIndex = FTL_INVALID_INDEX;
  911.  
  912. for (std::size_t i = 0; i < taskScheduler->m_fiberPoolSize; ++i) {
  913. // Double lock
  914. if (!taskScheduler->m_waitingFibers[i].load(std::memory_order_relaxed)) {
  915. continue;
  916. }
  917.  
  918. if (!taskScheduler->m_waitingFibers[i].load(std::memory_order_acquire)) {
  919. continue;
  920. }
  921.  
  922. // Found a waiting fiber
  923. // Test if it's ready
  924. WaitingBundle *bundle = &taskScheduler->m_waitingBundles[i];
  925. if (bundle->Counter->load(std::memory_order_relaxed) != bundle->TargetValue) {
  926. continue;
  927. }
  928.  
  929. bool expected = true;
  930. if (std::atomic_compare_exchange_weak_explicit(&taskScheduler->m_waitingFibers[i], &expected, false, std::memory_order_release, std::memory_order_relaxed)) {
  931. waitingFiberIndex = i;
  932. break;
  933. }
  934. }
  935.  
  936. if (waitingFiberIndex != FTL_INVALID_INDEX) {
  937. // Found a waiting task that is ready to continue
  938. ThreadLocalStorage &tls = taskScheduler->m_tls[taskScheduler->GetCurrentThreadIndex()];
  939.  
  940. tls.OldFiberIndex = tls.CurrentFiberIndex;
  941. tls.CurrentFiberIndex = waitingFiberIndex;
  942. tls.OldFiberDestination = FiberDestination::ToPool;
  943.  
  944. // Switch
  945. taskScheduler->m_fibers[tls.OldFiberIndex].SwitchToFiber(&taskScheduler->m_fibers[tls.CurrentFiberIndex]);
  946.  
  947. // And we're back
  948. } else {
  949. // Get a new task from the queue, and execute it
  950. TaskBundle nextTask;
  951. if (!taskScheduler->GetNextTask(&nextTask)) {
  952. // Spin
  953. } else {
  954. nextTask.TaskToExecute.Function(taskScheduler, nextTask.TaskToExecute.ArgData);
  955. nextTask.Counter->fetch_sub(1);
  956. }
  957. }
  958. }
  959.  
  960.  
  961. // Start the quit sequence
  962.  
  963. // Switch to the thread fibers
  964. ThreadLocalStorage &tls = taskScheduler->m_tls[taskScheduler->GetCurrentThreadIndex()];
  965. taskScheduler->m_fibers[tls.CurrentFiberIndex].SwitchToFiber(&tls.ThreadFiber);
  966.  
  967.  
  968. // We should never get here
  969. printf("Error: FiberStart should never return");
  970. }
  971.  
  972. void TaskScheduler::MainFiberStart(void *arg) {
  973. MainFiberStartArgs *mainFiberArgs = reinterpret_cast<MainFiberStartArgs *>(arg);
  974. TaskScheduler *taskScheduler = mainFiberArgs->taskScheduler;
  975.  
  976. // Call the main task procedure
  977. mainFiberArgs->MainTask(taskScheduler, mainFiberArgs->Arg);
  978.  
  979.  
  980. // Request that all the threads quit
  981. taskScheduler->m_quit.store(true, std::memory_order_release);
  982.  
  983. // Switch to the thread fibers
  984. ThreadLocalStorage &tls = taskScheduler->m_tls[taskScheduler->GetCurrentThreadIndex()];
  985. taskScheduler->m_fibers[tls.CurrentFiberIndex].SwitchToFiber(&tls.ThreadFiber);
  986.  
  987.  
  988. // We should never get here
  989. printf("Error: FiberStart should never return");
  990. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement