提交 8895bc7e authored 作者: Frederic Bastien's avatar Frederic Bastien

added an optimization 1+erf(x)->erfc(-x)

上级 9b562c6f
...@@ -2458,6 +2458,24 @@ local_transposed_dot = gof.PatternSub((inplace_matrix_transpose, (T.dot, 'x', 'y ...@@ -2458,6 +2458,24 @@ local_transposed_dot = gof.PatternSub((inplace_matrix_transpose, (T.dot, 'x', 'y
(T.dot, (inplace_matrix_transpose, 'y'), (inplace_matrix_transpose, 'x'))) (T.dot, (inplace_matrix_transpose, 'y'), (inplace_matrix_transpose, 'x')))
register_canonicalize(local_transposed_dot, name='local_transposed_dot') register_canonicalize(local_transposed_dot, name='local_transposed_dot')
def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1
"""
try:
v = get_constant_value(expr)
return numpy.allclose(v, 1)
except TypeError:
return False
local_one_plus_erf = gof.PatternSub((T.add,
dict(pattern='y', constraint = _is_1),
(T.erf, 'x')),
(T.erfc, (T.neg, 'x')),
allow_multiple_clients = True,)
register_canonicalize(local_one_plus_erf, name='local_one_plus_erf')
register_stabilize(local_one_plus_erf, name='local_one_plus_erf')
register_specialize(local_one_plus_erf, name='local_one_plus_erf')
# ############### # ###############
# # Loop fusion # # # Loop fusion #
# ############### # ###############
......
...@@ -1581,6 +1581,33 @@ def test_constant_get_stabilized(): ...@@ -1581,6 +1581,33 @@ def test_constant_get_stabilized():
#When this error is fixed, the following line should be ok. #When this error is fixed, the following line should be ok.
assert f()==800,f() assert f()==800,f()
def test_local_one_plus_erf():
mode = theano.config.mode
if mode == 'FAST_COMPILE':
mode = 'FAST_RUN'
mode = compile.mode.get_mode(mode)
mode = mode.excluding('fusion').excluding('gpu')
val = numpy.asarray([0,1,2,3,30])
x = T.vector()
f = theano.function([x],1+T.erf(x), mode=mode)
print f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.neg,inplace.erfc_inplace]
f(val)
f = theano.function([x],T.erf(x)+1, mode=mode)
print f.maker.env.toposort()
assert [n.op for n in f.maker.env.toposort()]==[T.neg,inplace.erfc_inplace]
f(val)
f = theano.function([x],T.erf(x)+2, mode=mode)
topo = f.maker.env.toposort()
print topo
assert len(topo)==2
assert topo[0].op==T.erf
assert isinstance(topo[1].op,T.Elemwise)
assert isinstance(topo[1].op.scalar_op,scal.Add)
f(val)
class T_local_sum(unittest.TestCase): class T_local_sum(unittest.TestCase):
def setUp(self): def setUp(self):
self.mode = theano.compile.get_default_mode().including('canonicalize') self.mode = theano.compile.get_default_mode().including('canonicalize')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论