提交 e8d02ddb authored 作者: Frederic's avatar Frederic

fix test as now we do not allow to change the optimizer after a mode is created

上级 6757ec46
......@@ -304,23 +304,20 @@ class test_canonize(unittest.TestCase):
# We must be sure that the Canonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion
mode = compile.mode.get_default_mode()
old_optimizer = mode._optimizer
try:
mode._optimizer = gof.Query(["canonicalize"])
mode._optimizer = mode._optimizer.excluding(
'local_elemwise_fusion')
for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases):
if isinstance(out_dtype, dict):
out_dtype = out_dtype[config.cast_policy]
f = compile.function(list(sym_inputs), g,
# we need the optimisation enabled, debug do this.
mode=mode)
opt = gof.Query(["canonicalize"])
opt = opt.excluding('local_elemwise_fusion')
mode = mode.__class__(linker=mode.linker, optimizer=opt)
for id, [g, sym_inputs, val_inputs,
nb_elemwise, out_dtype] in enumerate(cases):
if isinstance(out_dtype, dict):
out_dtype = out_dtype[config.cast_policy]
f = compile.function(list(sym_inputs), g,
# we need the optimisation enabled, debug do this.
mode=mode)
out = f(*val_inputs)
assert(len(f.maker.fgraph.toposort()) == nb_elemwise)
assert(out_dtype == out.dtype)
finally:
mode._optimizer = old_optimizer
out = f(*val_inputs)
assert(len(f.maker.fgraph.toposort()) == nb_elemwise)
assert(out_dtype == out.dtype)
def test_elemwise_multiple_inputs_optimisation2(self):
"""
......@@ -455,13 +452,12 @@ class test_canonize(unittest.TestCase):
# We must be sure that the Canonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion
mode = compile.mode.get_default_mode()
old_optimizer = mode._optimizer
try:
mode._optimizer = gof.Query(["canonicalize"])
mode._optimizer = mode._optimizer.including('ShapeOpt')
mode._optimizer = mode._optimizer.excluding(
opt = gof.Query(["canonicalize"])
opt = opt.including('ShapeOpt')
opt = opt.excluding(
'local_elemwise_fusion')
mode = mode.__class__(linker=mode.linker, optimizer=opt)
# test x / x -> 1
for id, (g, sym_inputs, val_inputs, out_dtype) in enumerate([(fx/fx, [fx], [fxv], 'float32'),
(dx/dx, [dx], [dxv], 'float64'),
......@@ -644,7 +640,7 @@ class test_canonize(unittest.TestCase):
assert numpy.allclose(out, numpy.sign(val_inputs[0]) * 2 / 3)
assert(out_dtype == out.dtype)
finally:
mode._optimizer = old_optimizer
pass
def test_abs_mul_div(self):
"""
......@@ -705,12 +701,11 @@ class test_canonize(unittest.TestCase):
# We must be sure that the Canonizer is working, but that we don't have other
# optimisation that could hide bug in the Canonizer as local_elemwise_fusion
mode = compile.mode.get_default_mode()
old_optimizer = mode._optimizer
try:
mode._optimizer = gof.Query(["canonicalize"])
mode._optimizer = mode._optimizer.excluding(
opt = gof.Query(["canonicalize"])
opt = opt.excluding(
'local_elemwise_fusion')
mode = mode.__class__(linker=mode.linker, optimizer=opt)
# test fail!
# test x / y / z -> x / (y * z)
for (g, sym_inputs, val_inputs, out_dtype) in [
......@@ -749,7 +744,7 @@ class test_canonize(unittest.TestCase):
assert(out_dtype == out.dtype)
finally:
mode._optimizer = old_optimizer
pass
def test_dont_merge_if_multiple_client(self):
""" test those case take from the comment in Canonizer
......@@ -3316,6 +3311,8 @@ class test_shapeoptimizer(unittest.TestCase):
# Register the optimization
opt.register_specialize(local_identity_noshape_to_identity_shape)
mode = theano.compile.get_default_mode().including(
'ShapeOpt', 'specialize')
# With the optimization
# The identity_shape op should not be needed anymore to compute
# the shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论