提交 ad4c3a84 authored 作者: Frederic's avatar Frederic

Add optional opt local_remove_all_assert.

上级 a4b15802
......@@ -1582,6 +1582,23 @@ def local_remove_useless_assert(node):
return [assert_(node.inputs[0], *cond)]
@gof.local_optimizer([Assert])
def local_remove_all_assert(node):
"""An optimization disable by default that remove all assert from
the graph.
:note: See the :ref:`unsafe` section to know how to enable it.
"""
if not isinstance(node.op, Assert):
return
return [node.inputs[0]]
# Disabled by default
compile.optdb['canonicalize'].register('local_remove_all_assert',
local_remove_all_assert)
@register_specialize
@gof.local_optimizer([T.Elemwise])
def local_elemwise_alloc(node):
......
......@@ -3224,8 +3224,8 @@ class test_assert(utt.InferShapeTester):
f(1, 1)
self.assertRaises(AssertionError, f, 1, 0)
def test1(self):
#remove assert that are always true
def test_local_remove_useless_assert1(self):
# remove assert that are always true
mode = theano.config.mode
if mode == 'FAST_COMPILE':
mode = 'FAST_RUN'
......@@ -3239,8 +3239,8 @@ class test_assert(utt.InferShapeTester):
assert len(topo) == 1
assert topo[0].op == deep_copy_op
def test2(self):
#remove assert condition that are always true
def test_test_local_remove_useless_assert2(self):
# remove assert condition that are always true
mode = theano.config.mode
if mode == 'FAST_COMPILE':
mode = 'FAST_RUN'
......@@ -3257,8 +3257,8 @@ class test_assert(utt.InferShapeTester):
assert len(topo[0].inputs) == 2
assert topo[1].op == deep_copy_op
def test3(self):
#don't remove assert condition that are always false
def test_local_remove_useless_assert3(self):
# don't remove assert condition that are always false
mode = theano.config.mode
if mode == 'FAST_COMPILE':
mode = 'FAST_RUN'
......@@ -3274,6 +3274,22 @@ class test_assert(utt.InferShapeTester):
assert len(topo[0].inputs) == 3
assert topo[1].op == deep_copy_op
def test_local_remove_all_assert1(self):
# remove assert condition that are unknown
mode = theano.config.mode
if mode == 'FAST_COMPILE':
mode = 'FAST_RUN'
mode = compile.mode.get_mode(mode).including('local_remove_all_assert')
x = T.scalar()
y = T.scalar()
f = theano.function([x, y], theano.tensor.opt.assert_op(x, y),
mode=mode)
f(1, 0) # Without opt, it should fail.
topo = f.maker.fgraph.toposort()
assert len(topo) == 1, topo
assert topo[0].op == deep_copy_op, topo
def test_infer_shape(self):
adscal = dscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论