提交 4c9f0eaf authored 作者: Frederic's avatar Frederic

Fix test in fast_compile

上级 a14c3e9b
...@@ -1576,12 +1576,13 @@ def test_log_add(): ...@@ -1576,12 +1576,13 @@ def test_log_add():
def test_local_useless_inc_subtensor(): def test_local_useless_inc_subtensor():
x = tensor.matrix('x') x = tensor.matrix('x')
y = tensor.matrix('y') y = tensor.matrix('y')
mode = compile.get_default_mode().including("local_useless_inc_subtensor")
for sub in [slice(None), slice(None, None, -1)]: for sub in [slice(None), slice(None, None, -1)]:
o = tensor.set_subtensor(x[::, sub], y) o = tensor.set_subtensor(x[::, sub], y)
f = theano.function([x, y], o) f = theano.function([x, y], o, mode=mode)
o_shape = tensor.set_subtensor(x[::, sub], o_shape = tensor.set_subtensor(x[::, sub],
tensor.specify_shape(y, x.shape)) tensor.specify_shape(y, x.shape))
f_shape = theano.function([x, y], o_shape) f_shape = theano.function([x, y], o_shape, mode=mode)
# Test with shape info # Test with shape info
topo = f_shape.maker.fgraph.toposort() topo = f_shape.maker.fgraph.toposort()
...@@ -1614,7 +1615,7 @@ def test_local_useless_inc_subtensor(): ...@@ -1614,7 +1615,7 @@ def test_local_useless_inc_subtensor():
tensor.specify_shape(y, sub.shape)) tensor.specify_shape(y, sub.shape))
f_shape = theano.function([x, y], o_shape) f_shape = theano.function([x, y], o_shape)
topo = f_shape.maker.fgraph.toposort() topo = f_shape.maker.fgraph.toposort()
theano.printing.debugprint(f_shape) # theano.printing.debugprint(f_shape)
assert any(isinstance(n.op, tensor.IncSubtensor) for n in topo) assert any(isinstance(n.op, tensor.IncSubtensor) for n in topo)
out = f_shape([[2, 3, 6, 7]], [[8, 9]]) out = f_shape([[2, 3, 6, 7]], [[8, 9]])
assert (out == numpy.asarray([[8, 3, 9, 7]])).all() assert (out == numpy.asarray([[8, 3, 9, 7]])).all()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论