提交 e8d02ddb authored 作者: Frederic's avatar Frederic

fix test as now we do not allow to change the optimizer after a mode is created

上级 6757ec46
...@@ -304,12 +304,11 @@ class test_canonize(unittest.TestCase): ...@@ -304,12 +304,11 @@ class test_canonize(unittest.TestCase):
# We must be sure that the Canonizer is working, but that we don't have other # We must be sure that the Canonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion # optimisation that could hide bug in the Canonizer as local_elemwise_fusion
mode = compile.mode.get_default_mode() mode = compile.mode.get_default_mode()
old_optimizer = mode._optimizer opt = gof.Query(["canonicalize"])
try: opt = opt.excluding('local_elemwise_fusion')
mode._optimizer = gof.Query(["canonicalize"]) mode = mode.__class__(linker=mode.linker, optimizer=opt)
mode._optimizer = mode._optimizer.excluding( for id, [g, sym_inputs, val_inputs,
'local_elemwise_fusion') nb_elemwise, out_dtype] in enumerate(cases):
for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases):
if isinstance(out_dtype, dict): if isinstance(out_dtype, dict):
out_dtype = out_dtype[config.cast_policy] out_dtype = out_dtype[config.cast_policy]
f = compile.function(list(sym_inputs), g, f = compile.function(list(sym_inputs), g,
...@@ -319,8 +318,6 @@ class test_canonize(unittest.TestCase): ...@@ -319,8 +318,6 @@ class test_canonize(unittest.TestCase):
out = f(*val_inputs) out = f(*val_inputs)
assert(len(f.maker.fgraph.toposort()) == nb_elemwise) assert(len(f.maker.fgraph.toposort()) == nb_elemwise)
assert(out_dtype == out.dtype) assert(out_dtype == out.dtype)
finally:
mode._optimizer = old_optimizer
def test_elemwise_multiple_inputs_optimisation2(self): def test_elemwise_multiple_inputs_optimisation2(self):
""" """
...@@ -455,13 +452,12 @@ class test_canonize(unittest.TestCase): ...@@ -455,13 +452,12 @@ class test_canonize(unittest.TestCase):
# We must be sure that the Canonizer is working, but that we don't have other # We must be sure that the Canonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion # optimisation that could hide bug in the Canonizer as local_elemwise_fusion
mode = compile.mode.get_default_mode() mode = compile.mode.get_default_mode()
old_optimizer = mode._optimizer
try: try:
mode._optimizer = gof.Query(["canonicalize"]) opt = gof.Query(["canonicalize"])
mode._optimizer = mode._optimizer.including('ShapeOpt') opt = opt.including('ShapeOpt')
mode._optimizer = mode._optimizer.excluding( opt = opt.excluding(
'local_elemwise_fusion') 'local_elemwise_fusion')
mode = mode.__class__(linker=mode.linker, optimizer=opt)
# test x / x -> 1 # test x / x -> 1
for id, (g, sym_inputs, val_inputs, out_dtype) in enumerate([(fx/fx, [fx], [fxv], 'float32'), for id, (g, sym_inputs, val_inputs, out_dtype) in enumerate([(fx/fx, [fx], [fxv], 'float32'),
(dx/dx, [dx], [dxv], 'float64'), (dx/dx, [dx], [dxv], 'float64'),
...@@ -644,7 +640,7 @@ class test_canonize(unittest.TestCase): ...@@ -644,7 +640,7 @@ class test_canonize(unittest.TestCase):
assert numpy.allclose(out, numpy.sign(val_inputs[0]) * 2 / 3) assert numpy.allclose(out, numpy.sign(val_inputs[0]) * 2 / 3)
assert(out_dtype == out.dtype) assert(out_dtype == out.dtype)
finally: finally:
mode._optimizer = old_optimizer pass
def test_abs_mul_div(self): def test_abs_mul_div(self):
""" """
...@@ -705,12 +701,11 @@ class test_canonize(unittest.TestCase): ...@@ -705,12 +701,11 @@ class test_canonize(unittest.TestCase):
# We must be sure that the Canonizer is working, but that we don't have other # We must be sure that the Canonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion # optimisation that could hide bug in the Canonizer as local_elemwise_fusion
mode = compile.mode.get_default_mode() mode = compile.mode.get_default_mode()
old_optimizer = mode._optimizer
try: try:
mode._optimizer = gof.Query(["canonicalize"]) opt = gof.Query(["canonicalize"])
mode._optimizer = mode._optimizer.excluding( opt = opt.excluding(
'local_elemwise_fusion') 'local_elemwise_fusion')
mode = mode.__class__(linker=mode.linker, optimizer=opt)
# test fail! # test fail!
# test x / y / z -> x / (y * z) # test x / y / z -> x / (y * z)
for (g, sym_inputs, val_inputs, out_dtype) in [ for (g, sym_inputs, val_inputs, out_dtype) in [
...@@ -749,7 +744,7 @@ class test_canonize(unittest.TestCase): ...@@ -749,7 +744,7 @@ class test_canonize(unittest.TestCase):
assert(out_dtype == out.dtype) assert(out_dtype == out.dtype)
finally: finally:
mode._optimizer = old_optimizer pass
def test_dont_merge_if_multiple_client(self): def test_dont_merge_if_multiple_client(self):
""" test those case take from the comment in Canonizer """ test those case take from the comment in Canonizer
...@@ -3316,6 +3311,8 @@ class test_shapeoptimizer(unittest.TestCase): ...@@ -3316,6 +3311,8 @@ class test_shapeoptimizer(unittest.TestCase):
# Register the optimization # Register the optimization
opt.register_specialize(local_identity_noshape_to_identity_shape) opt.register_specialize(local_identity_noshape_to_identity_shape)
mode = theano.compile.get_default_mode().including(
'ShapeOpt', 'specialize')
# With the optimization # With the optimization
# The identity_shape op should not be needed anymore to compute # The identity_shape op should not be needed anymore to compute
# the shape # the shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论