Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- diff --git a/caffe2/core/common_cudnn.h b/caffe2/core/common_cudnn.h
- index f93f333..5655b33 100644
- --- a/caffe2/core/common_cudnn.h
- +++ b/caffe2/core/common_cudnn.h
- @@ -31,6 +31,12 @@
- static_assert(
- CUDNN_VERSION >= 5000,
- "Caffe2 requires cudnn version 5.0 or above.");
- +
- +#if CUDNN_VERSION < 6000
- +#pragma message "CUDNN version under 6.0 is supported at best effort."
- +#pragma message "We strongly encourage you to move to 6.0 and above."
- +#pragma message "This message is intended to annoy you enough to update."
- +#endif // CUDNN_VERSION < 6000
- #define CUDNN_VERSION_MIN(major, minor, patch) \
- (CUDNN_VERSION >= ((major) * 1000 + (minor) * 100 + (patch)))
- @@ -136,6 +142,7 @@ class cudnnTypeWrapper<float> {
- }
- };
- +#if CUDNN_VERSION_MIN(6, 0, 0)
- template <>
- class cudnnTypeWrapper<int> {
- public:
- @@ -151,6 +158,7 @@ class cudnnTypeWrapper<int> {
- return &v;
- }
- };
- +#endif // CUDNN_VERSION_MIN(6, 0, 0)
- template <>
- class cudnnTypeWrapper<double> {
- diff --git a/caffe2/operators/pool_op_cudnn.cu b/caffe2/operators/pool_op_cudnn.cu
- index 5c18c4a..bfe491d 100644
- --- a/caffe2/operators/pool_op_cudnn.cu
- +++ b/caffe2/operators/pool_op_cudnn.cu
- @@ -134,8 +134,11 @@ class CuDNNPoolOp : public ConvPoolOpBase<CUDAContext> {
- CUDNN_ENFORCE(cudnnCreatePoolingDescriptor(&pooling_desc_));
- // Figure out the pooling descriptor.
- if (operator_def.type().substr(0, 7) == "MaxPool") {
- -#if CUDNN_VERSION_MIN(6,0,0)
- - mode_ = CUDNN_POOLING_MAX_DETERMINISTIC;
- + bool deterministic =
- + OperatorBase::GetSingleArgument<bool>("deterministic", false);
- +#if CUDNN_VERSION_MIN(6, 0, 0)
- + mode_ =
- + deterministic ? CUDNN_POOLING_MAX_DETERMINISTIC : CUDNN_POOLING_MAX;
- #else
- mode_ = CUDNN_POOLING_MAX;
- #endif
- @@ -253,15 +256,17 @@ class CuDNNPoolOp : public ConvPoolOpBase<CUDAContext> {
- }
- }
- // Carry out the pooling computation.
- + const T* Xdata = X.template data<T>();
- + T* Ydata = Y->template mutable_data<T>();
- CUDNN_ENFORCE(cudnnPoolingForward(
- cudnn_wrapper_.inline_cudnn_handle(),
- pooling_desc_,
- cudnnTypeWrapper<T>::kOne(),
- bottom_desc_,
- - X.template data<T>(),
- + Xdata,
- cudnnTypeWrapper<T>::kZero(),
- top_desc_,
- - Y->template mutable_data<T>()));
- + Ydata));
- return true;
- }
- @@ -382,8 +387,12 @@ class CuDNNPoolGradientOp : public ConvPoolOpBase<CUDAContext> {
- dX->mutable_data<float>());
- return true;
- }
- +#if CUDNN_VERSION_MIN(6, 0, 0)
- if (mode_ == CUDNN_POOLING_MAX ||
- mode_ == CUDNN_POOLING_MAX_DETERMINISTIC) {
- +#else
- + if (mode_ == CUDNN_POOLING_MAX) {
- +#endif
- global_maxpool_backward_NCHW<float>
- <<<CAFFE_GET_BLOCKS(dX->size()),
- CAFFE_CUDA_NUM_THREADS,
- @@ -449,19 +458,24 @@ class CuDNNPoolGradientOp : public ConvPoolOpBase<CUDAContext> {
- }
- }
- // Carry out the pooling computation.
- + const T* Xdata = X.template data<T>();
- + const T* Ydata = Y.template data<T>();
- + const T* dYdata = dY.template data<T>();
- + T* dXdata = dX->template mutable_data<T>();
- +
- CUDNN_ENFORCE(cudnnPoolingBackward(
- cudnn_wrapper_.inline_cudnn_handle(),
- pooling_desc_,
- cudnnTypeWrapper<T>::kOne(),
- top_desc_,
- - Y.template data<T>(),
- + Ydata,
- top_desc_,
- - dY.template data<T>(),
- + dYdata,
- bottom_desc_,
- - X.template data<T>(),
- + Xdata,
- cudnnTypeWrapper<T>::kZero(),
- bottom_desc_,
- - dX->template mutable_data<T>()));
- + dXdata));
- return true;
- }
- @@ -493,7 +507,7 @@ class CuDNNPoolGradientOp : public ConvPoolOpBase<CUDAContext> {
- // Input: X, Y, dY
- // Output: dX
- - INPUT_TAGS(IN, OUT, OUT_GRAD);
- + // INPUT_TAGS(IN, OUT, OUT_GRAD);
- };
- namespace {
- diff --git a/caffe2/utils/GpuBitonicSort.cuh b/caffe2/utils/GpuBitonicSort.cuh
- index f52bb50..45cb298 100644
- --- a/caffe2/utils/GpuBitonicSort.cuh
- +++ b/caffe2/utils/GpuBitonicSort.cuh
- @@ -6,6 +6,19 @@
- namespace caffe2 {
- +// Returns true if the given integer type is a power-of-2 (positive only)
- +// Note(jiayq): windows reported an error per
- +// https://github.com/caffe2/caffe2/issues/997
- +// and as a result will make it a macro.
- +#ifdef _MSC_VER
- +#define integerIsPowerOf2(v) ((v) && !((v) & ((v) - 1)))
- +#else // _MSC_VER
- +template <typename T>
- +constexpr bool integerIsPowerOf2(T v) {
- + return (v && !(v & (v - 1)));
- +}
- +#endif // _MSC_VER
- +
- /// The maximum in-block bitonic sort we support
- constexpr int kMaxBitonicSortSize = 4096;
- @@ -39,9 +52,9 @@ __device__ inline void bitonicSort(K* keys,
- // Assume the sort is taking place in shared memory
- // static_assert(Power2SortSize * (sizeof(K) + sizeof(V)) < 32768,
- // "sort data too large (>32768 bytes)");
- - static_assert(math::integerIsPowerOf2(Power2SortSize),
- + static_assert(integerIsPowerOf2(Power2SortSize),
- "sort size must be power of 2");
- - static_assert(math::integerIsPowerOf2(ThreadsPerBlock),
- + static_assert(integerIsPowerOf2(ThreadsPerBlock),
- "threads in block must be power of 2");
- // If what we are sorting is too small, then not all threads
- @@ -107,7 +120,7 @@ __device__ inline void warpBitonicSort(K* keys,
- // Smaller sorts should use a warp shuffle sort
- static_assert(Power2SortSize > kWarpSize,
- "sort not large enough");
- - static_assert(math::integerIsPowerOf2(Power2SortSize),
- + static_assert(integerIsPowerOf2(Power2SortSize),
- "sort size must be power of 2");
- static_assert(Power2SortSize <= kMaxBitonicSortSize,
- "sort size <= 4096 only supported");
- diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h
- index 487a77d..3da68cd 100644
- --- a/caffe2/utils/math.h
- +++ b/caffe2/utils/math.h
- @@ -468,19 +468,6 @@ constexpr T roundUp(T a, T b) {
- return divUp<T>(a, b) * b;
- }
- -// Returns true if the given integer type is a power-of-2 (positive only)
- -// Note(jiayq): windows reported an error per
- -// https://github.com/caffe2/caffe2/issues/997
- -// and as a result will make it a macro.
- -#ifdef _MSC_VER
- -#define integerIsPowerOf2(v) ((v) && !((v) & ((v) - 1)))
- -#else // _MSC_VER
- -template <typename T>
- -constexpr bool integerIsPowerOf2(T v) {
- - return (v && !(v & (v - 1)));
- -}
- -#endif // _MSC_VER
- -
- // Returns log2(n) for a positive integer type
- template <typename T>
- constexpr int integerLog2(T n, int p = 0) {
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement