提交 af6a20e9 authored 作者: Frederic Bastien's avatar Frederic Bastien

make stabilization about erf happen in the stabilize phase and test that.

上级 ae415c3b
......@@ -2576,6 +2576,16 @@ register_canonicalize(local_one_minus_erf, name='local_one_minus_erf')
register_stabilize(local_one_minus_erf, name='local_one_minus_erf')
register_specialize(local_one_minus_erf, name='local_one_minus_erf')
local_one_minus_erf2 = gof.PatternSub((T.add,
1,
(T.mul,-1,(T.erf, 'x'))),
(T.erfc, 'x'),
allow_multiple_clients = True,
name='local_one_minus_erf2')
register_canonicalize(local_one_minus_erf2)
register_stabilize(local_one_minus_erf2)
register_specialize(local_one_minus_erf2)
#1+(-erf(x))=>erfc(x)
#This is a different graph then the previous as the canonicalize don't work completly
local_one_plus_neg_erf = gof.PatternSub((T.add,
......
......@@ -1637,7 +1637,8 @@ class T_local_switch_sink(unittest.TestCase):
class T_local_erf(unittest.TestCase):
def setUp(self):
self.mode = theano.compile.mode.get_default_mode().including('canonicalize').including('fast_run').excluding('fusion').excluding('gpu')
self.mode = theano.compile.mode.get_default_mode().including('canonicalize','fast_run').excluding('gpu','fusion')
self.mode._optimizer.position_cutoff = 1.50001
def test_local_one_plus_erf(self):
val = numpy.asarray([-30,-3,-2,-1,0,1,2,3,30])
......@@ -1645,12 +1646,12 @@ class T_local_erf(unittest.TestCase):
f = theano.function([x],1+T.erf(x), mode=self.mode)
print f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.neg,inplace.erfc_inplace], f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.mul,T.erfc], f.maker.env.toposort()
f(val)
f = theano.function([x],T.erf(x)+1, mode=self.mode)
print f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.neg,inplace.erfc_inplace], f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.mul,T.erfc], f.maker.env.toposort()
f(val)
f = theano.function([x],T.erf(x)+2, mode=self.mode)
......@@ -1696,17 +1697,17 @@ class T_local_erf(unittest.TestCase):
f = theano.function([x],T.erf(x)-1, mode=self.mode)
print f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.erfc,inplace.neg_inplace]
assert [n.op for n in f.maker.env.toposort()]==[T.erfc,T.mul]
print f(val)
f = theano.function([x],T.erf(x)+(-1), mode=self.mode)
print f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.erfc,inplace.neg_inplace]
assert [n.op for n in f.maker.env.toposort()]==[T.erfc,T.mul]
print f(val)
f = theano.function([x],-1+T.erf(x), mode=self.mode)
print f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.erfc,inplace.neg_inplace]
assert [n.op for n in f.maker.env.toposort()]==[T.erfc,T.mul]
print f(val)
f = theano.function([x],T.erf(x)-2, mode=self.mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论