提交 c6f6c6ae authored 作者: Razvan Pascanu's avatar Razvan Pascanu

new test for the bug Justin found

上级 92ed3123
......@@ -1065,6 +1065,52 @@ class T_Scan(unittest.TestCase):
f_scan([1,2,3], numpy.arange(12).reshape([4,3]), 1.)
f_grad([1,2,3], numpy.arange(12).reshape([4,3]), 1.)
def caching_nsteps_by_scan_op(self):
import theano
import theano.tensor as T
import scipy
W = T.matrix('weights')
initial = T.vector('initial')
inpt = T.matrix('inpt')
def one_step(x_t, h_tm1, W):
expr = T.dot(h_tm1, W) + x_t
return expr
expr, _ = theano.scan(
fn=one_step,
sequences=[inpt],
outputs_info=[initial],
non_sequences=[W])
sh = expr.shape[0]
shapef = theano.function([W], expr,
givens={initial: theano.shared(
scipy.ones(5,
dtype=theano.config.floatX)),
inpt: theano.shared(
scipy.ones((5, 5),
dtype=theano.config.floatX))})
# First execution to cache n_steps
shapef(scipy.ones((5, 5), dtype=theano.config.floatX))
cost = expr.sum()
d_cost_wrt_W = T.grad(cost, [W])
f = theano.function([W, inpt], d_cost_wrt_W,
givens={initial: theano.shared(scipy.zeros(5))})
rval = numpy.asarray([[5187989]*5]*5, dtype = theano.config.floatX)
assert numpy.allclose( f(scipy.ones((5, 5),
dtype=theano.config.floatX)
, scipy.ones((10, 5),
dtype=theano.config.floatX))
,rval)
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论