Unverified 提交 17c675a2 authored 作者: Copilot's avatar Copilot 提交者: GitHub

Remove `scalar_` prefix from several Ops (#1683)

* Initial plan * Rename ScalarMaximum/ScalarMinimum to Maximum/Minimum * Apply ruff formatting fixes * Remove custom names when class name matches desired name * Remove scalar_ prefix from log1mexp, xlogx, xlogy0 and fix numba imports * Fix xtensor test to use backward compat aliases in skip list --------- Co-authored-by: 's avatarcopilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
上级 3082ed5e
......@@ -25,11 +25,11 @@ from pytensor.scalar.basic import (
IsNan,
Log,
Log1p,
Maximum,
Minimum,
Mul,
Neg,
Pow,
ScalarMaximum,
ScalarMinimum,
Sign,
Sin,
Sqr,
......@@ -105,7 +105,7 @@ def mlx_funcify_CARreduce_OR(scalar_op, axis):
return any_reduce
@mlx_funcify_CAReduce_scalar_op.register(ScalarMaximum)
@mlx_funcify_CAReduce_scalar_op.register(Maximum)
def mlx_funcify_CARreduce_Maximum(scalar_op, axis):
def max_reduce(x):
return mx.max(x, axis=axis)
......@@ -113,7 +113,7 @@ def mlx_funcify_CARreduce_Maximum(scalar_op, axis):
return max_reduce
@mlx_funcify_CAReduce_scalar_op.register(ScalarMinimum)
@mlx_funcify_CAReduce_scalar_op.register(Minimum)
def mlx_funcify_CARreduce_Minimum(scalar_op, axis):
def min_reduce(x):
return mx.min(x, axis=axis)
......@@ -354,13 +354,13 @@ def mlx_funcify_Elemwise_scalar_OR(scalar_op):
return mx.bitwise_or
@mlx_funcify_Elemwise_scalar_op.register(ScalarMaximum)
def mlx_funcify_Elemwise_scalar_ScalarMaximum(scalar_op):
@mlx_funcify_Elemwise_scalar_op.register(Maximum)
def mlx_funcify_Elemwise_scalar_Maximum(scalar_op):
return mx.maximum
@mlx_funcify_Elemwise_scalar_op.register(ScalarMinimum)
def mlx_funcify_Elemwise_scalar_ScalarMinimum(scalar_op):
@mlx_funcify_Elemwise_scalar_op.register(Minimum)
def mlx_funcify_Elemwise_scalar_Minimum(scalar_op):
return mx.minimum
......
......@@ -26,13 +26,13 @@ from pytensor.scalar.basic import (
XOR,
Add,
IntDiv,
Maximum,
Minimum,
Mul,
ScalarMaximum,
ScalarMinimum,
Sub,
TrueDiv,
get_scalar_type,
scalar_maximum,
maximum,
)
from pytensor.scalar.basic import add as add_as
from pytensor.tensor.blas import BatchedDot
......@@ -104,16 +104,16 @@ def scalar_in_place_fn_IntDiv(op, idx, res, arr):
return f"{res}[{idx}] //= {arr}"
@scalar_in_place_fn.register(ScalarMaximum)
def scalar_in_place_fn_ScalarMaximum(op, idx, res, arr):
@scalar_in_place_fn.register(Maximum)
def scalar_in_place_fn_Maximum(op, idx, res, arr):
return f"""
if {res}[{idx}] < {arr}:
{res}[{idx}] = {arr}
"""
@scalar_in_place_fn.register(ScalarMinimum)
def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr):
@scalar_in_place_fn.register(Minimum)
def scalar_in_place_fn_Minimum(op, idx, res, arr):
return f"""
if {res}[{idx}] > {arr}:
{res}[{idx}] = {arr}
......@@ -459,7 +459,7 @@ def numba_funcify_Softmax(op, node, **kwargs):
if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_multiaxis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
)
reduce_sum_py = create_multiaxis_reducer(
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
......@@ -523,7 +523,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_multiaxis_reducer(
scalar_maximum,
maximum,
-np.inf,
(axis,),
x_at.ndim,
......
......@@ -1855,7 +1855,7 @@ invert = Invert()
##############
# Arithmetic
##############
class ScalarMaximum(BinaryScalarOp):
class Maximum(BinaryScalarOp):
commutative = True
associative = True
nfunc_spec = ("maximum", 2, 1)
......@@ -1895,10 +1895,14 @@ class ScalarMaximum(BinaryScalarOp):
return (gx, gy)
scalar_maximum = ScalarMaximum(upcast_out, name="maximum")
maximum = Maximum(upcast_out)
# Backward compatibility
ScalarMaximum = Maximum
scalar_maximum = maximum
class ScalarMinimum(BinaryScalarOp):
class Minimum(BinaryScalarOp):
commutative = True
associative = True
nfunc_spec = ("minimum", 2, 1)
......@@ -1937,7 +1941,11 @@ class ScalarMinimum(BinaryScalarOp):
return (gx, gy)
scalar_minimum = ScalarMinimum(upcast_out, name="minimum")
minimum = Minimum(upcast_out)
# Backward compatibility
ScalarMinimum = Minimum
scalar_minimum = minimum
class Add(ScalarOp):
......
......@@ -32,8 +32,8 @@ from pytensor.scalar.basic import (
isinf,
log,
log1p,
maximum,
reciprocal,
scalar_maximum,
sqrt,
switch,
true_div,
......@@ -1315,7 +1315,7 @@ class Softplus(UnaryScalarOp):
return v
softplus = Softplus(upgrade_to_float, name="scalar_softplus")
softplus = Softplus(upgrade_to_float)
class Log1mexp(UnaryScalarOp):
......@@ -1360,7 +1360,7 @@ class Log1mexp(UnaryScalarOp):
raise NotImplementedError("only floating point is implemented")
log1mexp = Log1mexp(upgrade_to_float, name="scalar_log1mexp")
log1mexp = Log1mexp(upgrade_to_float)
class BetaInc(ScalarOp):
......@@ -1585,9 +1585,7 @@ def betainc_grad(p, q, x, wrtp: bool):
derivative_new = K * (F1 * dK + F2)
errapx = scalar_abs(derivative - derivative_new)
d_errapx = errapx / scalar_maximum(
err_threshold, scalar_abs(derivative_new)
)
d_errapx = errapx / maximum(err_threshold, scalar_abs(derivative_new))
min_iters_cond = n > (min_iters - 1)
derivative = switch(
......@@ -1833,7 +1831,7 @@ def _grad_2f1_loop(a, b, c, z, *, skip_loop, wrt, dtype):
if len(grad_incs) == 1:
[max_abs_grad_inc] = grad_incs
else:
max_abs_grad_inc = reduce(scalar_maximum, abs_grad_incs)
max_abs_grad_inc = reduce(maximum, abs_grad_incs)
return (
(*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k),
......
......@@ -948,8 +948,8 @@ class Gemm(GemmRelated):
z_shape, _, x_shape, y_shape, _ = input_shapes
return [
(
pytensor.scalar.scalar_maximum(z_shape[0], x_shape[0]),
pytensor.scalar.scalar_maximum(z_shape[1], y_shape[1]),
pytensor.scalar.maximum(z_shape[0], x_shape[0]),
pytensor.scalar.maximum(z_shape[1], y_shape[1]),
)
]
......
......@@ -357,12 +357,12 @@ fill_inplace = second_inplace
pprint.assign(fill_inplace, printing.FunctionPrinter(["fill="]))
@scalar_elemwise(symbolname="scalar_maximum_inplace")
@scalar_elemwise
def maximum_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
@scalar_elemwise(symbolname="scalar_minimum_inplace")
@scalar_elemwise
def minimum_inplace(a, b):
"""elementwise addition (inplace on `a`)"""
......
......@@ -399,7 +399,7 @@ class Max(NonZeroDimsCAReduce):
nfunc_spec = ("max", 1, 1)
def __init__(self, axis):
super().__init__(ps.scalar_maximum, axis)
super().__init__(ps.maximum, axis)
def clone(self, **kwargs):
axis = kwargs.get("axis", self.axis)
......@@ -457,7 +457,7 @@ class Min(NonZeroDimsCAReduce):
nfunc_spec = ("min", 1, 1)
def __init__(self, axis):
super().__init__(ps.scalar_minimum, axis)
super().__init__(ps.minimum, axis)
def clone(self, **kwargs):
axis = kwargs.get("axis", self.axis)
......@@ -2755,7 +2755,7 @@ def median(x: TensorLike, axis=None) -> TensorVariable:
return ifelse(even_k, even_median, odd_median, name="median")
@scalar_elemwise(symbolname="scalar_maximum")
@scalar_elemwise
def maximum(x, y):
"""elemwise maximum. See max for the maximum in one tensor
......@@ -2791,7 +2791,7 @@ def maximum(x, y):
# see decorator for function body
@scalar_elemwise(symbolname="scalar_minimum")
@scalar_elemwise
def minimum(x, y):
"""elemwise minimum. See min for the minimum in one tensor
......
......@@ -60,7 +60,7 @@ def local_max_to_min(fgraph, node):
if (
max.owner
and isinstance(max.owner.op, CAReduce)
and max.owner.op.scalar_op == ps.scalar_maximum
and max.owner.op.scalar_op == ps.maximum
):
neg_node = max.owner.inputs[0]
if neg_node.owner and neg_node.owner.op == neg:
......
......@@ -31,7 +31,7 @@ class XlogX(ps.UnaryScalarOp):
raise NotImplementedError("only floatingpoint is implemented")
scalar_xlogx = XlogX(ps.upgrade_to_float, name="scalar_xlogx")
scalar_xlogx = XlogX(ps.upgrade_to_float)
xlogx = Elemwise(scalar_xlogx, name="xlogx")
......@@ -62,5 +62,5 @@ class XlogY0(ps.BinaryScalarOp):
raise NotImplementedError("only floatingpoint is implemented")
scalar_xlogy0 = XlogY0(ps.upgrade_to_float, name="scalar_xlogy0")
scalar_xlogy0 = XlogY0(ps.upgrade_to_float)
xlogy0 = Elemwise(scalar_xlogy0, name="xlogy0")
......@@ -388,11 +388,11 @@ def reciprocal(): ...
def round(): ...
@_as_xelemwise(ps.scalar_maximum)
@_as_xelemwise(ps.maximum)
def maximum(): ...
@_as_xelemwise(ps.scalar_minimum)
@_as_xelemwise(ps.minimum)
def minimum(): ...
......
......@@ -63,8 +63,8 @@ def reduce(x, dim: REDUCE_DIM = None, *, binary_op):
sum = partial(reduce, binary_op=ps.add)
prod = partial(reduce, binary_op=ps.mul)
max = partial(reduce, binary_op=ps.scalar_maximum)
min = partial(reduce, binary_op=ps.scalar_minimum)
max = partial(reduce, binary_op=ps.maximum)
min = partial(reduce, binary_op=ps.minimum)
def bool_reduce(x, dim: REDUCE_DIM = None, *, binary_op):
......
......@@ -30,9 +30,9 @@ def lower_reduce(fgraph, node):
tensor_op_class = All
case ps.or_:
tensor_op_class = Any
case ps.scalar_maximum:
case ps.maximum:
tensor_op_class = Max
case ps.scalar_minimum:
case ps.minimum:
tensor_op_class = Min
case _:
# Case without known/predefined Ops
......
......@@ -3887,10 +3887,7 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None):
fgraph = f.maker.fgraph.toposort()
for node in fgraph:
if (
hasattr(node.op, "scalar_op")
and node.op.scalar_op == ps.basic.scalar_maximum
):
if hasattr(node.op, "scalar_op") and node.op.scalar_op == ps.basic.maximum:
return
# In mode FAST_COMPILE, the rewrites don't replace the
......
......@@ -544,14 +544,14 @@ class TestCAReduce(unittest_tools.InferShapeTester):
elif scalar_op == ps.mul:
for axis in sorted(tosum, reverse=True):
zv = np.multiply.reduce(zv, axis)
elif scalar_op == ps.scalar_maximum:
elif scalar_op == ps.maximum:
# There is no identity value for the maximum function
# So we can't support shape of dimensions 0.
if np.prod(zv.shape) == 0:
continue
for axis in sorted(tosum, reverse=True):
zv = np.maximum.reduce(zv, axis)
elif scalar_op == ps.scalar_minimum:
elif scalar_op == ps.minimum:
# There is no identity value for the minimum function
# So we can't support shape of dimensions 0.
if np.prod(zv.shape) == 0:
......@@ -594,7 +594,7 @@ class TestCAReduce(unittest_tools.InferShapeTester):
tosum = list(range(len(xsh)))
f = pytensor.function([x], e.shape, mode=mode, on_unused_input="ignore")
if not (
scalar_op in [ps.scalar_maximum, ps.scalar_minimum]
scalar_op in [ps.maximum, ps.minimum]
and (xsh == () or np.prod(xsh) == 0)
):
assert all(f(xv) == zv.shape)
......@@ -606,8 +606,8 @@ class TestCAReduce(unittest_tools.InferShapeTester):
for dtype in ["bool", "floatX", "complex64", "complex128", "int8", "uint8"]:
self.with_mode(Mode(linker="py"), ps.add, dtype=dtype)
self.with_mode(Mode(linker="py"), ps.mul, dtype=dtype)
self.with_mode(Mode(linker="py"), ps.scalar_maximum, dtype=dtype)
self.with_mode(Mode(linker="py"), ps.scalar_minimum, dtype=dtype)
self.with_mode(Mode(linker="py"), ps.maximum, dtype=dtype)
self.with_mode(Mode(linker="py"), ps.minimum, dtype=dtype)
self.with_mode(Mode(linker="py"), ps.and_, dtype=dtype, tensor_op=pt_all)
self.with_mode(Mode(linker="py"), ps.or_, dtype=dtype, tensor_op=pt_any)
for dtype in ["int8", "uint8"]:
......@@ -619,12 +619,8 @@ class TestCAReduce(unittest_tools.InferShapeTester):
for dtype in ["floatX", "complex64", "complex128"]:
self.with_mode(Mode(linker="py"), ps.add, dtype=dtype, test_nan=True)
self.with_mode(Mode(linker="py"), ps.mul, dtype=dtype, test_nan=True)
self.with_mode(
Mode(linker="py"), ps.scalar_maximum, dtype=dtype, test_nan=True
)
self.with_mode(
Mode(linker="py"), ps.scalar_minimum, dtype=dtype, test_nan=True
)
self.with_mode(Mode(linker="py"), ps.maximum, dtype=dtype, test_nan=True)
self.with_mode(Mode(linker="py"), ps.minimum, dtype=dtype, test_nan=True)
self.with_mode(
Mode(linker="py"),
ps.or_,
......@@ -659,8 +655,8 @@ class TestCAReduce(unittest_tools.InferShapeTester):
self.with_mode(Mode(linker="c"), ps.add, dtype=dtype)
self.with_mode(Mode(linker="c"), ps.mul, dtype=dtype)
for dtype in ["bool", "floatX", "int8", "uint8"]:
self.with_mode(Mode(linker="c"), ps.scalar_minimum, dtype=dtype)
self.with_mode(Mode(linker="c"), ps.scalar_maximum, dtype=dtype)
self.with_mode(Mode(linker="c"), ps.minimum, dtype=dtype)
self.with_mode(Mode(linker="c"), ps.maximum, dtype=dtype)
self.with_mode(Mode(linker="c"), ps.and_, dtype=dtype, tensor_op=pt_all)
self.with_mode(Mode(linker="c"), ps.or_, dtype=dtype, tensor_op=pt_any)
for dtype in ["bool", "int8", "uint8"]:
......@@ -678,12 +674,8 @@ class TestCAReduce(unittest_tools.InferShapeTester):
self.with_mode(Mode(linker="c"), ps.add, dtype=dtype, test_nan=True)
self.with_mode(Mode(linker="c"), ps.mul, dtype=dtype, test_nan=True)
for dtype in ["floatX"]:
self.with_mode(
Mode(linker="c"), ps.scalar_minimum, dtype=dtype, test_nan=True
)
self.with_mode(
Mode(linker="c"), ps.scalar_maximum, dtype=dtype, test_nan=True
)
self.with_mode(Mode(linker="c"), ps.minimum, dtype=dtype, test_nan=True)
self.with_mode(Mode(linker="c"), ps.maximum, dtype=dtype, test_nan=True)
def test_infer_shape(self, dtype=None, pre_scalar_op=None):
if dtype is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论