提交 8fccf460 authored 作者: Frederic's avatar Frederic

fix gh-1344. Scan opt error/warning.

上级 c0291c58
......@@ -506,7 +506,7 @@ class PushOutSeqScan(gof.Optimizer):
replace_with[x] = y
# We need to add one extra dimension to the outputs
if replace_with:
if replace_with and len(replace_with) == len(node.outputs):
fgraph.replace_all_validate_remove(
replace_with.items(),
remove=[node],
......
......@@ -2739,6 +2739,26 @@ class T_Scan(unittest.TestCase):
assert len([x for x in scan_node.op.fn.maker.fgraph.toposort()
if isinstance(x.op, theano.tensor.Elemwise)]) == 0
def test_pushout_nomodif(self):
inp = tensor.matrix('inp')
def fn(i, i_tm1):
return i + 10, i_tm1
([i_t, i_tm1], _) = theano.scan(
fn, sequences=[inp],
outputs_info=[numpy.asarray([0.0, 0.0], theano.config.floatX),
None])
f = theano.function([inp], [i_t, i_tm1])
val = numpy.arange(10).reshape(5, 2)
ret = f(val)
utt.assert_allclose(ret[0], val+10)
utt.assert_allclose(ret[1], [[0., 0.],
[10., 11.],
[12., 13.],
[14., 15.],
[16., 17.]])
def test_alloc_inputs1(self):
W1 = tensor.matrix('W1')
W2 = tensor.matrix('W2')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论