diff --git a/include/caffe/util/cudnn.hpp b/include/caffe/util/cudnn.hpp index a7d8dbb..0fe4f6a 100644 --- a/include/caffe/util/cudnn.hpp +++ b/include/caffe/util/cudnn.hpp @@ -109,8 +109,13 @@ template inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv, cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter, int pad_h, int pad_w, int stride_h, int stride_w) { +#if CUDNN_VERSION_MIN(6, 0, 0) CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv, - pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION)); + pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION, + dataType::type)); +#else + CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv, + pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION)); } template