Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <type_traits>
- #include <cassert>
- #include <functional>
- template <typename Type, typename Res, typename Enable = void>
- struct Functor;
- template <typename ClassType, typename Res>
- struct Functor<ClassType, Res, typename std::enable_if<std::is_class<ClassType>::value>::type> {
- Functor(ClassType* _c, Res (ClassType::* _fn)()): c(_c), fn(_fn) {}
- ClassType* c;
- Res (ClassType::*fn)();
- template <typename CT2, typename R2>
- bool operator ==(const Functor<CT2, R2>&) const {
- return false;
- }
- template <typename CT2>
- typename std::enable_if<std::is_base_of<ClassType, CT2>::value || std::is_base_of<CT2, ClassType>::value, bool>::type
- operator ==(const Functor<CT2, Res>& rhs) const {
- return rhs.c == this->c && rhs.fn == this->fn;
- }
- template <typename CT2, typename R2>
- bool operator !=(const Functor<CT2, R2>& rhs) const {
- return !(*this == rhs);
- }
- };
- template <typename FuncType, typename Res>
- struct Functor<FuncType, Res, typename std::enable_if<std::is_function<FuncType>::value>::type> {
- Functor(FuncType* _f): f(_f) {}
- template <typename FT2, typename R2>
- bool operator ==(const Functor<FT2, R2>&) const {
- return false;
- }
- bool operator ==(const Functor<FuncType, Res>& rhs) const {
- return this->f == rhs.f;
- }
- template <typename FT2, typename R2>
- bool operator !=(const Functor<FT2, R2>& rhs) const {
- return !(*this == rhs);
- }
- FuncType* f;
- };
- template <typename ClassType, typename Res>
- Functor<ClassType, Res> functor(ClassType* _c, Res (ClassType::*_fn)()) {
- return Functor<ClassType, Res>(_c, _fn);
- }
- template <typename Res>
- Functor<Res (), Res> functor(Res func()) {
- return Functor<Res (), Res>(func);
- }
- struct Base1 {
- virtual void fn1() {}
- };
- struct Base2 {
- virtual void fn2() {}
- };
- struct C {
- virtual void fn3() {}
- virtual int fn5() { return 0; }
- };
- struct Derive: Base1, Base2 {
- virtual void fn4() {}
- };
- void fn6() {}
- void fn7() {}
- int fn8() { return 0; }
- int main()
- {
- // Unit Test
- Derive d;
- auto f1 = functor<Base2>(&d, &Derive::fn2); //说是无法推导ClassType的类型囧
- assert(f1 == f1);
- auto f2 = functor(static_cast<Base2 *>(&d), &Base2::fn2);
- assert(f1 == f2);
- auto f3 = functor(static_cast<Base1 *>(&d), &Derive::fn1);
- assert(f1 != f3);
- auto f4 = functor(&d, &Derive::fn4);
- assert(f1 != f4);
- Base2 b2;
- auto f5 = functor(&b2, &Base2::fn2);
- assert(f1 != f5);
- C c;
- auto f6 = functor(&c, &C::fn3);
- assert(f1 != f6);
- auto f7 = functor(&c, &C::fn5);
- assert(f6 != f7);
- auto f8 = functor(::fn6);
- auto f9 = functor(::fn7);
- auto f10 = functor(::fn8);
- assert(f8 == f8);
- assert(f8 != f9);
- assert(f8 != f10);
- assert(f1 != f9);
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement