提交 091eec50 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add cases for `log1msigm` and `Softplus` in `local_exp_log` opt

上级 8717b8f7
......@@ -11,6 +11,7 @@ from functools import reduce
import numpy as np
import aesara.scalar.basic as aes
import aesara.scalar.math as aes_math
from aesara.assert_op import assert_op
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
......@@ -76,6 +77,7 @@ from aesara.tensor.math import (
ge,
int_div,
isinf,
le,
log,
log1mexp,
log1p,
......@@ -265,6 +267,24 @@ def local_exp_log(fgraph, node):
return
return [new_out]
# Case for exp(log1mexp(x))
if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(le(x, 0), sub(1, exp(x)), np.asarray(np.nan, old_out.dtype))
if new_out.type != old_out.type:
return
return [new_out]
# Case for exp(softplus(x)) aka exp(log1pexp)
if isinstance(prev_op, aes_math.Softplus) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = add(1, exp(x))
if new_out.type != old_out.type:
return
return [new_out]
@register_canonicalize
@register_specialize
......
......@@ -58,6 +58,7 @@ from aesara.tensor.math import (
iround,
le,
log,
log1mexp,
log1p,
log2,
log10,
......@@ -70,7 +71,7 @@ from aesara.tensor.math import minimum, mul, neg, neq
from aesara.tensor.math import pow as aet_pow
from aesara.tensor.math import prod, rad2deg, reciprocal
from aesara.tensor.math import round as aet_round
from aesara.tensor.math import sgn, sigmoid, sin, sinh, sqr, sqrt, sub
from aesara.tensor.math import sgn, sigmoid, sin, sinh, softplus, sqr, sqrt, sub
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.math_opt import (
......@@ -2510,9 +2511,10 @@ class TestExpLog:
np.testing.assert_array_equal(f(data), data)
def test_exp_log(self):
# exp(log(x)) -> switch(x > 0, x, nan)
# exp(log(x)) -> switch(x >= 0, x, nan)
data_valid = np.random.random((4, 3)).astype("float32")
data_invalid = data_valid * -1
data_valid[0, 0] = 0 # edge case
data_invalid = data_valid - 1
x = fmatrix()
f = function([x], exp(log(x)), mode=self.mode)
graph = f.maker.fgraph.toposort()
......@@ -2527,8 +2529,9 @@ class TestExpLog:
assert np.all(np.isnan(f(data_invalid)))
def test_exp_log1p(self):
# exp(log1p(x)) -> switch(x > -1, x + 1, nan)
# exp(log1p(x)) -> switch(x >= -1, x + 1, nan)
data_valid = np.random.random((4, 3)).astype("float32") * 2 - 1
data_valid[0, 0] = -1 # edge case
data_invalid = data_valid - 2
x = fmatrix()
f = function([x], exp(log1p(x)), mode=self.mode)
......@@ -2543,6 +2546,43 @@ class TestExpLog:
np.testing.assert_array_equal(f(data_valid), data_valid + 1)
assert np.all(np.isnan(f(data_invalid)))
def test_exp_log1mexp(self):
# exp(log1mexp(x)) -> switch(x <= 0, 1 - exp(x), nan)
data_valid = -np.random.random((4, 3)).astype("float32")
data_valid[0, 0] = 0 # edge case
data_invalid = data_valid + 1
x = fmatrix()
f = function([x], exp(log1mexp(x)), mode=self.mode)
graph = f.maker.fgraph.toposort()
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Log1mexp))
]
assert len(ops_graph) == 0
np.testing.assert_almost_equal(f(data_valid), 1 - np.exp(data_valid))
assert np.all(np.isnan(f(data_invalid)))
def test_exp_softplus(self):
# exp(softplus(x)) -> 1 + exp(x)
data_valid = np.random.random((4, 3)).astype("float32") * 2 - 1
x = fmatrix()
f = function([x], exp(softplus(x)), mode=self.mode)
graph = f.maker.fgraph.toposort()
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Softplus))
]
assert len(ops_graph) == 0
np.testing.assert_almost_equal(
f(data_valid),
1 + np.exp(data_valid),
decimal=6,
)
class TestLocalSwitchSink:
def setup_method(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论