提交 143896b3 authored 作者: Samira Shabanian's avatar Samira Shabanian

local_exmp1 function and test are added

Conflicts: theano/tensor/opt.py theano/tensor/tests/test_opt.py
上级 b453a024
...@@ -3584,34 +3584,29 @@ def local_join_make_vector(node): ...@@ -3584,34 +3584,29 @@ def local_join_make_vector(node):
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.Elemwise]) @gof.local_optimizer([T.Elemwise])
def local_expm1(node): def local_expm1(node):
""" """
This optimization detects exp(a)-1 and converts this to expm1(a). This optimization detects exp(a)-1 and converts this to expm1(a).
""" """
if (isinstance(node.op, T.Elemwise) and if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, theano.scalar.basic.Sub)): isinstance(node.op.scalar_op, theano.scalar.basic.Sub)):
in1, in2 = node.inputs in1, in2 = node.inputs
out = node.outputs[0] out = node.outputs[0]
#import pdb; pdb.set_trace()
if (in1.owner and isinstance(in1.owner.op, T.Elemwise) and isinstance(in1.owner.op.scalar_op, theano.scalar.basic.Exp) and if (in1.owner and isinstance(in1.owner.op, T.Elemwise) and isinstance(in1.owner.op.scalar_op, theano.scalar.basic.Exp) and
T.extract_constant(in2) == 1): T.extract_constant(in2) == 1):
in11 = in1.owner.inputs[0] in11 = in1.owner.inputs[0]
new_out = T.expm1(in11) new_out = T.expm1(in11)
# import pdb; pdb.set_trace()
if new_out.dtype != out.dtype: if new_out.dtype != out.dtype:
new_out = T.cast(new_out, dtype=out.dtype) new_out = T.cast(new_out, dtype=out.dtype)
# The ones could have forced a specific length
if new_out.type != out.type: if new_out.type != out.type:
new_out = broadcast_like(new_out, out, node.fgraph) return
return [new_out] return [new_out]
############### ###############
# Switch opts # # Switch opts #
############### ###############
@register_canonicalize('fast_compile', 'local_remove_switch_const_cond') @register_canonicalize('fast_compile', 'local_remove_switch_const_cond')
@register_specialize @register_specialize
@gof.local_optimizer([T.Elemwise]) @gof.local_optimizer([T.Elemwise])
......
...@@ -6132,25 +6132,27 @@ def test_local_expm1(): ...@@ -6132,25 +6132,27 @@ def test_local_expm1():
t = T.exp(x) - x t = T.exp(x) - x
s = T.exp(u) - numpy.ones((4, 3)).astype(config.floatX) s = T.exp(u) - numpy.ones((4, 3)).astype(config.floatX)
#f = function([x], y) f = function([x], y)
#g = function([x], z) g = function([x], z)
#h = function([x], t) h = function([x], t)
r = function([u],s) r = function([u], s)
x_val = numpy.random.rand(4, 3).astype(config.floatX) x_val = numpy.random.rand(4, 3).astype(config.floatX)
#f_val = f(x_val) f_val = f(x_val)
#g_val = g(x_val) f_test = function([x], T.expm1(x))
#h_val = h(x_val)
u_val = r(1.0)
#assert any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1) assert numpy.all(f_val == f_test(x_val))
# for n in f.maker.fgraph.toposort())
#assert not any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1) assert any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
# for n in g.maker.fgraph.toposort()) for n in f.maker.fgraph.toposort())
#assert not any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1) assert not any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
# for n in h.maker.fgraph.toposort()) for n in g.maker.fgraph.toposort())
assert not any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
for n in h.maker.fgraph.toposort())
assert not any(isinstance(n.op, T.Elemwise) and isinstance(n.op.scalar_op, theano.scalar.basic.Expm1)
for n in r.maker.fgraph.toposort())
def test_local_merge_alloc(): def test_local_merge_alloc():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论