提交 cc39ea30 authored 作者: Ben Mares's avatar Ben Mares 提交者: Ricardo Vieira

Rewrite log1mexp(log1mexp(x)) to x

上级 10e5c92f
......@@ -625,6 +625,13 @@ def local_exp_log_nan_switch(fgraph, node):
new_out = switch(ge(x, 0), log1p(-x), np.asarray(np.nan, old_out.dtype))
return [new_out]
# Case for log1mexp(log1mexp(x)) -> x
if isinstance(prev_op, ps_math.Log1mexp) and isinstance(node_op, ps_math.Log1mexp):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(le(x, 0), x, np.asarray(np.nan, old_out.dtype))
return [new_out]
@register_canonicalize
@register_specialize
......
......@@ -2016,6 +2016,29 @@ class TestExpLog:
np.testing.assert_almost_equal(f(data_valid), expected)
assert np.all(np.isnan(f(data_invalid)))
def test_log1mexp_log1mexp(self):
# log1mexp(log1mexp(x)) -> x
data_valid = -np.random.random((4, 3)).astype("float32")
data_valid[0, 0] = 0 # edge case
data_invalid = data_valid + 1.1
x = fmatrix()
f = function([x], log1mexp(log1mexp(x)), mode=self.mode.excluding("inplace"))
assert equal_computations(
f.maker.fgraph.outputs,
[
pt.switch(
x <= np.array([[0]], dtype=np.int8),
x,
np.array([[np.nan]], dtype=np.float32),
)
],
)
expected = data_valid
np.testing.assert_almost_equal(f(data_valid), expected)
assert np.all(np.isnan(f(data_invalid)))
@pytest.mark.parametrize(
["nested_expression", "expected_switches"],
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论