提交 01305008 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix grad

上级 e57898bc
...@@ -2249,11 +2249,12 @@ class GpuDnnRNNOp(DnnBase): ...@@ -2249,11 +2249,12 @@ class GpuDnnRNNOp(DnnBase):
dy = as_gpuarray_variable(dhy[-1], dy = as_gpuarray_variable(dhy[-1],
context_name=dhy.type.context_name) context_name=dhy.type.context_name)
if isinstance(dhy.type, DisconnectedType): if isinstance(dhy.type, DisconnectedType):
dhy = as_gpuarray_variable(hy.zeros_like(), dhy = None
context_name=hy.type.context_name)
if dcy and isinstance(dcy.type, DisconnectedType): if dcy and isinstance(dcy.type, DisconnectedType):
dcy = None dcy = None
dinputs = GpuDnnRNNGradInputs()( dinputs = GpuDnnRNNGradInputs(rnn_mode=self.rnn_mode,
grad_h=(dhy is not None),
grad_c=(dcy is not None))(
desc, x, y, dy, dhy, dcy, w, hx, cx, reserve, return_list=True) desc, x, y, dy, dhy, dcy, w, hx, cx, reserve, return_list=True)
reserve2, dx, dhx = dinputs[:3] reserve2, dx, dhx = dinputs[:3]
dw = GpuDnnRNNGradWeights()( dw = GpuDnnRNNGradWeights()(
...@@ -2270,28 +2271,55 @@ class GpuDnnRNNOp(DnnBase): ...@@ -2270,28 +2271,55 @@ class GpuDnnRNNOp(DnnBase):
class GpuDnnRNNGradInputs(DnnBase): class GpuDnnRNNGradInputs(DnnBase):
__props__ = () __props__ = ('rnn_mode', 'grad_c', 'grad_h')
_cop_num_inputs = 10 _cop_num_inputs = 10
_cop_num_outputs = 4 _cop_num_outputs = 4
def __init__(self): def __init__(self, rnn_mode, grad_h, grad_c):
DnnBase.__init__(self, ['dnn_rnn_gi.c'], 'dnn_rnn_gi') DnnBase.__init__(self, ['dnn_rnn_gi.c'], 'dnn_rnn_gi')
self.rnn_mode = rnn_mode
self.grad_h = grad_h
self.grad_c = grad_c
if self.grad_c:
assert self.rnn_mode == 'lstm'
def dnn_context(self, node): def dnn_context(self, node):
return node.outputs[1].type.context_name return node.outputs[1].type.context_name
def make_node(self, desc, x, y, dy, dhy, dcy, w, hx, cx, reserve): def make_node(self, desc, x, y, dy, dhy, dcy, w, hx, cx, reserve):
# We trust the callers here # We trust the callers here
xshp = as_scalar(x.shape[2]).astype('uint64') xshp = as_scalar(x.shape[2]).astype('uint64')
inputs = [desc, xshp, y, dy, dhy, w, hx, reserve] inputs = [desc, xshp, y, dy, w, hx, reserve]
outputs = [reserve.type(), x.type(), hx.type()] outputs = [reserve.type(), x.type(), hx.type()]
if dcy is not None: if self.rnn_mode == 'lstm':
inputs.append(dcy)
inputs.append(cx) inputs.append(cx)
outputs.append(cx.type()) outputs.append(cx.type())
if self.grad_h:
inputs.append(dhy)
if self.grad_c:
inputs.append(dcy)
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
# We have special requirements so this is hooking into COp
def format_c_function_args(self, inp, out):
rinp = inp[:7]
others = inp[7:]
if self.rnn_mode == 'lstm':
rinp.append(others.pop(0))
else:
rinp.append('NULL')
if self.grad_h:
rinp.append(others.pop(0))
else:
rinp.append('NULL')
if self.grad_c:
rinp.append(others.pop(0))
else:
rinp.append('NULL')
assert len(others) == 0
return COp.format_c_function_args(self, rinp, out)
class GpuDnnRNNGradWeights(DnnBase): class GpuDnnRNNGradWeights(DnnBase):
__props__ = () __props__ = ()
......
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
int dnn_rnn_gi(cudnnRNNDescriptor_t desc, npy_uint64 xshp, int dnn_rnn_gi(cudnnRNNDescriptor_t desc, npy_uint64 xshp,
PyGpuArrayObject *y, PyGpuArrayObject *dy, PyGpuArrayObject *y, PyGpuArrayObject *dy,
PyGpuArrayObject *dhy, PyGpuArrayObject *w, PyGpuArrayObject *w, PyGpuArrayObject *hx,
PyGpuArrayObject *hx, gpudata *reserve, PyGpuArrayObject *dcy, gpudata *reserve, PyGpuArrayObject *cx,
PyGpuArrayObject *cx, gpudata **oreserve, PyGpuArrayObject *dhy, PyGpuArrayObject *dcy,
PyGpuArrayObject **dx, PyGpuArrayObject **dhx, gpudata **oreserve, PyGpuArrayObject **dx,
PyGpuArrayObject **dcx, cudnnHandle_t _handle) { PyGpuArrayObject **dhx, PyGpuArrayObject **dcx,
cudnnHandle_t _handle) {
PyGpuContextObject *c = y->context; PyGpuContextObject *c = y->context;
cudnnTensorDescriptor_t ydesc = NULL; cudnnTensorDescriptor_t ydesc = NULL;
cudnnTensorDescriptor_t dhydesc = NULL; cudnnTensorDescriptor_t dhydesc = NULL;
...@@ -72,6 +73,7 @@ int dnn_rnn_gi(cudnnRNNDescriptor_t desc, npy_uint64 xshp, ...@@ -72,6 +73,7 @@ int dnn_rnn_gi(cudnnRNNDescriptor_t desc, npy_uint64 xshp,
goto fail; goto fail;
} }
if (dhy != NULL)
if (c_make_tensorNd(dhy, &dhydesc) != 0) if (c_make_tensorNd(dhy, &dhydesc) != 0)
goto fail; goto fail;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论