提交 525c9c84 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron 提交者: --global

c_set_tensor4d -> c_set_tensorNd to be more generic (in C code).

上级 4eb4a899
......@@ -41,6 +41,20 @@ static inline const char *cudnnGetErrorString(cudnnStatus_t err) {
typedef cudnnTensor4dDescriptor_t cudnnTensorDescriptor_t;
static inline cudnnStatus_t
cudnnSetTensorNdDescriptor(
cudnnTensorDescriptor_t tensorDesc,
cudnnDataType_t dataType,
int nbDims,
const int dimA[],
const int strideA[]) {
if (ndDims != 4) return CUDNN_STATUS_NOT_SUPPORTED;
return cudnnSetTensor4dDescriptorEx(
tensorDesc, dataType,
dimA[0], dimA[1], dimA[2], dimA[3],
strideA[0], strideA[1], strideA[2], strideA[3]);
}
static inline cudnnStatus_t
cudnnGetConvolution2dForwardOutputDim(
const cudnnConvolutionDescriptor_t convDesc,
......
......@@ -7,17 +7,15 @@ c_set_tensorNd(CudaNdarray *var, int dim, cudnnTensorDescriptor_t desc) {
int strides[dim];
int default_str = 1;
for (int i = 0; i < dim; ++i)
for (int i = dim-1; i >= 0; i--)
{
if (CudaNdarray_HOST_STRIDES(var)[i])
strides[i] = CudaNdarray_HOST_STRIDES(var)[i];
else
{
strides[i] = 1;
for (int j = i + 1; j < dim; ++j)
strides[i] *= CudaNdarray_HOST_DIMS(var)[j];
}
strides[i] = default_str;
default_str *= CudaNdarray_HOST_DIMS(var)[i];
}
cudnnStatus_t err = cudnnSetTensorNdDescriptor(desc, CUDNN_DATA_FLOAT, dim,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论