Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include "fiber_tasking_lib/task_scheduler.h"
- struct NumberSubset {
- uint64 start;
- uint64 end;
- uint64 total;
- };
- FTL_TASK_ENTRY_POINT(AddNumberSubset) {
- NumberSubset *subset = reinterpret_cast<NumberSubset *>(arg);
- subset->total = 0;
- while (subset->start != subset->end) {
- subset->total += subset->start;
- ++subset->start;
- }
- subset->total += subset->end;
- }
- /**
- * Calculates the value of a triangle number by dividing the additions up into tasks
- *
- * A triangle number is defined as:
- * Tn = 1 + 2 + 3 + ... + n
- *
- * The code is checked against the numerical solution which is:
- * Tn = n * (n + 1) / 2
- */
- FTL_TASK_ENTRY_POINT(MainTask) {
- // Define the constants to test
- const uint64 triangleNum = 47593243ull;
- const uint64 numAdditionsPerTask = 10000ull;
- const uint64 numTasks = (triangleNum + numAdditionsPerTask - 1ull) / numAdditionsPerTask;
- // Create the tasks
- FiberTaskingLib::Task *tasks = new FiberTaskingLib::Task[numTasks];
- NumberSubset *subsets = new NumberSubset[numTasks];
- uint64 nextNumber = 1ull;
- for (uint64 i = 0ull; i < numTasks; ++i) {
- NumberSubset *subset = &subsets[i];
- subset->start = nextNumber;
- subset->end = nextNumber + numAdditionsPerTask - 1ull;
- if (subset->end > triangleNum) {
- subset->end = triangleNum;
- }
- tasks[i] = {AddNumberSubset, subset};
- nextNumber = subset->end + 1;
- }
- // Schedule the tasks and wait for them to complete
- std::shared_ptr<std::atomic_uint> counter = taskScheduler->AddTasks(numTasks, tasks);
- delete[] tasks;
- taskScheduler->WaitForCounter(counter, 0);
- // Add the results
- uint64 result = 0ull;
- for (uint64 i = 0; i < numTasks; ++i) {
- result += subsets[i].total;
- }
- // Test
- assert(triangleNum * (triangleNum + 1ull) / 2ull == result);
- // Cleanup
- delete[] subsets;
- }
- int main(int argc, char *argv) {
- FiberTaskingLib::TaskScheduler taskScheduler;
- taskScheduler.Run(25, MainTask);
- return 0;
- }
- #pragma once
- #include "fiber_tasking_lib/typedefs.h"
- #include "fiber_tasking_lib/thread_abstraction.h"
- #include "fiber_tasking_lib/fiber.h"
- #include "fiber_tasking_lib/task.h"
- #include "fiber_tasking_lib/wait_free_queue.h"
- #include <atomic>
- #include <vector>
- #include <climits>
- #include <memory>
- namespace FiberTaskingLib {
- /**
- * A class that enables task-based multithreading.
- *
- * Underneath the covers, it uses fibers to allow cores to work on other tasks
- * when the current task is waiting on a synchronization atomic
- */
- class TaskScheduler {
- public:
- TaskScheduler();
- ~TaskScheduler();
- private:
- enum {
- FTL_INVALID_INDEX = UINT_MAX
- };
- std::size_t m_numThreads;
- std::vector<ThreadType> m_threads;
- std::size_t m_fiberPoolSize;
- /* The backing storage for the fiber pool */
- Fiber *m_fibers;
- /**
- * An array of atomics, which signify if a fiber is available to be used. The indices of m_waitingFibers
- * correspond 1 to 1 with m_fibers. So, if m_freeFibers[i] == true, then m_fibers[i] can be used.
- * Each atomic acts as a lock to ensure that threads do not try to use the same fiber at the same time
- */
- std::atomic<bool> *m_freeFibers;
- /**
- * An array of atomic, which signify if a fiber is waiting for a counter. The indices of m_waitingFibers
- * correspond 1 to 1 with m_fibers. So, if m_waitingFibers[i] == true, then m_fibers[i] is waiting for a counter
- */
- std::atomic<bool> *m_waitingFibers;
- /**
- * Holds a Counter that is being waited on. Specifically, until Counter == TargetValue
- */
- struct WaitingBundle {
- std::atomic_uint *Counter;
- uint TargetValue;
- };
- /**
- * An array of WaitingBundles, which correspond 1 to 1 with m_waitingFibers. If m_waitingFiber[i] == true,
- * m_waitingBundles[i] will contain the data for the waiting fiber in m_fibers[i].
- */
- std::vector<WaitingBundle> m_waitingBundles;
- std::atomic_bool m_quit;
- enum class FiberDestination {
- None = 0,
- ToPool = 1,
- ToWaiting = 2,
- };
- /**
- * Holds a task that is ready to to be executed by the worker threads
- * Counter is the counter for the task(group). It will be decremented when the task completes
- */
- struct TaskBundle {
- Task TaskToExecute;
- std::shared_ptr<std::atomic_uint> Counter;
- };
- struct ThreadLocalStorage {
- ThreadLocalStorage()
- : ThreadFiber(),
- CurrentFiberIndex(FTL_INVALID_INDEX),
- OldFiberIndex(FTL_INVALID_INDEX),
- OldFiberDestination(FiberDestination::None),
- TaskQueue(),
- LastSuccessfulSteal(1) {
- }
- /**
- * Boost fibers require that fibers created from threads finish on the same thread where they started
- *
- * To accommodate this, we have save the initial fibers created in each thread, and immediately switch
- * out of them into the general fiber pool. Once the 'mainTask' has finished, we signal all the threads to
- * start quitting. When the receive the signal, they switch back to the ThreadFiber, allowing it to
- * safely clean up.
- */
- Fiber ThreadFiber;
- /* The index of the current fiber in m_fibers */
- std::size_t CurrentFiberIndex;
- /* The index of the previously executed fiber in m_fibers */
- std::size_t OldFiberIndex;
- /* Where OldFiber should be stored when we call CleanUpPoolAndWaiting() */
- FiberDestination OldFiberDestination;
- /* The queue of waiting tasks */
- WaitFreeQueue<TaskBundle> TaskQueue;
- /* The last queue that we successfully stole from. This is an offset index from the current thread index */
- std::size_t LastSuccessfulSteal;
- };
- /**
- * c++ Thread Local Storage is, by definition, static/global. This poses some problems, such as multiple TaskScheduler
- * instances. In addition, with Boost::Context, we have no way of telling the compiler to disable TLS optimizations, so we
- * have to fake TLS anyhow.
- *
- * During initialization of the TaskScheduler, we create one ThreadLocalStorage instance per thread. Threads index into
- * their storage using m_tls[GetCurrentThreadIndex()]
- */
- ThreadLocalStorage *m_tls;
- public:
- /**
- * Initializes the TaskScheduler and then starts executing 'mainTask'
- *
- * NOTE: Run will "block" until 'mainTask' returns. However, it doesn't block in the traditional sense; 'mainTask' is created as a Fiber.
- * Therefore, the current thread will save it's current state, and then switch execution to the the 'mainTask' fiber. When 'mainTask'
- * finishes, the thread will switch back to the saved state, and Run() will return.
- *
- * @param fiberPoolSize The size of the fiber pool. The fiber pool is used to run new tasks when the current task is waiting on a counter
- * @param mainTask The main task to run
- * @param mainTaskArg The argument to pass to 'mainTask'
- * @param threadPoolSize The size of the thread pool to run. 0 corresponds to NumHarewareThreads()
- */
- void Run(uint fiberPoolSize, TaskFunction mainTask, void *mainTaskArg = nullptr, uint threadPoolSize = 0);
- /**
- * Adds a task to the internal queue.
- *
- * @param task The task to queue
- * @return An atomic counter corresponding to this task. Initially it will equal 1. When the task completes, it will be decremented.
- */
- std::shared_ptr<std::atomic_uint> AddTask(Task task);
- /**
- * Adds a group of tasks to the internal queue
- *
- * @param numTasks The number of tasks
- * @param tasks The tasks to queue
- * @return An atomic counter corresponding to the task group as a whole. Initially it will equal numTasks. When each task completes, it will be decremented.
- */
- std::shared_ptr<std::atomic_uint> AddTasks(uint numTasks, Task *tasks);
- /**
- * Yields execution to another task until counter == value
- *
- * @param counter The counter to check
- * @param value The value to wait for
- */
- void WaitForCounter(std::shared_ptr<std::atomic_uint> &counter, uint value);
- private:
- /**
- * Gets the 0-based index of the current thread
- * This is useful for m_tls[GetCurrentThreadIndex()]
- *
- * @return The index of the current thread
- */
- std::size_t GetCurrentThreadIndex();
- /**
- * Pops the next task off the queue into nextTask. If there are no tasks in the
- * the queue, it will return false.
- *
- * @param nextTask If the queue is not empty, will be filled with the next task
- * @return True: Successfully popped a task out of the queue
- */
- bool GetNextTask(TaskBundle *nextTask);
- /**
- * Gets the index of the next available fiber in the pool
- *
- * @return The index of the next available fiber in the pool
- */
- std::size_t GetNextFreeFiberIndex();
- /**
- * If necessary, moves the old fiber to the fiber pool or the waiting list
- * The old fiber is the last fiber to run on the thread before the current fiber
- */
- void CleanUpOldFiber();
- /**
- * The threadProc function for all worker threads
- *
- * @param arg An instance of ThreadStartArgs
- * @return The return status of the thread
- */
- static FTL_THREAD_FUNC_DECL ThreadStart(void *arg);
- /**
- * The fiberProc function that wraps the main fiber procedure given by the user
- *
- * @param arg An instance of TaskScheduler
- */
- static void MainFiberStart(void *arg);
- /**
- * The fiberProc function for all fibers in the fiber pool
- *
- * @param arg An instance of TaskScheduler
- */
- static void FiberStart(void *arg);
- };
- } // End of namespace FiberTaskingLib
- #pragma once
- #include "fiber_tasking_lib/config.h"
- #include <boost_context/fcontext.h>
- #include <cassert>
- #include <cstdlib>
- #include <algorithm>
- #if defined(FTL_VALGRIND)
- #include <valgrind/valgrind.h>
- #endif
- #if defined(FTL_FIBER_STACK_GUARD_PAGES)
- #if defined(FTL_OS_LINUX) || defined(FTL_OS_MAC) || defined(FTL_iOS)
- #include <sys/mman.h>
- #include <unistd.h>
- #elif defined(FTL_OS_WINDOWS)
- #define WIN32_LEAN_AND_MEAN
- #include <Windows.h>
- #endif
- #endif
- namespace FiberTaskingLib {
- #if defined(FTL_VALGRIND)
- #define FTL_VALGRIND_ID uint m_stackId
- #define FTL_VALGRIND_REGISTER(s, e)
- m_stackId = VALGRIND_STACK_REGISTER(s, e)
- #define SEW_VALGRIND_DEREGISTER() VALGRIND_STACK_DEREGISTER(m_stackId)
- #else
- #define FTL_VALGRIND_ID
- #define FTL_VALGRIND_REGISTER(s, e)
- #define FTL_VALGRIND_DEREGISTER()
- #endif
- inline void MemoryGuard(void *memory, size_t bytes);
- inline void MemoryGuardRelease(void *memory, size_t bytes);
- inline std::size_t SystemPageSize();
- inline void *AlignedAlloc(std::size_t size, std::size_t alignment);
- inline void AlignedFree(void *block);
- inline std::size_t RoundUp(std::size_t numToRound, std::size_t multiple);
- typedef void (*FiberStartRoutine)(void *arg);
- class Fiber {
- public:
- /**
- * Default constructor
- * Nothing is allocated. This can be used as a thread fiber.
- */
- Fiber()
- : m_stack(nullptr),
- m_systemPageSize(0),
- m_stackSize(0),
- m_context(nullptr),
- m_arg(0) {
- }
- /**
- * Allocates a stack and sets it up to start executing 'startRoutine' when first switched to
- *
- * @param stackSize The stack size for the fiber. If guard pages are being used, this will be rounded up to the next multiple of the system page size
- * @param startRoutine The function to run when the fiber first starts
- * @param arg The argument to pass to 'startRoutine'
- */
- Fiber(std::size_t stackSize, FiberStartRoutine startRoutine, void *arg)
- : m_arg(arg) {
- #if defined(FTL_FIBER_STACK_GUARD_PAGES)
- m_systemPageSize = SystemPageSize();
- #else
- m_systemPageSize = 0;
- #endif
- m_stackSize = RoundUp(stackSize, m_systemPageSize);
- // We add a guard page both the top and the bottom of the stack
- m_stack = AlignedAlloc(m_systemPageSize + m_stackSize + m_systemPageSize, m_systemPageSize);
- m_context = boost_context::make_fcontext(static_cast<char *>(m_stack) + m_systemPageSize + stackSize, stackSize, startRoutine);
- FTL_VALGRIND_REGISTER(static_cast<char *>(m_stack) + m_systemPageSize + stackSize, static_cast<char *>(m_stack) + m_systemPageSize);
- #if defined(FTL_FIBER_STACK_GUARD_PAGES)
- MemoryGuard(static_cast<char *>(m_stack), m_systemPageSize);
- MemoryGuard(static_cast<char *>(m_stack) + m_systemPageSize + stackSize, m_systemPageSize);
- #endif
- }
- /**
- * Deleted copy constructor
- * It makes no sense to copy a stack and its corresponding context. Therefore, we explicitly forbid it.
- */
- Fiber(const Fiber &other) = delete;
- /**
- * Move constructor
- * This does a swap() of all the member variables
- *
- * @param other
- *
- * @return
- */
- Fiber(Fiber &&other)
- : Fiber() {
- swap(*this, other);
- }
- /**
- * Move assignment operator
- * This does a swap() of all the member variables
- *
- * @param other The fiber to move
- */
- Fiber &operator=(Fiber &&other) {
- swap(*this, other);
- return *this;
- }
- ~Fiber() {
- if (m_stack != nullptr) {
- if (m_systemPageSize != 0) {
- MemoryGuardRelease(static_cast<char *>(m_stack), m_systemPageSize);
- MemoryGuardRelease(static_cast<char *>(m_stack) + m_systemPageSize + m_stackSize, m_systemPageSize);
- }
- FTL_VALGRIND_DEREGISTER();
- AlignedFree(m_stack);
- }
- }
- private:
- void *m_stack;
- std::size_t m_systemPageSize;
- std::size_t m_stackSize;
- boost_context::fcontext_t m_context;
- void *m_arg;
- FTL_VALGRIND_ID;
- public:
- /**
- * Saves the current stack context and then switches to the given fiber
- * Execution will resume here once another fiber switches to this fiber
- *
- * @param fiber The fiber to switch to
- */
- void SwitchToFiber(Fiber *fiber) {
- boost_context::jump_fcontext(&m_context, fiber->m_context, fiber->m_arg);
- }
- /**
- * Re-initializes the stack with a new startRoutine and arg
- *
- * NOTE: This can NOT be called on a fiber that has m_stack == nullptr || m_stackSize == 0
- * AKA, a default constructed fiber.
- *
- * @param startRoutine The function to run when the fiber is next switched to
- * @param arg The arg for 'startRoutine'
- *
- * @return
- */
- void Reset(FiberStartRoutine startRoutine, void *arg) {
- m_context = boost_context::make_fcontext(static_cast<char *>(m_stack) + m_stackSize, m_stackSize, startRoutine);
- m_arg = arg;
- }
- private:
- /**
- * Helper function for the move operators
- * Swaps all the member variables
- *
- * @param first The first fiber
- * @param second The second fiber
- */
- void swap(Fiber &first, Fiber &second) {
- using std::swap;
- swap(first.m_stack, second.m_stack);
- swap(first.m_systemPageSize, second.m_systemPageSize);
- swap(first.m_stackSize, second.m_stackSize);
- swap(first.m_context, second.m_context);
- swap(first.m_arg, second.m_arg);
- }
- };
- #if defined(FTL_FIBER_STACK_GUARD_PAGES)
- #if defined(FTL_OS_LINUX) || defined(FTL_OS_MAC) || defined(FTL_iOS)
- inline void MemoryGuard(void *memory, size_t bytes) {
- int result = mprotect(memory, bytes, PROT_NONE);
- if(result) {
- perror("mprotect failed with error:");
- assert(!result);
- }
- }
- inline void MemoryGuardRelease(void *memory, size_t bytes) {
- int result = mprotect(memory, bytes, PROT_READ | PROT_WRITE);
- if(result) {
- perror("mprotect failed with error:");
- assert(!result);
- }
- }
- inline std::size_t SystemPageSize() {
- int pageSize = getpagesize();
- return pageSize;
- }
- inline void *AlignedAlloc(std::size_t size, std::size_t alignment) {
- void *returnPtr;
- posix_memalign(&returnPtr, alignment, size);
- return returnPtr;
- }
- inline void AlignedFree(void *block) {
- free(block);
- }
- #elif defined(FTL_OS_WINDOWS)
- inline void MemoryGuard(void *memory, size_t bytes) {
- DWORD ignored;
- BOOL result = VirtualProtect(memory, bytes, PAGE_NOACCESS, &ignored);
- assert(result);
- }
- inline void MemoryGuardRelease(void *memory, size_t bytes) {
- DWORD ignored;
- BOOL result = VirtualProtect(memory, bytes, PAGE_READWRITE, &ignored);
- assert(result);
- }
- inline std::size_t SystemPageSize() {
- SYSTEM_INFO sysInfo;
- GetSystemInfo(&sysInfo);
- return sysInfo.dwPageSize;
- }
- inline void *AlignedAlloc(std::size_t size, std::size_t alignment) {
- return _aligned_malloc(size, alignment);
- }
- inline void AlignedFree(void *block) {
- _aligned_free(block);
- }
- #else
- #error "Need a way to protect memory for this platform".
- #endif
- #else
- inline void MemoryGuard(void *memory, size_t bytes) {
- (void)memory;
- (void)bytes;
- }
- inline void MemoryGuardRelease(void *memory, size_t bytes) {
- (void)memory;
- (void)bytes;
- }
- inline std::size_t SystemPageSize() {
- return 0;
- }
- inline void *AlignedAlloc(std::size_t size, std::size_t alignment) {
- return malloc(size);
- }
- inline void AlignedFree(void *block) {
- free(block);
- }
- #endif
- inline std::size_t RoundUp(std::size_t numToRound, std::size_t multiple) {
- if (multiple == 0) {
- return numToRound;
- }
- std::size_t remainder = numToRound % multiple;
- if (remainder == 0)
- return numToRound;
- return numToRound + multiple - remainder;
- }
- } // End of namespace FiberTaskingLib
- #include "fiber_tasking_lib/task_scheduler.h"
- namespace FiberTaskingLib {
- TaskScheduler::TaskScheduler()
- : m_numThreads(0),
- m_fiberPoolSize(0),
- m_fibers(nullptr),
- m_freeFibers(nullptr),
- m_waitingFibers(nullptr),
- m_tls(nullptr) {
- }
- TaskScheduler::~TaskScheduler() {
- delete[] m_fibers;
- delete[] m_freeFibers;
- delete[] m_waitingFibers;
- delete[] m_tls;
- }
- void TaskScheduler::Run(uint fiberPoolSize, TaskFunction mainTask, void *mainTaskArg, uint threadPoolSize) {
- // Create and populate the fiber pool
- m_fiberPoolSize = fiberPoolSize;
- m_fibers = new Fiber[fiberPoolSize];
- m_freeFibers = new std::atomic<bool>[fiberPoolSize];
- m_waitingFibers = new std::atomic<bool>[fiberPoolSize];
- for (uint i = 0; i < fiberPoolSize; ++i) {
- m_fibers[i] = std::move(Fiber(512000, FiberStart, this));
- m_freeFibers[i].store(true, std::memory_order_release);
- m_waitingFibers[i].store(false, std::memory_order_release);
- }
- m_waitingBundles.resize(fiberPoolSize);
- if (threadPoolSize == 0) {
- // 1 thread for each logical processor
- m_numThreads = GetNumHardwareThreads();
- } else {
- m_numThreads = threadPoolSize;
- }
- // Initialize all the things
- m_quit.store(false, std::memory_order_release);
- m_threads.resize(m_numThreads);
- m_tls = new ThreadLocalStorage[m_numThreads];
- // Set the properties for the current thread
- SetCurrentThreadAffinity(1);
- m_threads[0] = GetCurrentThread();
- // Create the remaining threads
- for (uint i = 1; i < m_numThreads; ++i) {
- ThreadStartArgs *threadArgs = new ThreadStartArgs();
- threadArgs->taskScheduler = this;
- threadArgs->threadIndex = i;
- if (!CreateThread(524288, ThreadStart, threadArgs, i, &m_threads[i])) {
- printf("Error: Failed to create all the worker threads");
- return;
- }
- }
- // Start the main task
- // Get a free fiber
- std::size_t freeFiberIndex = GetNextFreeFiberIndex();
- Fiber *freeFiber = &m_fibers[freeFiberIndex];
- // Repurpose it as the main task fiber and switch to it
- MainFiberStartArgs mainFiberArgs;
- mainFiberArgs.taskScheduler = this;
- mainFiberArgs.MainTask = mainTask;
- mainFiberArgs.Arg = mainTaskArg;
- freeFiber->Reset(MainFiberStart, &mainFiberArgs);
- m_tls[0].CurrentFiberIndex = freeFiberIndex;
- m_tls[0].ThreadFiber.SwitchToFiber(freeFiber);
- // And we're back
- // Wait for the worker threads to finish
- for (std::size_t i = 1; i < m_numThreads; ++i) {
- JoinThread(m_threads[i]);
- }
- return;
- }
- std::shared_ptr<std::atomic_uint> TaskScheduler::AddTask(Task task) {
- std::shared_ptr<std::atomic_uint> counter(new std::atomic_uint());
- counter->store(1);
- TaskBundle bundle = {task, counter};
- ThreadLocalStorage &tls = m_tls[GetCurrentThreadIndex()];
- tls.TaskQueue.Push(bundle);
- return counter;
- }
- std::shared_ptr<std::atomic_uint> TaskScheduler::AddTasks(uint numTasks, Task *tasks) {
- std::shared_ptr<std::atomic_uint> counter(new std::atomic_uint());
- counter->store(numTasks);
- ThreadLocalStorage &tls = m_tls[GetCurrentThreadIndex()];
- for (uint i = 0; i < numTasks; ++i) {
- TaskBundle bundle = {tasks[i], counter};
- tls.TaskQueue.Push(bundle);
- }
- return counter;
- }
- std::size_t TaskScheduler::GetCurrentThreadIndex() {
- #if defined(FTL_WIN32_THREADS)
- DWORD threadId = GetCurrentThreadId();
- for (std::size_t i = 0; i < m_numThreads; ++i) {
- if (m_threads[i].Id == threadId) {
- return i;
- }
- }
- #elif defined(FTL_POSIX_THREADS)
- pthread_t currentThread = pthread_self();
- for (std::size_t i = 0; i < m_numThreads; ++i) {
- if (pthread_equal(currentThread, m_threads[i])) {
- return i;
- }
- }
- #endif
- return FTL_INVALID_INDEX;
- }
- bool TaskScheduler::GetNextTask(TaskBundle *nextTask) {
- std::size_t currentThreadIndex = GetCurrentThreadIndex();
- ThreadLocalStorage &tls = m_tls[currentThreadIndex];
- // Try to pop from our own queue
- if (tls.TaskQueue.Pop(nextTask)) {
- return true;
- }
- // Ours is empty, try to steal from the others'
- std::size_t threadIndex = tls.LastSuccessfulSteal;
- for (std::size_t i = 0; i < m_numThreads; ++i) {
- const std::size_t threadIndexToStealFrom = (threadIndex + i) % m_numThreads;
- if (threadIndexToStealFrom == currentThreadIndex) {
- continue;
- }
- ThreadLocalStorage &otherTLS = m_tls[threadIndexToStealFrom];
- if (otherTLS.TaskQueue.Steal(nextTask)) {
- tls.LastSuccessfulSteal = i;
- return true;
- }
- }
- return false;
- }
- std::size_t TaskScheduler::GetNextFreeFiberIndex() {
- for (uint j = 0; ; ++j) {
- for (std::size_t i = 0; i < m_fiberPoolSize; ++i) {
- // Double lock
- if (!m_freeFibers[i].load(std::memory_order_relaxed)) {
- continue;
- }
- if (!m_freeFibers[i].load(std::memory_order_acquire)) {
- continue;
- }
- bool expected = true;
- if (std::atomic_compare_exchange_weak_explicit(&m_freeFibers[i], &expected, false, std::memory_order_release, std::memory_order_relaxed)) {
- return i;
- }
- }
- if (j > 10) {
- printf("No free fibers in the pool. Possible deadlock");
- }
- }
- }
- void TaskScheduler::CleanUpOldFiber() {
- // Clean up from the last Fiber to run on this thread
- //
- // Explanation:
- // When switching between fibers, there's the innate problem of tracking the fibers.
- // For example, let's say we discover a waiting fiber that's ready. We need to put the currently
- // running fiber back into the fiber pool, and then switch to the waiting fiber. However, we can't
- // just do the equivalent of:
- // m_fibers.Push(currentFiber)
- // currentFiber.SwitchToFiber(waitingFiber)
- // In the time between us adding the current fiber to the fiber pool and switching to the waiting fiber, another
- // thread could come along and pop the current fiber from the fiber pool and try to run it.
- // This leads to stack corruption and/or other undefined behavior.
- //
- // In the previous implementation of TaskScheduler, we used helper fibers to do this work for us.
- // AKA, we stored currentFiber and waitingFiber in TLS, and then switched to the helper fiber. The
- // helper fiber then did:
- // m_fibers.Push(currentFiber)
- // helperFiber.SwitchToFiber(waitingFiber)
- // If we have 1 helper fiber per thread, we can guarantee that currentFiber is free to be executed by any thread
- // once it is added back to the fiber pool
- //
- // This solution works well, however, we actually don't need the helper fibers
- // The code structure guarantees that between any two fiber switches, the code will always end up in WaitForCounter or FibeStart.
- // Therefore, instead of using a helper fiber and immediately pushing the fiber to the fiber pool or waiting list,
- // we defer the push until the next fiber gets to one of those two places
- //
- // Proof:
- // There are only two places where we switch fibers:
- // 1. When we're waiting for a counter, we pull a new fiber from the fiber pool and switch to it.
- // 2. When we found a counter that's ready, we put the current fiber back in the fiber pool, and switch to the waiting fiber.
- //
- // Case 1:
- // A fiber from the pool will always either be completely new or just come back from switching to a waiting fiber
- // The while and the if/else in FiberStart will guarantee the code will call CleanUpOldFiber() before executing any other fiber.
- // QED
- //
- // Case 2:
- // A waiting fiber can do two things:
- // a. Finish the task and return
- // b. Wait on another counter
- // In case a, the while loop and if/else will again guarantee the code will call CleanUpOldFiber() before executing any other fiber.
- // In case b, WaitOnCounter will directly call CleanUpOldFiber(). Any further code is just a recursion.
- // QED
- // In this specific implementation, the fiber pool and waiting list are flat arrays signaled by atomics
- // So in order to "Push" the fiber to the fiber pool or waiting list, we just set their corresponding atomics to true
- ThreadLocalStorage &tls = m_tls[GetCurrentThreadIndex()];
- switch (tls.OldFiberDestination) {
- case FiberDestination::ToPool:
- m_freeFibers[tls.OldFiberIndex].store(true, std::memory_order_release);
- tls.OldFiberDestination = FiberDestination::None;
- tls.OldFiberIndex = FTL_INVALID_INDEX;
- break;
- case FiberDestination::ToWaiting:
- m_waitingFibers[tls.OldFiberIndex].store(true, std::memory_order_release);
- tls.OldFiberDestination = FiberDestination::None;
- tls.OldFiberIndex = FTL_INVALID_INDEX;
- break;
- case FiberDestination::None:
- default:
- break;
- }
- }
- void TaskScheduler::WaitForCounter(std::shared_ptr<std::atomic_uint> &counter, uint value) {
- // Fast out
- if (counter->load(std::memory_order_relaxed) == value) {
- return;
- }
- ThreadLocalStorage &tls = m_tls[GetCurrentThreadIndex()];
- // Fill in the WaitingBundle data
- WaitingBundle &bundle = m_waitingBundles[tls.CurrentFiberIndex];
- bundle.Counter = counter.get();
- bundle.TargetValue = value;
- // Get a free fiber
- std::size_t freeFiberIndex = GetNextFreeFiberIndex();
- // Clean up the old fiber
- CleanUpOldFiber();
- // Fill in tls
- tls.OldFiberIndex = tls.CurrentFiberIndex;
- tls.CurrentFiberIndex = freeFiberIndex;
- tls.OldFiberDestination = FiberDestination::ToWaiting;
- // Switch
- m_fibers[tls.OldFiberIndex].SwitchToFiber(&m_fibers[freeFiberIndex]);
- // And we're back
- }
- } // End of namespace FiberTaskingLib
- struct ThreadStartArgs {
- TaskScheduler *taskScheduler;
- uint threadIndex;
- };
- FTL_THREAD_FUNC_RETURN_TYPE TaskScheduler::ThreadStart(void *arg) {
- ThreadStartArgs *threadArgs = reinterpret_cast<ThreadStartArgs *>(arg);
- TaskScheduler *taskScheduler = threadArgs->taskScheduler;
- uint index = threadArgs->threadIndex;
- // Clean up
- delete threadArgs;
- // Get a free fiber to switch to
- std::size_t freeFiberIndex = taskScheduler->GetNextFreeFiberIndex();
- // Initialize tls
- taskScheduler->m_tls[index].CurrentFiberIndex = freeFiberIndex;
- // Switch
- taskScheduler->m_tls[index].ThreadFiber.SwitchToFiber(&taskScheduler->m_fibers[freeFiberIndex]);
- // And we've returned
- // Cleanup and shutdown
- EndCurrentThread();
- FTL_THREAD_FUNC_END;
- }
- struct MainFiberStartArgs {
- TaskFunction MainTask;
- void *Arg;
- TaskScheduler *taskScheduler;
- };
- void TaskScheduler::FiberStart(void *arg) {
- TaskScheduler *taskScheduler = reinterpret_cast<TaskScheduler *>(arg);
- while (!taskScheduler->m_quit.load(std::memory_order_acquire)) {
- // Clean up from the last fiber to run on this thread
- taskScheduler->CleanUpOldFiber();
- // Check if any of the waiting tasks are ready
- std::size_t waitingFiberIndex = FTL_INVALID_INDEX;
- for (std::size_t i = 0; i < taskScheduler->m_fiberPoolSize; ++i) {
- // Double lock
- if (!taskScheduler->m_waitingFibers[i].load(std::memory_order_relaxed)) {
- continue;
- }
- if (!taskScheduler->m_waitingFibers[i].load(std::memory_order_acquire)) {
- continue;
- }
- // Found a waiting fiber
- // Test if it's ready
- WaitingBundle *bundle = &taskScheduler->m_waitingBundles[i];
- if (bundle->Counter->load(std::memory_order_relaxed) != bundle->TargetValue) {
- continue;
- }
- bool expected = true;
- if (std::atomic_compare_exchange_weak_explicit(&taskScheduler->m_waitingFibers[i], &expected, false, std::memory_order_release, std::memory_order_relaxed)) {
- waitingFiberIndex = i;
- break;
- }
- }
- if (waitingFiberIndex != FTL_INVALID_INDEX) {
- // Found a waiting task that is ready to continue
- ThreadLocalStorage &tls = taskScheduler->m_tls[taskScheduler->GetCurrentThreadIndex()];
- tls.OldFiberIndex = tls.CurrentFiberIndex;
- tls.CurrentFiberIndex = waitingFiberIndex;
- tls.OldFiberDestination = FiberDestination::ToPool;
- // Switch
- taskScheduler->m_fibers[tls.OldFiberIndex].SwitchToFiber(&taskScheduler->m_fibers[tls.CurrentFiberIndex]);
- // And we're back
- } else {
- // Get a new task from the queue, and execute it
- TaskBundle nextTask;
- if (!taskScheduler->GetNextTask(&nextTask)) {
- // Spin
- } else {
- nextTask.TaskToExecute.Function(taskScheduler, nextTask.TaskToExecute.ArgData);
- nextTask.Counter->fetch_sub(1);
- }
- }
- }
- // Start the quit sequence
- // Switch to the thread fibers
- ThreadLocalStorage &tls = taskScheduler->m_tls[taskScheduler->GetCurrentThreadIndex()];
- taskScheduler->m_fibers[tls.CurrentFiberIndex].SwitchToFiber(&tls.ThreadFiber);
- // We should never get here
- printf("Error: FiberStart should never return");
- }
- void TaskScheduler::MainFiberStart(void *arg) {
- MainFiberStartArgs *mainFiberArgs = reinterpret_cast<MainFiberStartArgs *>(arg);
- TaskScheduler *taskScheduler = mainFiberArgs->taskScheduler;
- // Call the main task procedure
- mainFiberArgs->MainTask(taskScheduler, mainFiberArgs->Arg);
- // Request that all the threads quit
- taskScheduler->m_quit.store(true, std::memory_order_release);
- // Switch to the thread fibers
- ThreadLocalStorage &tls = taskScheduler->m_tls[taskScheduler->GetCurrentThreadIndex()];
- taskScheduler->m_fibers[tls.CurrentFiberIndex].SwitchToFiber(&tls.ThreadFiber);
- // We should never get here
- printf("Error: FiberStart should never return");
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement