提交 40b51621 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Add rewrites to remove unnecessary expm1 operations

上级 e8c2782f
...@@ -281,7 +281,7 @@ def local_exp_log(fgraph, node): ...@@ -281,7 +281,7 @@ def local_exp_log(fgraph, node):
prev_op = x.owner.op.scalar_op prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op node_op = node.op.scalar_op
# Case for log(exp(x)) # Case for log(exp(x)) -> x
if isinstance(prev_op, aes.Exp) and isinstance(node_op, aes.Log): if isinstance(prev_op, aes.Exp) and isinstance(node_op, aes.Log):
new_out = x.owner.inputs[0] new_out = x.owner.inputs[0]
old_out = node.outputs[0] old_out = node.outputs[0]
...@@ -290,11 +290,25 @@ def local_exp_log(fgraph, node): ...@@ -290,11 +290,25 @@ def local_exp_log(fgraph, node):
new_out = cast(new_out, old_out.dtype) new_out = cast(new_out, old_out.dtype)
return [new_out] return [new_out]
# Case for exp(softplus(x)) aka exp(log1pexp) # Case for log1p(expm1(x)) -> x
if isinstance(prev_op, aes.Expm1) and isinstance(node_op, aes.Log1p):
new_out = x.owner.inputs[0]
old_out = node.outputs[0]
# Expm1 may have cast integer input to float
if new_out.dtype != old_out.dtype:
new_out = cast(new_out, old_out.dtype)
return [new_out]
# Case for exp(softplus(x)) aka exp(log1pexp) -> 1 + exp(x)
if isinstance(prev_op, aes_math.Softplus) and isinstance(node_op, aes.Exp): if isinstance(prev_op, aes_math.Softplus) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0] x = x.owner.inputs[0]
return [add(1, exp(x))] return [add(1, exp(x))]
# Case for expm1(softplus(x)) aka expm1(log1pexp) -> exp(x)
if isinstance(prev_op, aes_math.Softplus) and isinstance(node_op, aes.Expm1):
x = x.owner.inputs[0]
return [exp(x)]
@register_specialize @register_specialize
@local_optimizer([Elemwise]) @local_optimizer([Elemwise])
...@@ -310,27 +324,48 @@ def local_exp_log_nan_switch(fgraph, node): ...@@ -310,27 +324,48 @@ def local_exp_log_nan_switch(fgraph, node):
prev_op = x.owner.op.scalar_op prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op node_op = node.op.scalar_op
# Case for exp(log(x)) # Case for exp(log(x)) -> x
if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Exp): if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0] x = x.owner.inputs[0]
old_out = node.outputs[0] old_out = node.outputs[0]
new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype)) new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype))
return [new_out] return [new_out]
# Case for exp(log1p(x)) # Case for exp(log1p(x)) -> x + 1
if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Exp): if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0] x = x.owner.inputs[0]
old_out = node.outputs[0] old_out = node.outputs[0]
new_out = switch(ge(x, -1), add(1, x), np.asarray(np.nan, old_out.dtype)) new_out = switch(ge(x, -1), add(1, x), np.asarray(np.nan, old_out.dtype))
return [new_out] return [new_out]
# Case for exp(log1mexp(x)) # Case for expm1(log(x)) -> x - 1
if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Expm1):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, 0), sub(x, 1), np.asarray(np.nan, old_out.dtype))
return [new_out]
# Case for expm1(log1p(x)) -> x
if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Expm1):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, -1), x, np.asarray(np.nan, old_out.dtype))
return [new_out]
# Case for exp(log1mexp(x)) -> 1 - exp(x)
if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Exp): if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Exp):
x = x.owner.inputs[0] x = x.owner.inputs[0]
old_out = node.outputs[0] old_out = node.outputs[0]
new_out = switch(le(x, 0), sub(1, exp(x)), np.asarray(np.nan, old_out.dtype)) new_out = switch(le(x, 0), sub(1, exp(x)), np.asarray(np.nan, old_out.dtype))
return [new_out] return [new_out]
# Case for expm1(log1mexp(x)) -> -exp(x)
if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Expm1):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype))
return [new_out]
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
......
...@@ -2557,76 +2557,129 @@ class TestExpLog: ...@@ -2557,76 +2557,129 @@ class TestExpLog:
] ]
assert len(ops_graph) == 0 assert len(ops_graph) == 0
def test_exp_log(self): @pytest.mark.parametrize("dtype", ["float32", "int32"])
def test_log1p_expm1(self, dtype):
# log1p(expm1(x)) -> x
data = (np.random.random((4, 3)) * 100).astype(dtype)
x = matrix(dtype=dtype)
f = function([x], log1p(expm1(x)), mode=self.mode, allow_input_downcast=True)
graph = f.maker.fgraph.toposort()
ops_graph = [
node
for node in graph
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Exp, aes.Log1p, aes.Expm1))
]
assert len(ops_graph) == 0
np.testing.assert_array_equal(f(data), data)
@pytest.mark.parametrize("exp_op", [exp, expm1])
def test_exp_log(self, exp_op):
# exp(log(x)) -> switch(x >= 0, x, nan) # exp(log(x)) -> switch(x >= 0, x, nan)
# expm1(log(x)) -> switch(x >= 0, x - 1, nan)
data_valid = np.random.random((4, 3)).astype("float32") data_valid = np.random.random((4, 3)).astype("float32")
data_valid[0, 0] = 0 # edge case data_valid[0, 0] = 0 # edge case
data_invalid = data_valid - 1 data_invalid = data_valid - 1
x = fmatrix() x = fmatrix()
f = function([x], exp(log(x)), mode=self.mode) f = function([x], exp_op(log(x)), mode=self.mode)
graph = f.maker.fgraph.toposort() graph = f.maker.fgraph.toposort()
ops_graph = [ ops_graph = [
node node
for node in graph for node in graph
if isinstance(node.op, Elemwise) if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Exp)) and isinstance(node.op.scalar_op, (aes.Log, aes.Log1p, aes.Exp, aes.Expm1))
] ]
assert len(ops_graph) == 0 assert len(ops_graph) == 0
np.testing.assert_array_equal(f(data_valid), data_valid)
if exp_op == exp:
expected = data_valid
else:
expected = data_valid - 1
np.testing.assert_almost_equal(f(data_valid), expected)
assert np.all(np.isnan(f(data_invalid))) assert np.all(np.isnan(f(data_invalid)))
def test_exp_log1p(self): @pytest.mark.parametrize("exp_op", [exp, expm1])
def test_exp_log1p(self, exp_op):
# exp(log1p(x)) -> switch(x >= -1, x + 1, nan) # exp(log1p(x)) -> switch(x >= -1, x + 1, nan)
# expm1(log1p(x)) -> switch(x >= -1, x, nan)
data_valid = np.random.random((4, 3)).astype("float32") * 2 - 1 data_valid = np.random.random((4, 3)).astype("float32") * 2 - 1
data_valid[0, 0] = -1 # edge case data_valid[0, 0] = -1 # edge case
data_invalid = data_valid - 2 data_invalid = data_valid - 2
x = fmatrix() x = fmatrix()
f = function([x], exp(log1p(x)), mode=self.mode) f = function([x], exp_op(log1p(x)), mode=self.mode)
graph = f.maker.fgraph.toposort() graph = f.maker.fgraph.toposort()
ops_graph = [ ops_graph = [
node node
for node in graph for node in graph
if isinstance(node.op, Elemwise) if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Exp)) and isinstance(node.op.scalar_op, (aes.Log, aes.Log1p, aes.Exp, aes.Expm1))
] ]
assert len(ops_graph) == 0 assert len(ops_graph) == 0
np.testing.assert_array_equal(f(data_valid), data_valid + 1)
if exp_op == exp:
expected = data_valid + 1
else:
expected = data_valid
np.testing.assert_almost_equal(f(data_valid), expected)
assert np.all(np.isnan(f(data_invalid))) assert np.all(np.isnan(f(data_invalid)))
def test_exp_log1mexp(self): @pytest.mark.parametrize("exp_op", [exp, expm1])
def test_exp_log1mexp(self, exp_op):
# exp(log1mexp(x)) -> switch(x <= 0, 1 - exp(x), nan) # exp(log1mexp(x)) -> switch(x <= 0, 1 - exp(x), nan)
# expm1(log1mexp(x)) -> switch(x <= 0, - exp(x), nan)
data_valid = -np.random.random((4, 3)).astype("float32") data_valid = -np.random.random((4, 3)).astype("float32")
data_valid[0, 0] = 0 # edge case data_valid[0, 0] = 0 # edge case
data_invalid = data_valid + 1 data_invalid = data_valid + 1
x = fmatrix() x = fmatrix()
f = function([x], exp(log1mexp(x)), mode=self.mode) f = function([x], exp_op(log1mexp(x)), mode=self.mode)
graph = f.maker.fgraph.toposort() graph = f.maker.fgraph.toposort()
ops_graph = [ ops_graph = [
node node
for node in graph for node in graph
if isinstance(node.op, Elemwise) if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Log1mexp)) and isinstance(
node.op.scalar_op, (aes.Log, aes.Log1p, aes.Log1mexp, aes.Expm1)
)
] ]
assert len(ops_graph) == 0 assert len(ops_graph) == 0
np.testing.assert_almost_equal(f(data_valid), 1 - np.exp(data_valid))
if exp_op == exp:
expected = 1 - np.exp(data_valid)
else:
expected = -np.exp(data_valid)
np.testing.assert_almost_equal(f(data_valid), expected)
assert np.all(np.isnan(f(data_invalid))) assert np.all(np.isnan(f(data_invalid)))
def test_exp_softplus(self): @pytest.mark.parametrize("exp_op", [exp, expm1])
def test_exp_softplus(self, exp_op):
# exp(softplus(x)) -> 1 + exp(x) # exp(softplus(x)) -> 1 + exp(x)
# expm1(softplus(x)) -> exp(x)
data_valid = np.random.random((4, 3)).astype("float32") * 2 - 1 data_valid = np.random.random((4, 3)).astype("float32") * 2 - 1
x = fmatrix() x = fmatrix()
f = function([x], exp(softplus(x)), mode=self.mode) f = function([x], exp_op(softplus(x)), mode=self.mode)
graph = f.maker.fgraph.toposort() graph = f.maker.fgraph.toposort()
ops_graph = [ ops_graph = [
node node
for node in graph for node in graph
if isinstance(node.op, Elemwise) if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Softplus)) and isinstance(
node.op.scalar_op,
(aes.Log, aes.Log1p, aes.Softplus, aes.Expm1, aes.Switch),
)
] ]
assert len(ops_graph) == 0 assert len(ops_graph) == 0
if exp_op == exp:
expected = 1 + np.exp(data_valid)
else:
expected = np.exp(data_valid)
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
f(data_valid), f(data_valid),
1 + np.exp(data_valid), expected,
decimal=6, decimal=6,
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论