提交 a8358593 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fix pregreedy optimizer + test

上级 2e5fcea2
......@@ -1685,10 +1685,13 @@ def pre_greedy_local_optimizer(list_optimizations, out):
else:
break
return results, optimized_vars
if out.owner:
out_index = out.owner.outputs.index(out)
else:
out_index = 0
final_outs, optimized_nodes = local_recursive_function(
list_optimizations, out, {}, 0)
return final_outs[0]
return final_outs[out_index]
############
......
......@@ -3295,6 +3295,20 @@ class T_Scan(unittest.TestCase):
cost = x.sum()
self.assertRaises(ValueError, tensor.grad, cost, y0)
def test_pregreedy_optimizer(self):
W = tensor.zeros((5, 4))
bv = tensor.zeros((5,))
bh = tensor.zeros((4,))
v = tensor.matrix('v')
(bv_t, bh_t), _ = theano.scan(lambda _: [bv, bh], sequences=v,
outputs_info=[None, None])
chain, _ = theano.scan(
lambda x: tensor.dot(tensor.dot(x, W) + bh_t, W.T) + bv_t,
outputs_info=v,
n_steps=2)
theano.function([v], chain)(numpy.zeros((3, 5)))
def test_savemem_does_not_duplicate_number_of_scan_nodes(self):
var = tensor.ones(())
values, _ = theano.scan(lambda x: ([x], (), theano.scan_module.until(x)),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论