提交 15bb0b8e authored 作者: Frederic Bastien's avatar Frederic Bastien

remove Assert condition that are always true and remove Assert op that are always true. +test.

上级 71486d94
...@@ -660,6 +660,26 @@ class Assert(T.Op): ...@@ -660,6 +660,26 @@ class Assert(T.Op):
assert_ = Assert() assert_ = Assert()
@register_specialize
@gof.local_optimizer([Assert])
def local_remove_useless_assert(node):
if isinstance(node.op, Assert):
cond=[]
for c in node.inputs[1:]:
try:
const = get_constant_value(c)
if 0!=const.ndim or const==0:
#Should we raise an error here? How to be sure it is not catched?
cond.append(c)
except TypeError:
cond.append(c)
if len(cond)==0:
return [node.inputs[0]]
if len(cond)!=len(node.inputs)-1:
return [assert_(node.inputs[0],*cond)]
@gof.local_optimizer([T.Alloc]) @gof.local_optimizer([T.Alloc])
def local_alloc_elemwise(node): def local_alloc_elemwise(node):
""" """
...@@ -730,10 +750,9 @@ def local_alloc_elemwise(node): ...@@ -730,10 +750,9 @@ def local_alloc_elemwise(node):
return [node.op(*new)] return [node.op(*new)]
#TODO, T.eq if both input are the same, remove! #TODO, T.eq if both input are the same, remove!
#TODO, op that check the condition are all true and remove the Assert. Also remove the constant condition.
#TODO, global optimizer that lift the assert to the beginning of the graph. #TODO, global optimizer that lift the assert to the beginning of the graph.
#TODO, var.tag.shape to propagate the shape and lower the overhead of this op #TODO, var.tag.shape to propagate the shape and lower the overhead of this op
#TODO, when all can be optimizer do all except one #TODO, when all inputs can be optimized do all except one
theano.configparser.AddConfigVar('experimental.local_alloc_elemwise', theano.configparser.AddConfigVar('experimental.local_alloc_elemwise',
"If True enable the experimental optimization local_alloc_elemwise", "If True enable the experimental optimization local_alloc_elemwise",
......
...@@ -1093,6 +1093,51 @@ class test_assert(unittest.TestCase): ...@@ -1093,6 +1093,51 @@ class test_assert(unittest.TestCase):
f(1,1) f(1,1)
self.failUnlessRaises(AssertionError, f, 1,0) self.failUnlessRaises(AssertionError, f, 1,0)
def test1(self):
#remove assert that are always true
mode = theano.config.mode
if mode == 'FAST_COMPILE':
mode = 'FAST_RUN'
mode = compile.mode.get_mode(mode)
x=T.scalar()
f = theano.function([x],theano.tensor.opt.assert_(x,1),mode=mode)
assert f(1)==1
assert f(5)==5
topo=f.maker.env.toposort()
assert len(topo)==0
def test2(self):
#remove assert condition that are always true
mode = theano.config.mode
if mode == 'FAST_COMPILE':
mode = 'FAST_RUN'
mode = compile.mode.get_mode(mode)
x=T.scalar()
y=T.scalar()
f = theano.function([x,y],theano.tensor.opt.assert_(x,y,1),mode=mode)
assert f(1,1)==1
assert f(5,1)==5
topo=f.maker.env.toposort()
assert len(topo)==1
assert len(topo[0].inputs)==2
def test3(self):
#don't remove assert condition that are always false
mode = theano.config.mode
if mode == 'FAST_COMPILE':
mode = 'FAST_RUN'
mode = compile.mode.get_mode(mode)
x=T.scalar()
y=T.scalar()
f = theano.function([x,y],theano.tensor.opt.assert_(x,y,0),mode=mode)
self.failUnlessRaises(AssertionError, f, 1,0)
topo=f.maker.env.toposort()
assert len(topo)==1
assert len(topo[0].inputs)==3
def test_local_mul_specialize(): def test_local_mul_specialize():
# test a few cases to make sure that the basics are covered # test a few cases to make sure that the basics are covered
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论