提交 285abe3b authored 作者: Amjad Almahairi's avatar Amjad Almahairi

fix test for FAST_COMPILE mode

上级 c6afafdd
...@@ -3138,13 +3138,13 @@ def test_local_fill_useless(): ...@@ -3138,13 +3138,13 @@ def test_local_fill_useless():
def assert_eqs_const(topo, val): def assert_eqs_const(topo, val):
elem = topo[0] elem = topo[0]
assert len(topo) == 1 assert len(topo) == 1, topo
assert elem.op == deep_copy_op assert elem.op == deep_copy_op, elem.op
assert len(elem.inputs) == 1 assert len(elem.inputs) == 1, elem.inputs
assert isinstance(elem.inputs[0], T.TensorConstant) assert isinstance(elem.inputs[0], T.TensorConstant), elem
assert T.extract_constant(elem.inputs[0]) == val assert T.extract_constant(elem.inputs[0]) == val, val
class Test_local_useless_elemwise_comparison(unittest.TestCase): class Test_local_useless_elemwise_comparison(unittest.TestCase):
def test_local_useless_elemwise_comparison(self): def test_local_useless_elemwise_comparison(self):
# TODO: test each case individually. # TODO: test each case individually.
...@@ -3157,23 +3157,75 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3157,23 +3157,75 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
non_sequences=None) non_sequences=None)
Z = X_sum + Y Z = X_sum + Y
theano.printing.debugprint(Z) theano.printing.debugprint(Z)
# here is the output for the debug print:
"""
Elemwise{add,no_inplace} [@A] ''
|for{cpu,scan_fn} [@B] ''
| |Subtensor{int64} [@C] ''
| | |Shape [@D] ''
| | | |Subtensor{int64::} [@E] 'X[0:]'
| | | |X [@F]
| | | |Constant{0} [@G]
| | |Constant{0} [@H]
| |Subtensor{:int64:} [@I] ''
| | |Subtensor{int64::} [@E] 'X[0:]'
| | |ScalarFromTensor [@J] ''
| | |Subtensor{int64} [@C] ''
| |Subtensor{int64} [@C] ''
|Y [@K]
Inner graphs of the scan ops:
for{cpu,scan_fn} [@B] ''
>Sum{acc_dtype=float64} [@L] ''
> |X[t] [@M] -> [@I]
"""
mode = theano.compile.get_default_mode().excluding('fusion') mode = theano.compile.get_default_mode().excluding('fusion')
f = theano.function([X, Y], Z, mode=mode) f = theano.function([X, Y], Z, mode=mode)
theano.printing.debugprint(f, print_type=True) theano.printing.debugprint(f, print_type=True)
# here is the output for the debug print:
"""
Elemwise{Add}[(0, 0)] [@A] <TensorType(float64, vector)> '' 7
|for{cpu,scan_fn} [@B] <TensorType(float64, vector)> '' 6
| |Shape_i{0} [@C] <TensorType(int64, scalar)> '' 0
| | |X [@D] <TensorType(float64, matrix)>
| |Subtensor{int64:int64:int8} [@E] <TensorType(float64, matrix)> '' 5
| | |X [@D] <TensorType(float64, matrix)>
| | |ScalarFromTensor [@F] <int64> '' 4
| | | |Elemwise{switch,no_inplace} [@G] <TensorType(int64, scalar)> '' 3
| | | |Elemwise{le,no_inplace} [@H] <TensorType(int8, scalar)> '' 2
| | | | |Shape_i{0} [@C] <TensorType(int64, scalar)> '' 0
| | | | |TensorConstant{0} [@I] <TensorType(int8, scalar)>
| | | |TensorConstant{0} [@I] <TensorType(int8, scalar)>
| | | |TensorConstant{0} [@J] <TensorType(int64, scalar)>
| | |ScalarFromTensor [@K] <int64> '' 1
| | | |Shape_i{0} [@C] <TensorType(int64, scalar)> '' 0
| | |Constant{1} [@L] <int8>
| |Shape_i{0} [@C] <TensorType(int64, scalar)> '' 0
|Y [@M] <TensorType(float64, vector)>
Inner graphs of the scan ops:
for{cpu,scan_fn} [@B] <TensorType(float64, vector)> ''
>Sum{acc_dtype=float64} [@N] <TensorType(float64, scalar)> ''
> |X[t] [@O] <TensorType(float64, vector)> -> [@E]
"""
def test_inequality_with_self(self): def test_inequality_with_self(self):
x = T.scalar('x', dtype=config.floatX) x = T.scalar('x', dtype=config.floatX)
mode = theano.compile.get_default_mode().including('local_useless_elemwise_comparison')
f = theano.function([x], T.lt(x, x))
f = theano.function([x], T.lt(x, x), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 0) assert_eqs_const(f.maker.fgraph.toposort(), 0)
f = theano.function([x], T.le(x, x)) f = theano.function([x], T.le(x, x), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 1) assert_eqs_const(f.maker.fgraph.toposort(), 1)
f = theano.function([x], T.gt(x, x)) f = theano.function([x], T.gt(x, x), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 0) assert_eqs_const(f.maker.fgraph.toposort(), 0)
f = theano.function([x], T.ge(x, x)) f = theano.function([x], T.ge(x, x), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 1) assert_eqs_const(f.maker.fgraph.toposort(), 1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论