Tranvick

shared_ptr / weak_ptr

Dec 11th, 2014
257
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 7.34 KB | None | 0 0
  1. +#include <iostream>
  2. +using std::cout;
  3. +using std::endl;
  4. +
  5. +class bad_weak_ptr : public std::exception {};
  6. +
  7. +template <typename T>
  8. +struct DefaultDeletePolicy {
  9. +    static void freeMemory(T * pointer) {
  10. +        delete pointer;
  11. +    }
  12. +};
  13. +
  14. +template <typename T>
  15. +struct ArrayDeletePolicy {
  16. +    static void freeMemory(T * pointer) {
  17. +        delete [] pointer;
  18. +    }
  19. +};
  20. +
  21. +template<typename T, typename DeletePolicy = DefaultDeletePolicy<T> >
  22. +class weak_ptr;
  23. +
  24. +template<typename T>
  25. +struct Pointer {
  26. +    T * pointer;
  27. +    size_t refCount;
  28. +    size_t weakRefCount;
  29. +    explicit Pointer(T * pointer)
  30. +        : pointer(pointer), refCount(1), weakRefCount(0) {}
  31. +};
  32. +
  33. +
  34. +template <typename T, typename DeletePolicy = DefaultDeletePolicy<T> >
  35. +class shared_ptr {
  36. +
  37. +    friend class weak_ptr<T>;
  38. +
  39. +private:
  40. +
  41. +    Pointer<T> * objPointer;
  42. +
  43. +    void decrementRefCount() {
  44. +        if (objPointer) {
  45. +            Pointer<T> & object = *objPointer;
  46. +            --object.refCount;
  47. +            if (object.refCount == 0) {
  48. +                DeletePolicy::freeMemory(object.pointer);
  49. +                if (object.weakRefCount == 0) {
  50. +                    delete objPointer;
  51. +                }
  52. +            }
  53. +        }
  54. +    }
  55. +
  56. +    void incrementRefCount() {
  57. +        if (objPointer) {
  58. +            ++objPointer->refCount;
  59. +        }
  60. +    }
  61. +
  62. +    void assignPointer(T * pointer) {
  63. +        if (pointer) {
  64. +            objPointer = new Pointer<T>(pointer);
  65. +        } else {
  66. +            objPointer = NULL;
  67. +        }
  68. +    }
  69. +
  70. +public:
  71. +
  72. +    shared_ptr(T * pointer = NULL) {
  73. +        assignPointer(pointer);
  74. +    }
  75. +
  76. +    shared_ptr(const shared_ptr<T> & sharedPointer) {
  77. +        objPointer = sharedPointer.objPointer;
  78. +        incrementRefCount();
  79. +    }
  80. +
  81. +    shared_ptr(const weak_ptr<T> & weakPointer) {
  82. +        objPointer = weakPointer.objPointer;
  83. +        if (objPointer && objPointer -> refCount == 0) {
  84. +            throw bad_weak_ptr();
  85. +        }
  86. +        incrementRefCount();
  87. +    }
  88. +
  89. +    T & operator* () const {
  90. +        return *(objPointer->pointer);
  91. +    }
  92. +
  93. +    T * operator-> () const {
  94. +        return objPointer->pointer;
  95. +    }
  96. +
  97. +    T & operator[] (int index) {
  98. +        return *(objPointer->pointer + index);
  99. +    }
  100. +
  101. +    shared_ptr<T> & operator= (T * pointer) {
  102. +        decrementRefCount();
  103. +        assignPointer(pointer);
  104. +        return *this;
  105. +    }
  106. +
  107. +    shared_ptr<T> & operator= (const shared_ptr<T> & sharedPointer) {
  108. +        if (this != &pointer) {
  109. +            decrementRefCount();
  110. +            objPointer = pointer.objPointer;
  111. +            incrementRefCount();
  112. +        }
  113. +        return *this;
  114. +    }
  115. +
  116. +    bool operator== (const shared_ptr<T> & sharedPointer) const {
  117. +        return objPointer == sharedPointer.objPointer;
  118. +    }
  119. +
  120. +    bool operator!= (const shared_ptr<T> & sharedPointer) const {
  121. +        return !(*this == sharedPointer);
  122. +    }
  123. +
  124. +    operator bool () const {
  125. +        return objPointer && objPointer->pointer;
  126. +    }
  127. +
  128. +    ~shared_ptr() {
  129. +        decrementRefCount();
  130. +    }
  131. +
  132. +    int use_count() const {
  133. +        if (objPointer && objPointer->pointer) {
  134. +            return objPointer->refCount;
  135. +        }
  136. +        return 0;
  137. +    }
  138. +
  139. +    void reset(T * pointer = 0) {
  140. +        decrementRefCount();
  141. +        assignPointer(pointer);
  142. +    }
  143. +
  144. +    void swap(shared_ptr<T> & sharedPointer) {
  145. +        std::swap(objPointer, sharedPointer.objPointer);
  146. +    }
  147. +
  148. +};
  149. +
  150. +template<typename T, typename DeletePolicy>
  151. +class weak_ptr {
  152. +
  153. +    friend class shared_ptr<T>;
  154. +
  155. +private:
  156. +    Pointer<T> * objPointer;
  157. +
  158. +    void incrementWeakRefCount() {
  159. +        if (objPointer) {
  160. +            ++(objPointer->weakRefCount);
  161. +        }
  162. +    }
  163. +
  164. +    void decrementWeakRefCount() {
  165. +        if (objPointer) {
  166. +            Pointer<T> & object = *objPointer;
  167. +            --object.weakRefCount;
  168. +            if (object.refCount == 0 && object.weakRefCount == 0) {
  169. +                delete objPointer;
  170. +            }
  171. +        }
  172. +    }
  173. +
  174. +public:
  175. +
  176. +    weak_ptr() {
  177. +        objPointer = NULL;
  178. +    }
  179. +
  180. +    weak_ptr(const weak_ptr<T> & weakPointer) {
  181. +        objPointer = weakPointer.objPointer;
  182. +        incrementWeakRefCount();
  183. +    }
  184. +
  185. +    weak_ptr(const shared_ptr<T> & sharedPointer) {
  186. +        objPointer = sharedPointer.objPointer;
  187. +        incrementWeakRefCount();
  188. +    }
  189. +
  190. +    ~weak_ptr() {
  191. +        decrementWeakRefCount();
  192. +    }
  193. +
  194. +    int use_count() const {
  195. +        if (objPointer) {
  196. +            return objPointer->refCount;
  197. +        } else {
  198. +            return 0;
  199. +        }
  200. +    }
  201. +
  202. +    bool expired() const {
  203. +        return (use_count() == 0);
  204. +    }
  205. +
  206. +    shared_ptr<T> lock() const {
  207. +        if (objPointer->refCount == 0) {
  208. +            return shared_ptr<T>();
  209. +        } else {
  210. +            return shared_ptr<T>(*this);
  211. +        }
  212. +    }
  213. +
  214. +    weak_ptr<T> & operator= (const weak_ptr<T> & weakPointer) {
  215. +        if (this != &weakPointer) {
  216. +            decrementWeakRefCount();
  217. +            objPointer = weakPointer.objPointer;
  218. +            incrementWeakRefCount();
  219. +        }
  220. +        return *this;
  221. +    }
  222. +
  223. +   weak_ptr<T> & operator= (const shared_ptr<T> & sharedPointer) {
  224. +        decrementWeakRefCount();
  225. +        objPointer = sharedPointer.objPointer;
  226. +        incrementWeakRefCount();
  227. +        return *this;
  228. +    }
  229. +
  230. +    void reset() {
  231. +        decrementWeakRefCount();
  232. +        objPointer = NULL;
  233. +    }
  234. +
  235. +    void swap(weak_ptr<T> & weakPointer) {
  236. +        std::swap(objPointer, weakPointer.objPointer);
  237. +    }
  238. +
  239. +};
  240. +
  241. +struct MyStruct {
  242. +    void sayHello() {
  243. +        cout << "Hello, world!" << endl;
  244. +    }
  245. +};
  246. +
  247. +template<typename T>
  248. +void test(shared_ptr<T> sharedPointer) {
  249. +    cout << sharedPointer.use_count() << endl;
  250. +}
  251. +
  252. +weak_ptr<int> weakPointer;
  253. +
  254. +void weakTest() {
  255. +    cout << "use count: " << weakPointer.use_count() << endl;
  256. +    shared_ptr<int> shared = weakPointer.lock();
  257. +    if (shared) {
  258. +        cout << *shared << endl;
  259. +    } else {
  260. +        cout << "Expired!\n";
  261. +    }
  262. +}
  263. +
  264. +int main() {
  265. +    shared_ptr<int> intPtr = new int (3);
  266. +    shared_ptr<int> intPtrCopy = intPtr;
  267. +    cout << *intPtrCopy << endl;
  268. +    *intPtr = 10;
  269. +    cout << *intPtrCopy << endl << endl;
  270. +    shared_ptr<int, ArrayDeletePolicy<int> > arrayPtr(new int [10]);
  271. +    for (int i = 0; i < 10; ++i) {
  272. +        arrayPtr[i] = i;
  273. +    }
  274. +    cout << arrayPtr[3] << " " << arrayPtr[5] << endl << endl;
  275. +    cout << (bool)arrayPtr << " " << bool(shared_ptr<int>()) << endl;
  276. +    cout << (intPtr == intPtrCopy) << " " << (intPtr != intPtrCopy) << endl << endl;
  277. +
  278. +    shared_ptr<MyStruct> myPointer(new MyStruct);
  279. +    myPointer -> sayHello();
  280. +
  281. +    cout << intPtrCopy.use_count() << endl;
  282. +    test(intPtrCopy);
  283. +    cout << intPtrCopy.use_count() << endl;
  284. +    intPtr.reset();
  285. +    cout << intPtrCopy.use_count() << endl << endl;
  286. +
  287. +    {
  288. +        shared_ptr<int> sharedPointer = new int(42);
  289. +        weakPointer = sharedPointer;
  290. +        weakTest();
  291. +    }
  292. +    weakTest();
  293. +
  294. +    shared_ptr<int> a, b;
  295. +    a = new int (20);
  296. +    b = new int (10);
  297. +    cout << "\nBefore swap: " << *a << " " << *b << endl;
  298. +    a.swap(b);
  299. +    cout << " After swap: " << *a << " " << *b << endl;
  300. +
  301. +    return 0;
  302. +}
Advertisement
Add Comment
Please, Sign In to add comment