提交 9f540975 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add opts for log(exp(x)), exp(log(x)), and exp(log1p(x))

上级 1549649f
...@@ -72,6 +72,7 @@ from aesara.tensor.math import ( ...@@ -72,6 +72,7 @@ from aesara.tensor.math import (
erfc, erfc,
exp, exp,
expm1, expm1,
ge,
int_div, int_div,
inv, inv,
log, log,
...@@ -226,6 +227,43 @@ def local_func_inv(fgraph, node): ...@@ -226,6 +227,43 @@ def local_func_inv(fgraph, node):
return return
@register_canonicalize
@register_specialize
@local_optimizer([Elemwise])
def local_exp_log(fgraph, node):
x = node.inputs[0]
if not isinstance(node.op, Elemwise):
return
if not x.owner or not isinstance(x.owner.op, Elemwise):
return
prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op
# Case for log(exp(x))
if isinstance(prev_op, aes.Exp) and isinstance(node_op, aes.Log):
return x.owner.inputs
# Case for exp(log(x))
if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype))
if new_out.type != old_out.type:
return
return [new_out]
# Case for exp(log1p(x))
if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, -1), add(1, x), np.asarray(np.nan, old_out.dtype))
if new_out.type != old_out.type:
return
return [new_out]
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Sum]) @local_optimizer([Sum])
......
...@@ -2482,6 +2482,61 @@ class TestFuncInverse: ...@@ -2482,6 +2482,61 @@ class TestFuncInverse:
self.assert_func_pair_optimized(rad2deg, cosh, dx, should_copy=False) self.assert_func_pair_optimized(rad2deg, cosh, dx, should_copy=False)
class TestExpLog:
def setup_method(self):
mode = get_default_mode()
self.mode = mode.including("local_exp_log").excluding("fusion")
def test_log_exp(self):
# log(exp(x)) -> x
data = np.random.rand(4, 3).astype("float32")
x = fmatrix()
f = function([x], log(exp(x)), mode=self.mode)
graph = f.maker.fgraph.toposort()
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Exp))
]
assert len(ops_graph) == 0
np.testing.assert_array_equal(f(data), data)
def test_exp_log(self):
# exp(log(x)) -> switch(x > 0, x, nan)
data_valid = np.random.rand(4, 3).astype("float32")
data_invalid = data_valid * -1
x = fmatrix()
f = function([x], exp(log(x)), mode=self.mode)
graph = f.maker.fgraph.toposort()
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Exp))
]
assert len(ops_graph) == 0
np.testing.assert_array_equal(f(data_valid), data_valid)
assert np.all(np.isnan(f(data_invalid)))
def test_exp_log1p(self):
# exp(log1p(x)) -> switch(x > -1, x + 1, nan)
data_valid = np.random.rand(4, 3).astype("float32") * 2 - 1
data_invalid = data_valid - 2
x = fmatrix()
f = function([x], exp(log1p(x)), mode=self.mode)
graph = f.maker.fgraph.toposort()
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Exp))
]
assert len(ops_graph) == 0
np.testing.assert_array_equal(f(data_valid), data_valid + 1)
assert np.all(np.isnan(f(data_invalid)))
class TestLocalSwitchSink: class TestLocalSwitchSink:
def setup_method(self): def setup_method(self):
# condition values # condition values
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论