提交 071d31b6 authored 作者: Frederic Bastien's avatar Frederic Bastien

add optimization to transform erf to erfc and erfc to erf in some case.

上级 091c2ab8
...@@ -2467,6 +2467,16 @@ def _is_1(expr): ...@@ -2467,6 +2467,16 @@ def _is_1(expr):
except TypeError: except TypeError:
return False return False
def _is_minus1(expr):
"""rtype bool. True iff expr is a constant close to -1
"""
try:
v = get_constant_value(expr)
return numpy.allclose(v, -1)
except TypeError:
return False
#1+erf(x)=>erfc(-x)
local_one_plus_erf = gof.PatternSub((T.add, local_one_plus_erf = gof.PatternSub((T.add,
dict(pattern='y', constraint = _is_1), dict(pattern='y', constraint = _is_1),
(T.erf, 'x')), (T.erf, 'x')),
...@@ -2476,6 +2486,71 @@ register_canonicalize(local_one_plus_erf, name='local_one_plus_erf') ...@@ -2476,6 +2486,71 @@ register_canonicalize(local_one_plus_erf, name='local_one_plus_erf')
register_stabilize(local_one_plus_erf, name='local_one_plus_erf') register_stabilize(local_one_plus_erf, name='local_one_plus_erf')
register_specialize(local_one_plus_erf, name='local_one_plus_erf') register_specialize(local_one_plus_erf, name='local_one_plus_erf')
#1-erf(x)=>erfc(x)
local_one_minus_erf = gof.PatternSub((T.sub,
dict(pattern='y', constraint = _is_1),
(T.erf, 'x')),
(T.erfc, 'x'),
allow_multiple_clients = True,)
register_canonicalize(local_one_minus_erf, name='local_one_minus_erf')
register_stabilize(local_one_minus_erf, name='local_one_minus_erf')
register_specialize(local_one_minus_erf, name='local_one_minus_erf')
#1+(-erf(x))=>erfc(x)
#This is a different graph then the previous as the canonicalize don't work completly
local_one_plus_neg_erf = gof.PatternSub((T.add,
dict(pattern='y', constraint = _is_1),
(T.neg,(T.erf, 'x'))),
(T.erfc, 'x'),
allow_multiple_clients = True,)
register_canonicalize(local_one_plus_neg_erf, name='local_one_plus_neg_erf')
register_stabilize(local_one_plus_neg_erf, name='local_one_plus_neg_erf')
register_specialize(local_one_plus_neg_erf, name='local_one_plus_neg_erf')
#(-1)+erf(x) => -erfc(x)
#don't need erf(x)+(-1) as the canonicalize will put the -1 as the first argument.
local_erf_minus_one = gof.PatternSub((T.add,
dict(pattern='y', constraint = _is_minus1),
(T.erf, 'x')),
(T.neg,(T.erfc, 'x')),
allow_multiple_clients = True,)
register_canonicalize(local_erf_minus_one, name='local_erf_minus_one')
register_stabilize(local_erf_minus_one, name='local_erf_minus_one')
register_specialize(local_erf_minus_one, name='local_erf_minus_one')
#1-erfc(x) => erf(x)
local_one_minus_erfc = gof.PatternSub((T.sub,
dict(pattern='y', constraint = _is_1),
(T.erfc, 'x')),
(T.erf, 'x'),
allow_multiple_clients = True,)
register_canonicalize(local_one_minus_erfc, name='local_one_minus_erfc')
register_stabilize(local_one_minus_erfc, name='local_one_minus_erfc')
register_specialize(local_one_minus_erfc, name='local_one_minus_erfc')
#1+(-erfc(x)) => erf(x)
#This is a different graph then the previous as the canonicalize don't work completly
local_one_add_neg_erfc = gof.PatternSub((T.add,
dict(pattern='y', constraint = _is_1),
(T.neg,(T.erfc, 'x'))),
(T.erf, 'x'),
allow_multiple_clients = True,)
register_canonicalize(local_one_add_neg_erfc, name='local_one_add_neg_erfc')
register_stabilize(local_one_add_neg_erfc, name='local_one_add_neg_erfc')
register_specialize(local_one_add_neg_erfc, name='local_one_add_neg_erfc')
#(-1)+erfc(-x)=>erf(x)
local_erf_neg_minus_one = gof.PatternSub((T.add,
dict(pattern='y', constraint = _is_minus1),
(T.erfc, (T.neg,'x'))),
(T.erf, 'x'),
allow_multiple_clients = True,)
register_canonicalize(local_erf_neg_minus_one, name='local_erf_neg_minus_one')
register_stabilize(local_erf_neg_minus_one, name='local_erf_neg_minus_one')
register_specialize(local_erf_neg_minus_one, name='local_erf_neg_minus_one')
#-erfc(x)+1=>erf(x)
# ############### # ###############
# # Loop fusion # # # Loop fusion #
# ############### # ###############
......
...@@ -1581,32 +1581,137 @@ def test_constant_get_stabilized(): ...@@ -1581,32 +1581,137 @@ def test_constant_get_stabilized():
#When this error is fixed, the following line should be ok. #When this error is fixed, the following line should be ok.
assert f()==800,f() assert f()==800,f()
def test_local_one_plus_erf(): class T_local_erf(unittest.TestCase):
mode = theano.config.mode def setUp(self):
if mode == 'FAST_COMPILE': self.mode = theano.compile.mode.get_default_mode().including('canonicalize').including('fast_run').excluding('fusion').excluding('gpu')
mode = 'FAST_RUN'
mode = compile.mode.get_mode(mode)
mode = mode.excluding('fusion').excluding('gpu')
val = numpy.asarray([0,1,2,3,30]) def test_local_one_plus_erf(self):
val = numpy.asarray([-30,-3,-2,-1,0,1,2,3,30])
x = T.vector()
x = T.vector() f = theano.function([x],1+T.erf(x), mode=self.mode)
f = theano.function([x],1+T.erf(x), mode=mode) print f.maker.env.toposort()
print f.maker.env.toposort() assert [n.op for n in f.maker.env.toposort()]==[T.neg,inplace.erfc_inplace], f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.neg,inplace.erfc_inplace] f(val)
f(val)
f = theano.function([x],T.erf(x)+1, mode=mode) f = theano.function([x],T.erf(x)+1, mode=self.mode)
print f.maker.env.toposort() print f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.neg,inplace.erfc_inplace] assert [n.op for n in f.maker.env.toposort()]==[T.neg,inplace.erfc_inplace], f.maker.env.toposort()
f(val) f(val)
f = theano.function([x],T.erf(x)+2, mode=mode)
topo = f.maker.env.toposort() f = theano.function([x],T.erf(x)+2, mode=self.mode)
print topo topo = f.maker.env.toposort()
assert len(topo)==2 print topo
assert topo[0].op==T.erf assert len(topo)==2
assert isinstance(topo[1].op,T.Elemwise) assert topo[0].op==T.erf
assert isinstance(topo[1].op.scalar_op,scal.Add) assert isinstance(topo[1].op,T.Elemwise)
f(val) assert isinstance(topo[1].op.scalar_op,scal.Add)
f(val)
def test_local_one_minus_erf(self):
val = numpy.asarray([-30,-3,-2,-1,0,1,2,3,30])
x = T.vector()
f = theano.function([x],1-T.erf(x), mode=self.mode)
print f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.erfc], f.maker.env.toposort()
print f(val)
f = theano.function([x],1+(-T.erf(x)), mode=self.mode)
print f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.erfc], f.maker.env.toposort()
print f(val)
f = theano.function([x],(-T.erf(x))+1, mode=self.mode)
print f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.erfc], f.maker.env.toposort()
print f(val)
f = theano.function([x],2-T.erf(x), mode=self.mode)
topo = f.maker.env.toposort()
print topo
assert len(topo)==2, f.maker.env.toposort()
assert topo[0].op==T.erf, f.maker.env.toposort()
assert isinstance(topo[1].op,T.Elemwise), f.maker.env.toposort()
assert isinstance(topo[1].op.scalar_op,scal.Add) or isinstance(topo[1].op.scalar_op,scal.Sub), f.maker.env.toposort()
print f(val)
def test_local_erf_minus_one(self):
val = numpy.asarray([-30,-3,-2,-1,0,1,2,3,30])
x = T.vector()
f = theano.function([x],T.erf(x)-1, mode=self.mode)
print f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.erfc,inplace.neg_inplace]
print f(val)
f = theano.function([x],T.erf(x)+(-1), mode=self.mode)
print f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.erfc,inplace.neg_inplace]
print f(val)
f = theano.function([x],-1+T.erf(x), mode=self.mode)
print f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.erfc,inplace.neg_inplace]
print f(val)
f = theano.function([x],T.erf(x)-2, mode=self.mode)
topo = f.maker.env.toposort()
print topo
assert len(topo)==2
assert topo[0].op==T.erf
assert isinstance(topo[1].op,T.Elemwise)
assert isinstance(topo[1].op.scalar_op,scal.Add) or isinstance(topo[1].op.scalar_op,scal.Sub)
print f(val)
class T_local_erfc(unittest.TestCase):
def setUp(self):
self.mode = theano.compile.mode.get_default_mode().including('canonicalize').including('fast_run').excluding('fusion').excluding('gpu')
def test_local_one_minus_erfc(self):
""" test opt: 1-erfc(x) => erf(x) and -erfc(x)+1 => erf(x)
"""
val = numpy.asarray([-30,-3,-2,-1,0,1,2,3,30])
x = T.vector()
f = theano.function([x],1-T.erfc(x), mode=self.mode)
theano.printing.debugprint(f)
assert [n.op for n in f.maker.env.toposort()]==[T.erf], f.maker.env.toposort()
print f(val)
f = theano.function([x],(-T.erfc(x))+1, mode=self.mode)
theano.printing.debugprint(f)
assert [n.op for n in f.maker.env.toposort()]==[T.erf], f.maker.env.toposort()
print f(val)
f = theano.function([x],2-T.erfc(x), mode=self.mode)
topo = f.maker.env.toposort()
theano.printing.debugprint(f)
assert len(topo)==2, f.maker.env.toposort()
assert topo[0].op==T.erfc, f.maker.env.toposort()
assert isinstance(topo[1].op,T.Elemwise), f.maker.env.toposort()
assert isinstance(topo[1].op.scalar_op,scal.Sub), f.maker.env.toposort()
print f(val)
def test_local_erf_neg_minus_one(self):
""" test opt: (-1)+erfc(-x)=>erf(x)"""
val = numpy.asarray([-30,-3,-2,-1,0,1,2,3,30])
x = T.vector()
f = theano.function([x],-1+T.erfc(-x), mode=self.mode)
theano.printing.debugprint(f)
assert [n.op for n in f.maker.env.toposort()]==[T.erf], f.maker.env.toposort()
print f(val)
f = theano.function([x],T.erfc(-x)-1, mode=self.mode)
theano.printing.debugprint(f)
assert [n.op for n in f.maker.env.toposort()]==[T.erf], f.maker.env.toposort()
print f(val)
f = theano.function([x],T.erfc(-x)+(-1), mode=self.mode)
theano.printing.debugprint(f)
assert [n.op for n in f.maker.env.toposort()]==[T.erf], f.maker.env.toposort()
print f(val)
class T_local_sum(unittest.TestCase): class T_local_sum(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论