提交 87640d7e authored 作者: Frederic's avatar Frederic 提交者: Arnaud Bergeron

make cudnn softmax support R2

上级 487148a8
...@@ -38,4 +38,11 @@ static inline const int cudnnVersionMacro(){ ...@@ -38,4 +38,11 @@ static inline const int cudnnVersionMacro(){
} }
//some macro to help support cudnn R1 and R2.
#ifdef CUDNN_VERSION
#define cudnnTensor4dDescriptor_t cudnnTensorDescriptor_t
#define cudnnCreateTensor4dDescriptor cudnnCreateTensorDescriptor
#define cudnnDestroyTensor4dDescriptor cudnnDestroyTensorDescriptor
#endif
#endif #endif
...@@ -1180,7 +1180,7 @@ if (CudaNdarray_prep_output(&%(outs)s, 4, CudaNdarray_HOST_DIMS(%(ins)s)) != 0) ...@@ -1180,7 +1180,7 @@ if (CudaNdarray_prep_output(&%(outs)s, 4, CudaNdarray_HOST_DIMS(%(ins)s)) != 0)
return result return result
def c_code_cache_version(self): def c_code_cache_version(self):
return (0, 6) return (0, 6, version())
def method(self): def method(self):
raise NotImplementedError('GpuDnnSoftmaxBase::method') raise NotImplementedError('GpuDnnSoftmaxBase::method')
...@@ -1196,6 +1196,7 @@ class GpuDnnSoftmax(GpuDnnSoftmaxBase): ...@@ -1196,6 +1196,7 @@ class GpuDnnSoftmax(GpuDnnSoftmaxBase):
def method(self): def method(self):
return """ return """
#ifndef CUDNN_VERSION
err%(name)s = cudnnSoftmaxForward( err%(name)s = cudnnSoftmaxForward(
_handle, _handle,
algo%(name)s, algo%(name)s,
...@@ -1205,6 +1206,23 @@ err%(name)s = cudnnSoftmaxForward( ...@@ -1205,6 +1206,23 @@ err%(name)s = cudnnSoftmaxForward(
softmax_output_%(name)s, softmax_output_%(name)s,
CudaNdarray_DEV_DATA(%(outs)s) CudaNdarray_DEV_DATA(%(outs)s)
); );
#else
{
const float alpha = 1.;
const float beta = 0.;
err%(name)s = cudnnSoftmaxForward(
_handle,
algo%(id)d,
mode%(id)d,
(void*) &alpha,
softmax_input_%(id)d,
CudaNdarray_DEV_DATA(%(ins)s),
(void*) &beta,
softmax_output_%(id)d,
CudaNdarray_DEV_DATA(%(outs)s)
);
}
#endif
""" """
def grad(self, inp, grads): def grad(self, inp, grads):
...@@ -1230,6 +1248,7 @@ class GpuDnnSoftmaxGrad(GpuDnnSoftmaxBase): ...@@ -1230,6 +1248,7 @@ class GpuDnnSoftmaxGrad(GpuDnnSoftmaxBase):
def method(self): def method(self):
return """ return """
#ifndef CUDNN_VERSION
err%(name)s = cudnnSoftmaxBackward( err%(name)s = cudnnSoftmaxBackward(
_handle, _handle,
algo%(name)s, algo%(name)s,
...@@ -1241,7 +1260,26 @@ err%(name)s = cudnnSoftmaxBackward( ...@@ -1241,7 +1260,26 @@ err%(name)s = cudnnSoftmaxBackward(
softmax_output_%(name)s, softmax_output_%(name)s,
CudaNdarray_DEV_DATA(%(outs)s) CudaNdarray_DEV_DATA(%(outs)s)
); );
""" #else
{
const float alpha = 1.;
const float beta = 0.;
err%(name)s = cudnnSoftmaxBackward(
_handle,
algo%(id)d,
mode%(id)d,
(void*) &alpha,
%(name1)s_%(id)d,
CudaNdarray_DEV_DATA(%(ins1)s),
%(name0)s_%(id)d,
CudaNdarray_DEV_DATA(%(ins0)s),
(void*) &beta,
softmax_output_%(id)d,
CudaNdarray_DEV_DATA(%(outs)s)
);
}
#endif
"""
# Intentation for history # Intentation for history
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论