Unverified 提交 b065112b authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Add rewrite for `1 ** x = 1` (#1179)

上级 2f1d25a9
......@@ -1905,13 +1905,40 @@ def local_reciprocal_canon(fgraph, node):
@register_canonicalize
@node_rewriter([pt_pow])
def local_pow_canonicalize(fgraph, node):
cst = get_underlying_scalar_constant_value(
"""
Rewrites for exponential functions with straight-forward simplifications:
1. x ** 0 -> 1
2. x ** 1 -> x
3. 1 ** x -> 1
In all cases, the shape of the output is the result of broadcasting the shapes of the inputs.
"""
cst_base = get_underlying_scalar_constant_value(
node.inputs[0], only_process_constants=True, raise_not_constant=False
)
cst_exponent = get_underlying_scalar_constant_value(
node.inputs[1], only_process_constants=True, raise_not_constant=False
)
if cst == 0:
return [alloc_like(1, node.outputs[0], fgraph)]
if cst == 1:
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
new_out = None
if cst_base == 1:
# 1 ** x = 1
new_out = broadcast_arrays(*node.inputs)[0]
elif cst_exponent == 0:
# x ** 0 = 1
new_out = broadcast_arrays(ones_like(node.inputs[0]), node.inputs[1])[0]
elif cst_exponent == 1:
# x ** 1 = x
new_out = broadcast_arrays(*node.inputs)[0]
if not new_out:
return
if new_out.dtype != node.out.dtype:
new_out = cast(new_out, dtype=node.out.dtype)
return [new_out]
@register_specialize
......
......@@ -4571,3 +4571,22 @@ def test_log_kv_stabilization():
out.eval({x: 1000.0}, mode=mode),
-1003.2180912984705,
)
@pytest.mark.parametrize("shape", [(), (4, 5, 6)], ids=["scalar", "tensor"])
def test_pow_1_rewrite(shape):
x = pt.tensor("x", shape=shape)
z = 1**x
assert isinstance(z.owner.op, Elemwise) and isinstance(
z.owner.op.scalar_op, ps.basic.Pow
)
f = pytensor.function([x], z)
assert not any(
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.basic.Pow)
for node in f.maker.fgraph.toposort()
)
x_val = np.random.random(shape).astype(config.floatX)
np.testing.assert_allclose(z.eval({x: x_val}), f(x_val))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论