提交 ad4c3a84 authored 作者: Frederic's avatar Frederic

Add optional opt local_remove_all_assert.

上级 a4b15802
...@@ -1582,6 +1582,23 @@ def local_remove_useless_assert(node): ...@@ -1582,6 +1582,23 @@ def local_remove_useless_assert(node):
return [assert_(node.inputs[0], *cond)] return [assert_(node.inputs[0], *cond)]
@gof.local_optimizer([Assert])
def local_remove_all_assert(node):
"""An optimization disable by default that remove all assert from
the graph.
:note: See the :ref:`unsafe` section to know how to enable it.
"""
if not isinstance(node.op, Assert):
return
return [node.inputs[0]]
# Disabled by default
compile.optdb['canonicalize'].register('local_remove_all_assert',
local_remove_all_assert)
@register_specialize @register_specialize
@gof.local_optimizer([T.Elemwise]) @gof.local_optimizer([T.Elemwise])
def local_elemwise_alloc(node): def local_elemwise_alloc(node):
......
...@@ -3224,8 +3224,8 @@ class test_assert(utt.InferShapeTester): ...@@ -3224,8 +3224,8 @@ class test_assert(utt.InferShapeTester):
f(1, 1) f(1, 1)
self.assertRaises(AssertionError, f, 1, 0) self.assertRaises(AssertionError, f, 1, 0)
def test1(self): def test_local_remove_useless_assert1(self):
#remove assert that are always true # remove assert that are always true
mode = theano.config.mode mode = theano.config.mode
if mode == 'FAST_COMPILE': if mode == 'FAST_COMPILE':
mode = 'FAST_RUN' mode = 'FAST_RUN'
...@@ -3239,8 +3239,8 @@ class test_assert(utt.InferShapeTester): ...@@ -3239,8 +3239,8 @@ class test_assert(utt.InferShapeTester):
assert len(topo) == 1 assert len(topo) == 1
assert topo[0].op == deep_copy_op assert topo[0].op == deep_copy_op
def test2(self): def test_test_local_remove_useless_assert2(self):
#remove assert condition that are always true # remove assert condition that are always true
mode = theano.config.mode mode = theano.config.mode
if mode == 'FAST_COMPILE': if mode == 'FAST_COMPILE':
mode = 'FAST_RUN' mode = 'FAST_RUN'
...@@ -3257,8 +3257,8 @@ class test_assert(utt.InferShapeTester): ...@@ -3257,8 +3257,8 @@ class test_assert(utt.InferShapeTester):
assert len(topo[0].inputs) == 2 assert len(topo[0].inputs) == 2
assert topo[1].op == deep_copy_op assert topo[1].op == deep_copy_op
def test3(self): def test_local_remove_useless_assert3(self):
#don't remove assert condition that are always false # don't remove assert condition that are always false
mode = theano.config.mode mode = theano.config.mode
if mode == 'FAST_COMPILE': if mode == 'FAST_COMPILE':
mode = 'FAST_RUN' mode = 'FAST_RUN'
...@@ -3274,6 +3274,22 @@ class test_assert(utt.InferShapeTester): ...@@ -3274,6 +3274,22 @@ class test_assert(utt.InferShapeTester):
assert len(topo[0].inputs) == 3 assert len(topo[0].inputs) == 3
assert topo[1].op == deep_copy_op assert topo[1].op == deep_copy_op
def test_local_remove_all_assert1(self):
# remove assert condition that are unknown
mode = theano.config.mode
if mode == 'FAST_COMPILE':
mode = 'FAST_RUN'
mode = compile.mode.get_mode(mode).including('local_remove_all_assert')
x = T.scalar()
y = T.scalar()
f = theano.function([x, y], theano.tensor.opt.assert_op(x, y),
mode=mode)
f(1, 0) # Without opt, it should fail.
topo = f.maker.fgraph.toposort()
assert len(topo) == 1, topo
assert topo[0].op == deep_copy_op, topo
def test_infer_shape(self): def test_infer_shape(self):
adscal = dscalar() adscal = dscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论