提交 01305008 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix grad

上级 e57898bc
......@@ -2249,11 +2249,12 @@ class GpuDnnRNNOp(DnnBase):
dy = as_gpuarray_variable(dhy[-1],
context_name=dhy.type.context_name)
if isinstance(dhy.type, DisconnectedType):
dhy = as_gpuarray_variable(hy.zeros_like(),
context_name=hy.type.context_name)
dhy = None
if dcy and isinstance(dcy.type, DisconnectedType):
dcy = None
dinputs = GpuDnnRNNGradInputs()(
dinputs = GpuDnnRNNGradInputs(rnn_mode=self.rnn_mode,
grad_h=(dhy is not None),
grad_c=(dcy is not None))(
desc, x, y, dy, dhy, dcy, w, hx, cx, reserve, return_list=True)
reserve2, dx, dhx = dinputs[:3]
dw = GpuDnnRNNGradWeights()(
......@@ -2270,28 +2271,55 @@ class GpuDnnRNNOp(DnnBase):
class GpuDnnRNNGradInputs(DnnBase):
__props__ = ()
__props__ = ('rnn_mode', 'grad_c', 'grad_h')
_cop_num_inputs = 10
_cop_num_outputs = 4
def __init__(self):
def __init__(self, rnn_mode, grad_h, grad_c):
DnnBase.__init__(self, ['dnn_rnn_gi.c'], 'dnn_rnn_gi')
self.rnn_mode = rnn_mode
self.grad_h = grad_h
self.grad_c = grad_c
if self.grad_c:
assert self.rnn_mode == 'lstm'
def dnn_context(self, node):
return node.outputs[1].type.context_name
def make_node(self, desc, x, y, dy, dhy, dcy, w, hx, cx, reserve):
# We trust the callers here
xshp = as_scalar(x.shape[2]).astype('uint64')
inputs = [desc, xshp, y, dy, dhy, w, hx, reserve]
inputs = [desc, xshp, y, dy, w, hx, reserve]
outputs = [reserve.type(), x.type(), hx.type()]
if dcy is not None:
inputs.append(dcy)
if self.rnn_mode == 'lstm':
inputs.append(cx)
outputs.append(cx.type())
if self.grad_h:
inputs.append(dhy)
if self.grad_c:
inputs.append(dcy)
return Apply(self, inputs, outputs)
# We have special requirements so this is hooking into COp
def format_c_function_args(self, inp, out):
rinp = inp[:7]
others = inp[7:]
if self.rnn_mode == 'lstm':
rinp.append(others.pop(0))
else:
rinp.append('NULL')
if self.grad_h:
rinp.append(others.pop(0))
else:
rinp.append('NULL')
if self.grad_c:
rinp.append(others.pop(0))
else:
rinp.append('NULL')
assert len(others) == 0
return COp.format_c_function_args(self, rinp, out)
class GpuDnnRNNGradWeights(DnnBase):
__props__ = ()
......
......@@ -2,11 +2,12 @@
int dnn_rnn_gi(cudnnRNNDescriptor_t desc, npy_uint64 xshp,
PyGpuArrayObject *y, PyGpuArrayObject *dy,
PyGpuArrayObject *dhy, PyGpuArrayObject *w,
PyGpuArrayObject *hx, gpudata *reserve, PyGpuArrayObject *dcy,
PyGpuArrayObject *cx, gpudata **oreserve,
PyGpuArrayObject **dx, PyGpuArrayObject **dhx,
PyGpuArrayObject **dcx, cudnnHandle_t _handle) {
PyGpuArrayObject *w, PyGpuArrayObject *hx,
gpudata *reserve, PyGpuArrayObject *cx,
PyGpuArrayObject *dhy, PyGpuArrayObject *dcy,
gpudata **oreserve, PyGpuArrayObject **dx,
PyGpuArrayObject **dhx, PyGpuArrayObject **dcx,
cudnnHandle_t _handle) {
PyGpuContextObject *c = y->context;
cudnnTensorDescriptor_t ydesc = NULL;
cudnnTensorDescriptor_t dhydesc = NULL;
......@@ -72,8 +73,9 @@ int dnn_rnn_gi(cudnnRNNDescriptor_t desc, npy_uint64 xshp,
goto fail;
}
if (c_make_tensorNd(dhy, &dhydesc) != 0)
goto fail;
if (dhy != NULL)
if (c_make_tensorNd(dhy, &dhydesc) != 0)
goto fail;
if (dcy != NULL)
if (c_make_tensorNd(dcy, &dcydesc) != 0)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论