提交 ddcf0532 authored 作者: Luca Citi's avatar Luca Citi 提交者: Ricardo Vieira

Added log1mexp(log(x)) -> log1p(-x) and its test

Also implemented tests as suggested by ricardoV94
上级 b9ea6dfb
......@@ -576,7 +576,7 @@ def local_log_sqrt(fgraph, node):
@register_specialize
@node_rewriter([exp, expm1, softplus])
@node_rewriter([exp, expm1, log1pexp, log1mexp])
def local_exp_log_nan_switch(fgraph, node):
# Rewrites of the kind exp(log...(x)) that require a `nan` switch
x = node.inputs[0]
......@@ -629,13 +629,20 @@ def local_exp_log_nan_switch(fgraph, node):
new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype))
return [new_out]
# Case for softplus(log(x)) -> log1p(x)
# Case for log1pexp(log(x)) -> log1p(x) (log1pexp aka softplus)
if isinstance(prev_op, ps.Log) and isinstance(node_op, ps_math.Softplus):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, 0), log1p(x), np.asarray(np.nan, old_out.dtype))
return [new_out]
# Case for log1mexp(log(x)) -> log1p(-x)
if isinstance(prev_op, ps.Log) and isinstance(node_op, ps_math.Log1mexp):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, 0), log1p(-x), np.asarray(np.nan, old_out.dtype))
return [new_out]
@register_canonicalize
@register_specialize
......
......@@ -69,6 +69,7 @@ from pytensor.tensor.math import (
log,
log1mexp,
log1p,
log1pexp,
lt,
maximum,
minimum,
......@@ -1968,27 +1969,53 @@ class TestExpLog:
decimal=6,
)
def test_softplus_log(self):
# softplus(log(x)) -> log1p(x)
def test_log1pexp_log(self):
# log1pexp(log(x)) -> log1p(x)
data_valid = np.random.random((4, 3)).astype("float32") * 2
data_valid[0, 0] = 0 # edge case
data_invalid = data_valid - 2
x = fmatrix()
f = function([x], softplus(log(x)), mode=self.mode)
graph = f.maker.fgraph.toposort()
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, ps.Log | ps.Exp | ps.Softplus)
]
assert len(ops_graph) == 0
f = function([x], log1pexp(log(x)), mode=self.mode.excluding("inplace"))
assert equal_computations(
f.maker.fgraph.outputs,
[
pt.switch(
x >= np.array([[0]], dtype=np.int8),
pt.log1p(x),
np.array([[np.nan]], dtype=np.float32),
)
],
)
expected = np.log1p(data_valid)
np.testing.assert_almost_equal(f(data_valid), expected)
assert np.all(np.isnan(f(data_invalid)))
def test_log1mexp_log(self):
# log1mexp(log(x)) -> log1p(-x)
data_valid = np.random.random((4, 3)).astype("float32")
data_valid[0, 0] = 0 # edge case
data_valid[0, 1] = 1 # another edge case
data_invalid = np.concatenate([data_valid + 1.1, data_valid - 1.1])
x = fmatrix()
f = function([x], log1mexp(log(x)), mode=self.mode.excluding("inplace"))
assert equal_computations(
f.maker.fgraph.outputs,
[
pt.switch(
x >= np.array([[0]], dtype=np.int8),
pt.log1p(-x),
np.array([[np.nan]], dtype=np.float32),
)
],
)
expected = np.log1p(-data_valid)
np.testing.assert_almost_equal(f(data_valid), expected)
assert np.all(np.isnan(f(data_invalid)))
@pytest.mark.parametrize(
["nested_expression", "expected_switches"],
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论