提交 d0a9488a authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Cast output of `local_func_inv` and `local_exp_log` to float when needed

上级 43ed9011
......@@ -225,7 +225,12 @@ def local_func_inv(fgraph, node):
if is_inverse_pair(node_op, prev_op, inv_pair):
# We don't need to copy stack trace, because the optimization
# is trivial and maintains the earlier stack trace
return x.owner.inputs
ottype = node.out.dtype
inp = x.owner.inputs[0]
# Functions may have casted integer input to float
if inp.dtype != ottype:
inp = cast(inp, ottype)
return [inp]
return
......@@ -246,7 +251,12 @@ def local_exp_log(fgraph, node):
# Case for log(exp(x))
if isinstance(prev_op, aes.Exp) and isinstance(node_op, aes.Log):
return x.owner.inputs
new_out = x.owner.inputs[0]
old_out = node.outputs[0]
# Exp may have casted integer input to float
if new_out.dtype != old_out.dtype:
new_out = cast(new_out, old_out.dtype)
return [new_out]
# Case for exp(softplus(x)) aka exp(log1pexp)
if isinstance(prev_op, aes_math.Softplus) and isinstance(node_op, aes.Exp):
......
......@@ -2488,6 +2488,16 @@ class TestFuncInverse:
self.assert_func_pair_optimized(rad2deg, rad2deg, dx, should_copy=False)
self.assert_func_pair_optimized(rad2deg, cosh, dx, should_copy=False)
def test_integer_upcast(self):
"""
All invertible methods (except for `Neg`) can upgrade their input to float.
Here we test that the rewrite works with just one pair of methods
"""
x = ivector("x")
f = function([x], deg2rad(rad2deg(x)), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
class TestExpLog:
def setup_method(self):
......@@ -2512,6 +2522,17 @@ class TestExpLog:
assert len(ops_graph) == 0
np.testing.assert_array_equal(f(data), data)
def test_log_exp_integer_upcast(self):
x = ivector("x")
f = function([x], log(exp(x)), mode=self.mode)
ops_graph = [
node
for node in f.maker.fgraph.toposort()
if isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, (aes.Log, aes.Exp))
]
assert len(ops_graph) == 0
def test_exp_log(self):
# exp(log(x)) -> switch(x >= 0, x, nan)
data_valid = np.random.random((4, 3)).astype("float32")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论