提交 d52d1e7d authored 作者: notoraptor's avatar notoraptor

Wrap Op params for theano.gpuarray.dnn.GpuDnnSoftmaxBase:

- algo (C enum) - mode (C enum) - handle (handle_type) Extend DnnBase.get_params() to take ParamsType into account.
上级 1b65de9a
......@@ -12,7 +12,7 @@ from theano import Op, Apply, tensor, config, Variable
from theano.scalar import as_scalar, constant, Log, get_scalar_type
from theano.tensor import as_tensor_variable
from theano.gradient import DisconnectedType, grad_not_implemented
from theano.gof import Optimizer, local_optimizer, COp
from theano.gof import Optimizer, local_optimizer, COp, ParamsType, CEnumType
from theano.gof.cmodule import GCC_compiler
from theano.gof.type import CDataType, Generic
from theano.compile import optdb
......@@ -234,6 +234,11 @@ class DnnBase(COp):
ptr = ctx.cudnn_handle.value
res = handle_type.make_value(ptr)
ctx.cudnn_handle_param = res
if isinstance(self.params_type, ParamsType):
if not self.params_type.has_type(handle_type):
raise TypeError('DnnBase: params_type must take into account the cuDNN handle type.')
handle_field = self.params_type.get_field(handle_type)
return self.params_type.get_params(self, **{handle_field: ctx.cudnn_handle_param})
return ctx.cudnn_handle_param
def __init__(self, files=None, c_func=None):
......@@ -1504,6 +1509,18 @@ class GpuDnnSoftmaxBase(DnnBase):
"""
__props__ = ('mode', 'algo')
# Neither inputs nor output types properties are used
# neither in dnn_base.c nor in dnn_softmax*.c,
# so we can disable input checking.
check_input = False
params_type = ParamsType(algo=CEnumType(('CUDNN_SOFTMAX_FAST', 'fast'),
('CUDNN_SOFTMAX_LOG', 'log'),
('CUDNN_SOFTMAX_ACCURATE', 'accurate'),
ctype='cudnnSoftmaxAlgorithm_t'),
mode=CEnumType(('CUDNN_SOFTMAX_MODE_INSTANCE', 'instance'),
('CUDNN_SOFTMAX_MODE_CHANNEL', 'channel'),
ctype='cudnnSoftmaxMode_t'),
handle=handle_type)
def __init__(self, algo, mode):
DnnBase.__init__(self, [self.file], self.c_func)
......@@ -1520,21 +1537,6 @@ class GpuDnnSoftmaxBase(DnnBase):
else:
return [shape[1]]
def get_op_params(self):
if self.mode == 'instance':
mode = "CUDNN_SOFTMAX_MODE_INSTANCE"
else:
mode = "CUDNN_SOFTMAX_MODE_CHANNEL"
if self.algo == 'fast':
algo = "CUDNN_SOFTMAX_FAST"
elif self.algo == 'log':
algo = "CUDNN_SOFTMAX_LOG"
else:
algo = "CUDNN_SOFTMAX_ACCURATE"
return [("SOFTMAX_MODE", mode), ("SOFTMAX_ALGO", algo)]
class GpuDnnSoftmax(GpuDnnSoftmaxBase):
......
......@@ -35,7 +35,7 @@ if (APPLY_SPECIFIC(output) != NULL)
int APPLY_SPECIFIC(softmax)(PyGpuArrayObject *x,
PyGpuArrayObject **out,
cudnnHandle_t _handle) {
PARAMS_TYPE* wrapper) {
PyGpuContextObject *c = x->context;
cudnnStatus_t err;
......@@ -83,9 +83,9 @@ int APPLY_SPECIFIC(softmax)(PyGpuArrayObject *x,
cuda_wait((*out)->ga.data, GPUARRAY_CUDA_WAIT_WRITE);
err = cudnnSoftmaxForward(
_handle,
SOFTMAX_ALGO,
SOFTMAX_MODE,
wrapper->handle,
wrapper->algo,
wrapper->mode,
alpha,
APPLY_SPECIFIC(input),
PyGpuArray_DEV_DATA(x),
......
......@@ -46,7 +46,7 @@ if (APPLY_SPECIFIC(dx) != NULL)
int APPLY_SPECIFIC(softmax_grad)(PyGpuArrayObject *dy,
PyGpuArrayObject *sm,
PyGpuArrayObject **dx,
cudnnHandle_t _handle) {
PARAMS_TYPE* wrapper) {
PyGpuContextObject *c = dy->context;
cudnnStatus_t err;
......@@ -97,9 +97,9 @@ int APPLY_SPECIFIC(softmax_grad)(PyGpuArrayObject *dy,
cuda_wait((*dx)->ga.data, GPUARRAY_CUDA_WAIT_WRITE);
err = cudnnSoftmaxBackward(
_handle,
SOFTMAX_ALGO,
SOFTMAX_MODE,
wrapper->handle,
wrapper->algo,
wrapper->mode,
alpha,
APPLY_SPECIFIC(sm),
PyGpuArray_DEV_DATA(sm),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论