提交 1d7e3679 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Make scan pushout computation work when you push everything

Note also a test was added for this case.
上级 840176c8
......@@ -276,6 +276,22 @@ class PushOutNonSeqScan(gof.Optimizer):
env.replace_all_validate(zip(node.outputs, nw_node.outputs),
reason = 'scan_push_computation_out')
return True
elif to_keep == []:
# Nothing in the inner graph should be kept
replace_with = {}
for idx, out in enumerate(to_replace):
if out in local_env.outputs:
x = node.outputs[local_env.outputs.index(out)]
y = replace_with_out[idx]
shape = [y.shape[idx] for idx in xrange(y.ndim)]
replace_with[x] = tensor.alloc(y,
node.inputs[0],
*shape)
# We need to add one extra dimension to the outputs
env.replace_all_validate(replace_with.items(),
reason = 'scan_push_computation_out')
else:
return False
......
......@@ -2260,6 +2260,37 @@ class T_Scan(unittest.TestCase):
assert numpy.allclose(vnh0, tnh0, atol = 1e-6)
assert numpy.allclose(vnW , tnW , atol = 1e-6)
def test_pushout_all(self):
W1 = tensor.matrix('W1')
W2 = tensor.matrix('W2')
h0 = tensor.vector('h0')
def lambda_fn(h, W1, W2):
return tensor.dot(h, W1 + W2)
o, _ = theano.scan(lambda_fn,
non_sequences =[h0,W1,W2],
n_steps = 5)
f = theano.function([h0,W1,W2], o, mode= mode_with_opt)
scan_nodes = [x for x in f.maker.env.toposort()
if isinstance(x.op,
theano.scan_module.scan_op.Scan)]
assert len(scan_nodes) == 0
seed = utt.fetch_seed()
rng = numpy.random.RandomState(seed)
floatX = theano.config.floatX
v_h = numpy.array(rng.uniform(size=(2,)), dtype= floatX)
v_W1 = numpy.array(rng.uniform(size=(2,2)), dtype=floatX)
v_W2 = numpy.array(rng.uniform(size=(2,2)), dtype = floatX)
v_out = numpy.dot(v_h, v_W1+ v_W2)
sol = numpy.zeros((5,2))
sol[:,:] = v_out
assert numpy.allclose(sol, f(v_h, v_W1, v_W2))
def test_pushout(self):
W1 = tensor.matrix('W1')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论