提交 e57898bc authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix some things and remove extraneous dimshuffles.

上级 7b23b23b
......@@ -1982,7 +1982,7 @@ class _RNNSplitParams(DnnBase):
assert w.ndim == 1
layer = as_scalar(layer).astype('int32')
isize = as_tensor_variable(isize).astype('uint64')
assert isize.ndim == 2
assert isize.ndim == 1
typecode = as_scalar(typecode).astype('int32')
_1d = GpuArrayType(w.type.dtype, [False],
context_name=w.type.context_name)
......@@ -2230,7 +2230,7 @@ class GpuDnnRNNOp(DnnBase):
return Apply(self, inputs, outputs)
def grad2(self, inputs, outputs, output_grads):
def L_op(self, inputs, outputs, output_grads):
desc, w, x, hx = inputs[:4]
cx = inputs[4] if len(inputs) == 5 else None
reserve, y, hy = outputs[:3]
......
......@@ -73,14 +73,14 @@ class Layer(object):
class GRU(Layer):
def __init__(self, input_dim, output_dim, input_layer, s0=None, batch_normalize=False, name=""):
def __init__(self, input_dim, output_dim, input_layer, s0=None, name=""):
'''Layers information'''
self.name = name
self.input_dim = input_dim
self.hidden_dim = output_dim
self.output_dim = output_dim
self.input_layer = input_layer
self.X = input_layer.output().dimshuffle(1, 0, 2)
self.X = input_layer.output()
self.s0 = s0
self.params = []
......@@ -127,7 +127,7 @@ class GRU(Layer):
outputs_info=outputs_info
)
self.Y = states.dimshuffle(1, 0, 2)
self.Y = states
def output(self):
return self.Y
......
......@@ -1457,7 +1457,7 @@ def test_dnn_rnn_gru():
numpy.zeros((psize,), dtype=theano.config.floatX))
model = Model()
last_layer = WrapperLayer(X.dimshuffle(1, 0, 2))
last_layer = WrapperLayer(X)
last_dim = input_dim
for i in range(depth):
gru = GRU(last_dim, hidden_dim, last_layer, s0=h0[i, :, :])
......@@ -1478,7 +1478,7 @@ def test_dnn_rnn_gru():
grad_fn = theano.function([X, Y, h0], grad, mode=mode_with_gpu)
return fn, grad_fn
ref_fn, ref_grad_fn = funcs(last_layer.output().dimshuffle((1, 0, 2)),
ref_fn, ref_grad_fn = funcs(last_layer.output(),
model.get_params())
cudnn_fn, cudnn_grad_fn = funcs(rnnb.apply(params_cudnn, X, h0)[0],
[params_cudnn])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论