Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #pragma once
- #include <tpcc/stdlike/atomic.hpp>
- #include <tpcc/stdlike/condition_variable.hpp>
- #include <tpcc/stdlike/mutex.hpp>
- #include <tpcc/support/compiler.hpp>
- #include <algorithm>
- #include <forward_list>
- #include <functional>
- #include <shared_mutex>
- #include <vector>
- #include <utility>
- #include <iostream>
- namespace tpcc {
- namespace solutions {
- ////////////////////////////////////////////////////////////////////////////////
- // implement writer-priority rwlock
- class ReaderWriterLock {
- public:
- // reader section / shared ownership
- void lock_shared() {
- // to be implemented
- }
- void unlock_shared() {
- // to be implemented
- }
- // writer section / exclusive ownership
- void lock() {
- // to be implemented
- }
- void unlock() {
- // to be implemented
- }
- private:
- // use mutex / condition_variable from tpcc namespace
- };
- ////////////////////////////////////////////////////////////////////////////////
- template <typename T, class HashFunction = std::hash<T>>
- class StripedHashSet {
- private:
- using RWLock = std::shared_timed_mutex; // std::shared_timed_mutex
- using ReaderLocker = std::shared_lock<RWLock>;
- using WriterLocker = std::unique_lock<RWLock>;
- // using ReaderWriterLock = std::shared_timed_mutex;
- using Bucket = std::forward_list<T>;
- using Buckets = std::vector<Bucket>;
- public:
- explicit StripedHashSet(const size_t concurrency_level = 4,
- const size_t growth_factor = 2,
- const double max_load_factor = 0.8)
- : concurrency_level_(concurrency_level),
- growth_factor_(growth_factor),
- max_load_factor_(max_load_factor),
- hash_func_(),
- table_size_(0),
- buckets_(concurrency_level_),
- rw_locks_(concurrency_level) {
- }
- bool Insert(T element) {
- const auto hash = hash_func_(element);
- auto lock = LockStripe<WriterLocker>(hash);
- if (ContainsHelper(element, hash)) {
- return false;
- }
- auto& bucket = GetBucket(hash);
- bucket.push_front(std::move(element));
- table_size_.fetch_add(1);
- if(MaxLoadFactorExceeded()) {
- const auto count = buckets_.size();
- lock.unlock();
- TryExpandTable(count);
- }
- return true;
- }
- bool Remove(const T& element) {
- const auto hash = hash_func_(element);
- auto lock = LockStripe<WriterLocker>(hash);
- if (!ContainsHelper(element, hash)) {
- return false;
- }
- auto& bucket = GetBucket(hash);
- bucket.remove(element);
- table_size_.fetch_sub(1);
- return true;
- }
- bool Contains(const T& element) const {
- const auto hash = hash_func_(element);
- LockStripe<ReaderLocker>(hash);
- return ContainsHelper(element, hash);
- }
- size_t GetSize() const {
- return table_size_.load();
- }
- size_t GetBucketCount() const {
- LockStripe<ReaderLocker>(0);
- return buckets_.size();
- }
- private:
- bool ContainsHelper(const T& element, const size_t hash) const {
- auto& buck = GetBucket(hash);
- auto result = std::find(buck.begin(), buck.end(), element);
- return result != buck.end();
- }
- size_t GetStripeIndex(const size_t hash_value) const {
- return hash_value % concurrency_level_;
- }
- template <class Locker>
- Locker LockStripe(const size_t hash_value) const {
- return Locker(rw_locks_[GetStripeIndex(hash_value)]);
- }
- size_t GetBucketIndex(const size_t hash_value) const {
- return hash_value % buckets_.size();
- }
- Bucket& GetBucket(const size_t hash_value) {
- return buckets_[GetBucketIndex(hash_value)];
- }
- const Bucket& GetBucket(const size_t hash_value) const {
- return buckets_[GetBucketIndex(hash_value)];
- }
- bool MaxLoadFactorExceeded() const {
- return table_size_.load() >= max_load_factor_ * buckets_.size();
- }
- void TryExpandTable(const size_t expected_bucket_count) {
- std::vector<WriterLocker> w_locks;
- w_locks.emplace_back(LockStripe<WriterLocker>(0));
- if (expected_bucket_count != buckets_.size()) {
- return;
- }
- for (size_t i = 1; i < concurrency_level_; i++) {
- w_locks.emplace_back(LockStripe<WriterLocker>(i));
- }
- Buckets new_buckets(buckets_.size() * growth_factor_);
- for (auto& bucket : buckets_) {
- for (auto& elem : bucket) {
- new_buckets[hash_func_(elem) % new_buckets.size()].push_front(std::move(elem));
- }
- }
- std::swap(buckets_, new_buckets);
- // Buckets new_buckets(buckets_.size() * growth_factor_);
- // std::swap(buckets_, new_buckets);
- // for (auto& bucket : new_buckets) {
- // for (auto& elem : bucket) {
- // buckets_[GetBucketIndex(hash_func_(elem))].push_front(std::move(elem));
- // }
- // }
- }
- private:
- const size_t concurrency_level_;
- const size_t growth_factor_;
- const double max_load_factor_;
- const HashFunction hash_func_;
- tpcc::atomic<size_t> table_size_;
- Buckets buckets_;
- mutable std::vector<RWLock> rw_locks_;
- };
- } // namespace solutions
- } // namespace tpcc
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement