提交 c67b2121 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4151 from shabanian/opt

Opt
......@@ -3576,10 +3576,37 @@ def local_join_make_vector(node):
return [ret]
#################
# Exp stability #
#################
@register_stabilize
@register_specialize
@register_canonicalize
@gof.local_optimizer([T.Elemwise])
def local_expm1(node):
"""
This optimization detects exp(a)-1 and converts this to expm1(a).
"""
if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, theano.scalar.basic.Sub)):
in1, in2 = node.inputs
out = node.outputs[0]
if (in1.owner and isinstance(in1.owner.op, T.Elemwise) and isinstance(in1.owner.op.scalar_op, theano.scalar.basic.Exp) and
T.extract_constant(in2, only_process_constants=False) == 1):
in11 = in1.owner.inputs[0]
new_out = T.expm1(in11)
if new_out.dtype != out.dtype:
new_out = T.cast(new_out, dtype=out.dtype)
if new_out.type != out.type:
return
return [new_out]
###############
# Switch opts #
###############
@register_canonicalize('fast_compile', 'local_remove_switch_const_cond')
@register_specialize
@gof.local_optimizer([T.Elemwise])
......
......@@ -37,7 +37,8 @@ from theano.tensor.opt import (
Shape_i,
Assert,
MakeVector,
make_vector
make_vector,
local_expm1
)
from theano import tensor
from theano import tensor as T
......@@ -6122,6 +6123,38 @@ class TestIntDivByOne(unittest.TestCase):
assert len(divs) == 0
def test_local_expm1():
x = matrix('x')
u = T.scalar('u')
y = T.exp(x) - 1.
z = T.exp(x) - 2.
t = T.exp(x) - x
s = T.exp(u) - numpy.ones((4, 3)).astype(config.floatX)
MODE = theano.compile.get_default_mode().including('local_expm1')
f = function([x], y, mode=MODE)
g = function([x], z, mode=MODE)
h = function([x], t, mode=MODE)
r = function([u], s, mode=MODE)
x_val = numpy.random.rand(4, 3).astype(config.floatX)
f_val = f(x_val)
f_test = function([x], T.expm1(x), mode=MODE)
assert numpy.allclose(f_val, f_test(x_val))
assert any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
for n in f.maker.fgraph.toposort())
assert not any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
for n in g.maker.fgraph.toposort())
assert not any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
for n in h.maker.fgraph.toposort())
assert not any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
for n in r.maker.fgraph.toposort())
def test_local_merge_alloc():
# Add this opt to the default mode,
# otherwise, FAST_COMPILE fails.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论