提交 1f50d5d1 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

coding style improved for the new test following Pascal's suggestions

上级 5acda106
...@@ -1067,14 +1067,9 @@ class T_Scan(unittest.TestCase): ...@@ -1067,14 +1067,9 @@ class T_Scan(unittest.TestCase):
def caching_nsteps_by_scan_op(self): def caching_nsteps_by_scan_op(self):
import theano W = theano.tensor.matrix('weights')
import theano.tensor as T initial = theano.tensor.vector('initial')
import scipy inpt = theano.tensor.matrix('inpt')
W = T.matrix('weights')
initial = T.vector('initial')
inpt = T.matrix('inpt')
def one_step(x_t, h_tm1, W): def one_step(x_t, h_tm1, W):
expr = T.dot(h_tm1, W) + x_t expr = T.dot(h_tm1, W) + x_t
...@@ -1086,30 +1081,29 @@ class T_Scan(unittest.TestCase): ...@@ -1086,30 +1081,29 @@ class T_Scan(unittest.TestCase):
outputs_info=[initial], outputs_info=[initial],
non_sequences=[W]) non_sequences=[W])
floatX = theano.config.floatX
sh = expr.shape[0] sh = expr.shape[0]
init_val = theano.shared( numpy.ones(5, dtype=floatX))
inpt_val = theano.shared( numpy.ones((5,5), dtype=floatX))
shapef = theano.function([W], expr, shapef = theano.function([W], expr,
givens={initial: theano.shared( givens={initial: init_val,
scipy.ones(5, inpt: inpt_val })
dtype=theano.config.floatX)),
inpt: theano.shared(
scipy.ones((5, 5),
dtype=theano.config.floatX))})
# First execution to cache n_steps # First execution to cache n_steps
shapef(scipy.ones((5, 5), dtype=theano.config.floatX)) val0 = numpy.ones((5,5), dtype = floatX)
shapef(val0)
cost = expr.sum() cost = expr.sum()
d_cost_wrt_W = T.grad(cost, [W]) d_cost_wrt_W = T.grad(cost, [W])
init_val = theano.shared( numpy.zeros(5, dtype =floatX))
f = theano.function([W, inpt], d_cost_wrt_W, f = theano.function([W, inpt], d_cost_wrt_W,
givens={initial: theano.shared(scipy.zeros(5))}) givens={initial: init_val})
rval = numpy.asarray([[5187989]*5]*5, dtype = theano.config.floatX) rval = numpy.asarray([[5187989]*5]*5, dtype = floatX)
assert numpy.allclose( f(scipy.ones((5, 5), x = numpy.ones((5,5), dtype = floatX)
dtype=theano.config.floatX) y = numpy.ones((10,5), dtype = floatX)
, scipy.ones((10, 5), t_rval = f( x,y)
dtype=theano.config.floatX)) assert numpy.allclose( t_rval, rval)
,rval)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论