提交 e03605e4 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Rewrite `sqr(sqrt(x)) -> |x|` and `sqrt(sqr(x)) -> x`

上级 271c2463
......@@ -400,6 +400,37 @@ def local_exp_log(fgraph, node):
return [exp(x)]
@register_canonicalize
@register_specialize
@node_rewriter([sqrt, sqr])
def local_sqrt_sqr(fgraph, node):
x = node.inputs[0]
if not (x.owner and isinstance(x.owner.op, Elemwise)):
return
prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op
# Case for sqrt(sqr(x)) -> |x|
if isinstance(prev_op, ps.Sqrt) and isinstance(node_op, ps.Sqr):
new_out = pt_abs(x.owner.inputs[0])
old_out = node.outputs[0]
# Handle potential integer to float cast by sqr
if new_out.dtype != old_out.dtype:
new_out = cast(new_out, old_out.dtype)
return [new_out]
# Case for sqr(sqrt(x)) -> x
if isinstance(prev_op, ps.Sqr) and isinstance(node_op, ps.Sqrt):
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype))
return [new_out]
@register_specialize
@node_rewriter([exp, expm1])
def local_exp_log_nan_switch(fgraph, node):
......
......@@ -2031,6 +2031,45 @@ class TestExpLog:
assert len(ops_graph) == expected_switches
class TestSqrSqrt:
def setup_method(self):
mode = get_default_mode()
self.mode = mode.including(
"local_sqrt_sqr",
).excluding("fusion")
self.rng = np.random.default_rng()
def test_sqr_sqrt(self):
# sqrt(x) ** 2 -> x
x = pt.tensor("x", shape=(None, None))
out = sqr(sqrt(x))
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])
assert equal_computations([out], [pt_abs(x)])
def test_sqrt_sqr(self):
x = pt.tensor("x", shape=(None, None))
out = sqrt(sqr(x))
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])
expected = switch(
ge(x, np.zeros((1, 1), dtype="int8")),
x,
np.full((1, 1), np.nan, dtype=x.type.dtype),
)
assert equal_computations([out], [expected])
def test_sqr_sqrt_integer_upcast(self):
x = ivector("x")
out = sqr(sqrt(x))
dtype = out.type.dtype
out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"])
expected = pt.cast(pt_abs(x), dtype=dtype)
assert equal_computations([out], [expected])
class TestLocalSwitchSink:
def setup_method(self):
# condition values
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论