提交 21181c87 authored 作者: Frederic's avatar Frederic

Fix tests folling fusion changes.

The 2 subtensors tests change probably due to optimization order changes. The 2->22 is that in the past, the Composite wasn't flattened, so the outer Composite was smaller.
上级 e2ece259
...@@ -1706,8 +1706,8 @@ class test_local_subtensor_lift(unittest.TestCase): ...@@ -1706,8 +1706,8 @@ class test_local_subtensor_lift(unittest.TestCase):
f = function([x, y, z], tensor.exp(x + y + z)[0], mode=mode_opt) f = function([x, y, z], tensor.exp(x + y + z)[0], mode=mode_opt)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
assert isinstance(prog[1].op, tensor.DimShuffle) assert isinstance(prog[0].op, tensor.DimShuffle)
assert isinstance(prog[0].op, tensor.Subtensor) # first subtensor assert isinstance(prog[1].op, tensor.Subtensor) # first subtensor
assert isinstance(prog[2].op, tensor.Subtensor) # first subtensor assert isinstance(prog[2].op, tensor.Subtensor) # first subtensor
assert isinstance(prog[3].op.scalar_op, theano.scalar. assert isinstance(prog[3].op.scalar_op, theano.scalar.
Composite) # Composite{add,add} Composite) # Composite{add,add}
...@@ -1723,8 +1723,8 @@ class test_local_subtensor_lift(unittest.TestCase): ...@@ -1723,8 +1723,8 @@ class test_local_subtensor_lift(unittest.TestCase):
f = function([x, y, z], tensor.exp(x + y + z)[0:2], mode=mode_opt) f = function([x, y, z], tensor.exp(x + y + z)[0:2], mode=mode_opt)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
assert isinstance(prog[1].op, tensor.DimShuffle) assert isinstance(prog[0].op, tensor.DimShuffle)
assert isinstance(prog[0].op, tensor.Subtensor) # first subtensor assert isinstance(prog[1].op, tensor.Subtensor) # first subtensor
assert isinstance(prog[2].op, tensor.Subtensor) # first subtensor assert isinstance(prog[2].op, tensor.Subtensor) # first subtensor
assert isinstance(prog[3].op.scalar_op, theano.scalar. assert isinstance(prog[3].op.scalar_op, theano.scalar.
Composite) # Composite{add,add} Composite) # Composite{add,add}
...@@ -3432,7 +3432,7 @@ class T_local_erfc(unittest.TestCase): ...@@ -3432,7 +3432,7 @@ class T_local_erfc(unittest.TestCase):
assert len(f.maker.fgraph.apply_nodes) == 1, len(f.maker.fgraph.apply_nodes) assert len(f.maker.fgraph.apply_nodes) == 1, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
assert len(f.maker.fgraph.toposort()[0].fgraph.toposort()[ assert len(f.maker.fgraph.toposort()[0].fgraph.toposort()[
0].op.scalar_op.fgraph.apply_nodes)==2,len(f.maker.fgraph.toposort()[0].fgraph.toposort()[0].op.scalar_op.fgraph.apply_nodes) 0].op.scalar_op.fgraph.apply_nodes)==22,len(f.maker.fgraph.toposort()[0].fgraph.toposort()[0].op.scalar_op.fgraph.apply_nodes)
#TODO: fix this problem #TODO: fix this problem
if theano.config.floatX=="float32" and theano.config.mode in ["DebugMode", "DEBUG_MODE"]: if theano.config.floatX=="float32" and theano.config.mode in ["DebugMode", "DEBUG_MODE"]:
raise KnownFailureTest( raise KnownFailureTest(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论