Unverified 提交 b55c4730 authored 作者: Will Dean's avatar Will Dean 提交者: GitHub

Add rewrite for log(sqrt(x)) (#1555)

上级 892a8f0e
......@@ -552,6 +552,29 @@ def local_sqrt_sqr(fgraph, node):
return [new_out]
@register_specialize
@node_rewriter([log])
def local_log_sqrt(fgraph, node):
x = node.inputs[0]
if (
not x.owner
or not isinstance(x.owner.op, Elemwise)
or not isinstance(x.owner.op.scalar_op, ps.Sqrt)
):
return
# Case for log(sqrt(x)) -> 0.5 * log(x)
x = x.owner.inputs[0]
old_out = node.outputs[0]
new_out = mul(as_tensor_variable(0.5, dtype=x.dtype), log(x))
if new_out.dtype != old_out.dtype:
new_out = cast(new_out, old_out.dtype)
copy_stack_trace(node.out, new_out)
return [new_out]
@register_specialize
@node_rewriter([exp, expm1])
def local_exp_log_nan_switch(fgraph, node):
......
......@@ -1989,6 +1989,18 @@ class TestExpLog:
assert len(ops_graph) == expected_switches
def test_log_sqrt() -> None:
x = pt.tensor("x", shape=(None, None))
out = log(sqrt(x))
out = rewrite_graph(out, include=["specialize"])
assert utt.assert_equal_computations(
[out],
[mul(pt.as_tensor_variable([[0.5]], dtype=x.dtype), log(x))],
)
class TestSqrSqrt:
def setup_method(self):
mode = get_default_mode()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论