提交 15a77a59 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

correcting rop2 test

上级 b0f5fb53
...@@ -2535,10 +2535,10 @@ class T_Scan(unittest.TestCase): ...@@ -2535,10 +2535,10 @@ class T_Scan(unittest.TestCase):
def rnn_fn(_u, _y, _W): def rnn_fn(_u, _y, _W):
srng = theano.tensor.shared_randomstreams.RandomStreams(seed) srng = theano.tensor.shared_randomstreams.RandomStreams(seed)
sl_o = theano.tensor.tanh(theano.tensor.dot(_W, (_u + _y + \ tmp_val = _u + _y + srng.uniform(size=v_h0.shape) *\
srng.uniform(size=v_h0.shape) * numpy.asarray(1e-6, dtype=floatX)
numpy.asarray(1e-6, dtype=floatX)))) sl_o = theano.tensor.tanh(theano.tensor.dot(_W, tmp_val))
return sl_o return sl_o, tmp_val
u = theano.tensor.matrix('U') u = theano.tensor.matrix('U')
h0 = theano.tensor.vector('h0') h0 = theano.tensor.vector('h0')
...@@ -2551,9 +2551,9 @@ class T_Scan(unittest.TestCase): ...@@ -2551,9 +2551,9 @@ class T_Scan(unittest.TestCase):
_W = theano.tensor.specify_shape(W, v_W.shape) _W = theano.tensor.specify_shape(W, v_W.shape)
_W.name = '_W' _W.name = '_W'
o, _ = theano.scan(rnn_fn, [o,_], _ = theano.scan(rnn_fn,
sequences=_u, sequences=_u,
outputs_info=_h0, outputs_info=[_h0, None],
non_sequences=_W, non_sequences=_W,
name='rnn_fn') name='rnn_fn')
o = o[-1] o = o[-1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论