提交 143896b3 authored 作者: Samira Shabanian's avatar Samira Shabanian

local_exmp1 function and test are added

Conflicts: theano/tensor/opt.py theano/tensor/tests/test_opt.py
上级 b453a024
......@@ -3584,34 +3584,29 @@ def local_join_make_vector(node):
@register_canonicalize
@gof.local_optimizer([T.Elemwise])
def local_expm1(node):
"""
"""
This optimization detects exp(a)-1 and converts this to expm1(a).
"""
if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, theano.scalar.basic.Sub)):
in1, in2 = node.inputs
out = node.outputs[0]
#import pdb; pdb.set_trace()
if (in1.owner and isinstance(in1.owner.op, T.Elemwise) and isinstance(in1.owner.op.scalar_op, theano.scalar.basic.Exp) and
if (in1.owner and isinstance(in1.owner.op, T.Elemwise) and isinstance(in1.owner.op.scalar_op, theano.scalar.basic.Exp) and
T.extract_constant(in2) == 1):
in11 = in1.owner.inputs[0]
new_out = T.expm1(in11)
# import pdb; pdb.set_trace()
if new_out.dtype != out.dtype:
new_out = T.cast(new_out, dtype=out.dtype)
# The ones could have forced a specific length
new_out = T.cast(new_out, dtype=out.dtype)
if new_out.type != out.type:
new_out = broadcast_like(new_out, out, node.fgraph)
return
return [new_out]
###############
# Switch opts #
###############
@register_canonicalize('fast_compile', 'local_remove_switch_const_cond')
@register_specialize
@gof.local_optimizer([T.Elemwise])
......
......@@ -6132,25 +6132,27 @@ def test_local_expm1():
t = T.exp(x) - x
s = T.exp(u) - numpy.ones((4, 3)).astype(config.floatX)
#f = function([x], y)
#g = function([x], z)
#h = function([x], t)
r = function([u],s)
f = function([x], y)
g = function([x], z)
h = function([x], t)
r = function([u], s)
x_val = numpy.random.rand(4, 3).astype(config.floatX)
#f_val = f(x_val)
#g_val = g(x_val)
#h_val = h(x_val)
u_val = r(1.0)
f_val = f(x_val)
f_test = function([x], T.expm1(x))
#assert any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
# for n in f.maker.fgraph.toposort())
assert numpy.all(f_val == f_test(x_val))
#assert not any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
# for n in g.maker.fgraph.toposort())
assert any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
for n in f.maker.fgraph.toposort())
#assert not any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
# for n in h.maker.fgraph.toposort())
assert not any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
for n in g.maker.fgraph.toposort())
assert not any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
for n in h.maker.fgraph.toposort())
assert not any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
for n in r.maker.fgraph.toposort())
def test_local_merge_alloc():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论