提交 b453a024 authored 作者: Samira Shabanian's avatar Samira Shabanian

opt is almost done

上级 0aa5ff77
......@@ -3576,6 +3576,38 @@ def local_join_make_vector(node):
return [ret]
#################
# Exp stability #
#################
@register_stabilize
@register_specialize
@register_canonicalize
@gof.local_optimizer([T.Elemwise])
def local_expm1(node):
"""
This optimization detects exp(a)-1 and converts this to expm1(a).
"""
if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, theano.scalar.basic.Sub)):
in1, in2 = node.inputs
out = node.outputs[0]
#import pdb; pdb.set_trace()
if (in1.owner and isinstance(in1.owner.op, T.Elemwise) and isinstance(in1.owner.op.scalar_op, theano.scalar.basic.Exp) and
T.extract_constant(in2) == 1):
in11 = in1.owner.inputs[0]
new_out = T.expm1(in11)
# import pdb; pdb.set_trace()
if new_out.dtype != out.dtype:
new_out = T.cast(new_out, dtype=out.dtype)
# The ones could have forced a specific length
if new_out.type != out.type:
new_out = broadcast_like(new_out, out, node.fgraph)
return [new_out]
###############
# Switch opts #
###############
......@@ -6756,7 +6788,7 @@ def local_grad_clip(node):
@register_stabilize
@register_specialize
@gof.local_optimizer([T.Alloc])
def local_merge_alloc(node):
def local_merge_allo(cnode):
# This opt takes care of several cases:
# Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
# Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w)
......
......@@ -37,7 +37,8 @@ from theano.tensor.opt import (
Shape_i,
Assert,
MakeVector,
make_vector
make_vector,
local_expm1
)
from theano import tensor
from theano import tensor as T
......@@ -6122,6 +6123,36 @@ class TestIntDivByOne(unittest.TestCase):
assert len(divs) == 0
def test_local_expm1():
x = matrix('x')
u = T.scalar('u')
y = T.exp(x) - 1.
z = T.exp(x) - 2.
t = T.exp(x) - x
s = T.exp(u) - numpy.ones((4, 3)).astype(config.floatX)
#f = function([x], y)
#g = function([x], z)
#h = function([x], t)
r = function([u],s)
x_val = numpy.random.rand(4, 3).astype(config.floatX)
#f_val = f(x_val)
#g_val = g(x_val)
#h_val = h(x_val)
u_val = r(1.0)
#assert any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
# for n in f.maker.fgraph.toposort())
#assert not any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
# for n in g.maker.fgraph.toposort())
#assert not any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
# for n in h.maker.fgraph.toposort())
def test_local_merge_alloc():
# Add this opt to the default mode,
# otherwise, FAST_COMPILE fails.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论