Advertisement
Guest User

Untitled

a guest
Apr 21st, 2018
61
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 6.53 KB | None | 0 0
  1. #pragma once
  2.  
  3. #include <tpcc/stdlike/atomic.hpp>
  4. #include <tpcc/stdlike/condition_variable.hpp>
  5. #include <tpcc/stdlike/mutex.hpp>
  6.  
  7. #include <tpcc/support/compiler.hpp>
  8.  
  9. #include <algorithm>
  10. #include <forward_list>
  11. #include <functional>
  12. #include <shared_mutex>
  13. #include <vector>
  14. #include <utility>
  15.  
  16. #include <iostream>
  17.  
  18. namespace tpcc {
  19.     namespace solutions {
  20.  
  21. ////////////////////////////////////////////////////////////////////////////////
  22.  
  23. // implement writer-priority rwlock
  24.         class ReaderWriterLock {
  25.         public:
  26.             // reader section / shared ownership
  27.  
  28.             void lock_shared() {
  29.                 // to be implemented
  30.             }
  31.  
  32.             void unlock_shared() {
  33.                 // to be implemented
  34.             }
  35.  
  36.             // writer section / exclusive ownership
  37.  
  38.             void lock() {
  39.                 // to be implemented
  40.             }
  41.  
  42.             void unlock() {
  43.                 // to be implemented
  44.             }
  45.  
  46.         private:
  47.             // use mutex / condition_variable from tpcc namespace
  48.         };
  49.  
  50. ////////////////////////////////////////////////////////////////////////////////
  51.  
  52.         template <typename T, class HashFunction = std::hash<T>>
  53.         class StripedHashSet {
  54.         private:
  55.             using RWLock = std::shared_timed_mutex;  // std::shared_timed_mutex
  56.  
  57.             using ReaderLocker = std::shared_lock<RWLock>;
  58.             using WriterLocker = std::unique_lock<RWLock>;
  59. //            using ReaderWriterLock = std::shared_timed_mutex;
  60.  
  61.             using Bucket = std::forward_list<T>;
  62.             using Buckets = std::vector<Bucket>;
  63.  
  64.         public:
  65.             explicit StripedHashSet(const size_t concurrency_level = 4,
  66.                                     const size_t growth_factor = 2,
  67.                                     const double max_load_factor = 0.8)
  68.                     : concurrency_level_(concurrency_level),
  69.                       growth_factor_(growth_factor),
  70.                       max_load_factor_(max_load_factor),
  71.                       hash_func_(),
  72.                       table_size_(0),
  73.                       buckets_(concurrency_level_),
  74.                       rw_locks_(concurrency_level) {
  75.             }
  76.  
  77.             bool Insert(T element) {
  78.                 const auto hash = hash_func_(element);
  79.                 auto lock = LockStripe<WriterLocker>(hash);
  80.                 if (ContainsHelper(element, hash)) {
  81.                     return false;
  82.                 }
  83.                 auto& bucket = GetBucket(hash);
  84.                 bucket.push_front(std::move(element));
  85.                 table_size_.fetch_add(1);
  86.  
  87.                 if(MaxLoadFactorExceeded()) {
  88.                     const auto count = buckets_.size();
  89.                     lock.unlock();
  90.                     TryExpandTable(count);
  91.                 }
  92.                 return true;
  93.             }
  94.  
  95.             bool Remove(const T& element) {
  96.                 const auto hash = hash_func_(element);
  97.                 auto lock = LockStripe<WriterLocker>(hash);
  98.  
  99.                 if (!ContainsHelper(element, hash)) {
  100.                     return false;
  101.                 }
  102.  
  103.                 auto& bucket = GetBucket(hash);
  104.                 bucket.remove(element);
  105.                 table_size_.fetch_sub(1);
  106.                 return true;
  107.             }
  108.  
  109.             bool Contains(const T& element) const {
  110.                 const auto hash = hash_func_(element);
  111.                 LockStripe<ReaderLocker>(hash);
  112.                 return ContainsHelper(element, hash);
  113.             }
  114.  
  115.             size_t GetSize() const {
  116.                 return table_size_.load();
  117.             }
  118.  
  119.             size_t GetBucketCount() const {
  120.                 LockStripe<ReaderLocker>(0);
  121.                 return buckets_.size();
  122.             }
  123.  
  124.         private:
  125.             bool ContainsHelper(const T& element, const size_t hash) const {
  126.                 auto& buck = GetBucket(hash);
  127.                 auto result = std::find(buck.begin(), buck.end(), element);
  128.                 return result != buck.end();
  129.             }
  130.  
  131.             size_t GetStripeIndex(const size_t hash_value) const {
  132.                 return hash_value % concurrency_level_;
  133.             }
  134.  
  135.             template <class Locker>
  136.             Locker LockStripe(const size_t hash_value) const {
  137.                 return Locker(rw_locks_[GetStripeIndex(hash_value)]);
  138.             }
  139.  
  140.             size_t GetBucketIndex(const size_t hash_value) const {
  141.                 return hash_value % buckets_.size();
  142.             }
  143.  
  144.             Bucket& GetBucket(const size_t hash_value) {
  145.                 return buckets_[GetBucketIndex(hash_value)];
  146.             }
  147.  
  148.             const Bucket& GetBucket(const size_t hash_value) const {
  149.                 return buckets_[GetBucketIndex(hash_value)];
  150.             }
  151.  
  152.             bool MaxLoadFactorExceeded() const {
  153.                 return table_size_.load() >= max_load_factor_ * buckets_.size();
  154.             }
  155.  
  156.             void TryExpandTable(const size_t expected_bucket_count) {
  157.                 std::vector<WriterLocker> w_locks;
  158.                 w_locks.emplace_back(LockStripe<WriterLocker>(0));
  159.                 if (expected_bucket_count != buckets_.size()) {
  160.                     return;
  161.                 }
  162.  
  163.                 for (size_t i = 1; i < concurrency_level_; i++) {
  164.                     w_locks.emplace_back(LockStripe<WriterLocker>(i));
  165.                 }
  166.  
  167.                 Buckets new_buckets(buckets_.size() * growth_factor_);
  168.                 for (auto& bucket : buckets_) {
  169.                     for (auto& elem : bucket) {
  170.                         new_buckets[hash_func_(elem) % new_buckets.size()].push_front(std::move(elem));
  171.                     }
  172.                 }
  173.                 std::swap(buckets_, new_buckets);
  174.  
  175. //                Buckets new_buckets(buckets_.size() * growth_factor_);
  176. //                std::swap(buckets_, new_buckets);
  177. //                for (auto& bucket : new_buckets) {
  178. //                    for (auto& elem : bucket) {
  179. //                        buckets_[GetBucketIndex(hash_func_(elem))].push_front(std::move(elem));
  180. //                    }
  181. //                }
  182.             }
  183.  
  184.         private:
  185.             const size_t concurrency_level_;
  186.             const size_t growth_factor_;
  187.             const double max_load_factor_;
  188.             const HashFunction hash_func_;
  189.             tpcc::atomic<size_t> table_size_;
  190.             Buckets buckets_;
  191.             mutable std::vector<RWLock> rw_locks_;
  192.         };
  193.  
  194.     }  // namespace solutions
  195. }  // namespace tpcc
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement