提交 40b51621 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Add rewrites to remove unnecessary expm1 operations

上级 e8c2782f
......@@ -281,7 +281,7 @@ def local_exp_log(fgraph, node):
prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op
# Case for log(exp(x))
# Case for log(exp(x)) -> x
if isinstance(prev_op, aes.Exp) and isinstance(node_op, aes.Log):
new_out = x.owner.inputs[0]
old_out = node.outputs[0]
......@@ -290,11 +290,25 @@ def local_exp_log(fgraph, node):
new_out = cast(new_out, old_out.dtype)
return [new_out]
# Case for exp(softplus(x)) aka exp(log1pexp)
# Case for log1p(expm1(x)) -> x
if isinstance(prev_op, aes.Expm1) and isinstance(node_op, aes.Log1p):
new_out = x.owner.inputs[0]
old_out = node.outputs[0]
# Expm1 may have cast integer input to float
if new_out.dtype != old_out.dtype:
new_out = cast(new_out, old_out.dtype)
return [new_out]
# Case for exp(softplus(x)) aka exp(log1pexp) -> 1 + exp(x)
if isinstance(prev_op, aes_math.Softplus) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0]
return [add(1, exp(x))]
# Case for expm1(softplus(x)) aka expm1(log1pexp) -> exp(x)
if isinstance(prev_op, aes_math.Softplus) and isinstance(node_op, aes.Expm1):
x = x.owner.inputs[0]
return [exp(x)]
@register_specialize
@local_optimizer([Elemwise])
......@@ -310,27 +324,48 @@ def local_exp_log_nan_switch(fgraph, node):
prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op
# Case for exp(log(x))
# Case for exp(log(x)) -> x
if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype))
return [new_out]
# Case for exp(log1p(x))
# Case for exp(log1p(x)) -> x + 1
if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, -1), add(1, x), np.asarray(np.nan, old_out.dtype))
return [new_out]
# Case for exp(log1mexp(x))
# Case for expm1(log(x)) -> x - 1
if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Expm1):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, 0), sub(x, 1), np.asarray(np.nan, old_out.dtype))
return [new_out]
# Case for expm1(log1p(x)) -> x
if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Expm1):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, -1), x, np.asarray(np.nan, old_out.dtype))
return [new_out]
# Case for exp(log1mexp(x)) -> 1 - exp(x)
if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(le(x, 0), sub(1, exp(x)), np.asarray(np.nan, old_out.dtype))
return [new_out]
# Case for expm1(log1mexp(x)) -> -exp(x)
if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Expm1):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype))
return [new_out]
@register_canonicalize
@register_specialize
......
......@@ -2557,76 +2557,129 @@ class TestExpLog:
]
assert len(ops_graph) == 0
def test_exp_log(self):
@pytest.mark.parametrize("dtype", ["float32", "int32"])
def test_log1p_expm1(self, dtype):
# log1p(expm1(x)) -> x
data = (np.random.random((4, 3)) * 100).astype(dtype)
x = matrix(dtype=dtype)
f = function([x], log1p(expm1(x)), mode=self.mode, allow_input_downcast=True)
graph = f.maker.fgraph.toposort()
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Exp, aes.Log1p, aes.Expm1))
]
assert len(ops_graph) == 0
np.testing.assert_array_equal(f(data), data)
@pytest.mark.parametrize("exp_op", [exp, expm1])
def test_exp_log(self, exp_op):
# exp(log(x)) -> switch(x >= 0, x, nan)
# expm1(log(x)) -> switch(x >= 0, x - 1, nan)
data_valid = np.random.random((4, 3)).astype("float32")
data_valid[0, 0] = 0 # edge case
data_invalid = data_valid - 1
x = fmatrix()
f = function([x], exp(log(x)), mode=self.mode)
f = function([x], exp_op(log(x)), mode=self.mode)
graph = f.maker.fgraph.toposort()
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Exp))
and isinstance(node.op.scalar_op, (aes.Log, aes.Log1p, aes.Exp, aes.Expm1))
]
assert len(ops_graph) == 0
np.testing.assert_array_equal(f(data_valid), data_valid)
if exp_op == exp:
expected = data_valid
else:
expected = data_valid - 1
np.testing.assert_almost_equal(f(data_valid), expected)
assert np.all(np.isnan(f(data_invalid)))
def test_exp_log1p(self):
@pytest.mark.parametrize("exp_op", [exp, expm1])
def test_exp_log1p(self, exp_op):
# exp(log1p(x)) -> switch(x >= -1, x + 1, nan)
# expm1(log1p(x)) -> switch(x >= -1, x, nan)
data_valid = np.random.random((4, 3)).astype("float32") * 2 - 1
data_valid[0, 0] = -1 # edge case
data_invalid = data_valid - 2
x = fmatrix()
f = function([x], exp(log1p(x)), mode=self.mode)
f = function([x], exp_op(log1p(x)), mode=self.mode)
graph = f.maker.fgraph.toposort()
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Exp))
and isinstance(node.op.scalar_op, (aes.Log, aes.Log1p, aes.Exp, aes.Expm1))
]
assert len(ops_graph) == 0
np.testing.assert_array_equal(f(data_valid), data_valid + 1)
if exp_op == exp:
expected = data_valid + 1
else:
expected = data_valid
np.testing.assert_almost_equal(f(data_valid), expected)
assert np.all(np.isnan(f(data_invalid)))
def test_exp_log1mexp(self):
@pytest.mark.parametrize("exp_op", [exp, expm1])
def test_exp_log1mexp(self, exp_op):
# exp(log1mexp(x)) -> switch(x <= 0, 1 - exp(x), nan)
# expm1(log1mexp(x)) -> switch(x <= 0, - exp(x), nan)
data_valid = -np.random.random((4, 3)).astype("float32")
data_valid[0, 0] = 0 # edge case
data_invalid = data_valid + 1
x = fmatrix()
f = function([x], exp(log1mexp(x)), mode=self.mode)
f = function([x], exp_op(log1mexp(x)), mode=self.mode)
graph = f.maker.fgraph.toposort()
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Log1mexp))
and isinstance(
node.op.scalar_op, (aes.Log, aes.Log1p, aes.Log1mexp, aes.Expm1)
)
]
assert len(ops_graph) == 0
np.testing.assert_almost_equal(f(data_valid), 1 - np.exp(data_valid))
if exp_op == exp:
expected = 1 - np.exp(data_valid)
else:
expected = -np.exp(data_valid)
np.testing.assert_almost_equal(f(data_valid), expected)
assert np.all(np.isnan(f(data_invalid)))
def test_exp_softplus(self):
@pytest.mark.parametrize("exp_op", [exp, expm1])
def test_exp_softplus(self, exp_op):
# exp(softplus(x)) -> 1 + exp(x)
# expm1(softplus(x)) -> exp(x)
data_valid = np.random.random((4, 3)).astype("float32") * 2 - 1
x = fmatrix()
f = function([x], exp(softplus(x)), mode=self.mode)
f = function([x], exp_op(softplus(x)), mode=self.mode)
graph = f.maker.fgraph.toposort()
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Softplus))
and isinstance(
node.op.scalar_op,
(aes.Log, aes.Log1p, aes.Softplus, aes.Expm1, aes.Switch),
)
]
assert len(ops_graph) == 0
if exp_op == exp:
expected = 1 + np.exp(data_valid)
else:
expected = np.exp(data_valid)
np.testing.assert_almost_equal(
f(data_valid),
1 + np.exp(data_valid),
expected,
decimal=6,
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论