Unverified 提交 17c675a2 authored 作者: Copilot's avatar Copilot 提交者: GitHub

Remove `scalar_` prefix from several Ops (#1683)

* Initial plan * Rename ScalarMaximum/ScalarMinimum to Maximum/Minimum * Apply ruff formatting fixes * Remove custom names when class name matches desired name * Remove scalar_ prefix from log1mexp, xlogx, xlogy0 and fix numba imports * Fix xtensor test to use backward compat aliases in skip list --------- Co-authored-by: 's avatarcopilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
上级 3082ed5e
...@@ -25,11 +25,11 @@ from pytensor.scalar.basic import ( ...@@ -25,11 +25,11 @@ from pytensor.scalar.basic import (
IsNan, IsNan,
Log, Log,
Log1p, Log1p,
Maximum,
Minimum,
Mul, Mul,
Neg, Neg,
Pow, Pow,
ScalarMaximum,
ScalarMinimum,
Sign, Sign,
Sin, Sin,
Sqr, Sqr,
...@@ -105,7 +105,7 @@ def mlx_funcify_CARreduce_OR(scalar_op, axis): ...@@ -105,7 +105,7 @@ def mlx_funcify_CARreduce_OR(scalar_op, axis):
return any_reduce return any_reduce
@mlx_funcify_CAReduce_scalar_op.register(ScalarMaximum) @mlx_funcify_CAReduce_scalar_op.register(Maximum)
def mlx_funcify_CARreduce_Maximum(scalar_op, axis): def mlx_funcify_CARreduce_Maximum(scalar_op, axis):
def max_reduce(x): def max_reduce(x):
return mx.max(x, axis=axis) return mx.max(x, axis=axis)
...@@ -113,7 +113,7 @@ def mlx_funcify_CARreduce_Maximum(scalar_op, axis): ...@@ -113,7 +113,7 @@ def mlx_funcify_CARreduce_Maximum(scalar_op, axis):
return max_reduce return max_reduce
@mlx_funcify_CAReduce_scalar_op.register(ScalarMinimum) @mlx_funcify_CAReduce_scalar_op.register(Minimum)
def mlx_funcify_CARreduce_Minimum(scalar_op, axis): def mlx_funcify_CARreduce_Minimum(scalar_op, axis):
def min_reduce(x): def min_reduce(x):
return mx.min(x, axis=axis) return mx.min(x, axis=axis)
...@@ -354,13 +354,13 @@ def mlx_funcify_Elemwise_scalar_OR(scalar_op): ...@@ -354,13 +354,13 @@ def mlx_funcify_Elemwise_scalar_OR(scalar_op):
return mx.bitwise_or return mx.bitwise_or
@mlx_funcify_Elemwise_scalar_op.register(ScalarMaximum) @mlx_funcify_Elemwise_scalar_op.register(Maximum)
def mlx_funcify_Elemwise_scalar_ScalarMaximum(scalar_op): def mlx_funcify_Elemwise_scalar_Maximum(scalar_op):
return mx.maximum return mx.maximum
@mlx_funcify_Elemwise_scalar_op.register(ScalarMinimum) @mlx_funcify_Elemwise_scalar_op.register(Minimum)
def mlx_funcify_Elemwise_scalar_ScalarMinimum(scalar_op): def mlx_funcify_Elemwise_scalar_Minimum(scalar_op):
return mx.minimum return mx.minimum
......
...@@ -26,13 +26,13 @@ from pytensor.scalar.basic import ( ...@@ -26,13 +26,13 @@ from pytensor.scalar.basic import (
XOR, XOR,
Add, Add,
IntDiv, IntDiv,
Maximum,
Minimum,
Mul, Mul,
ScalarMaximum,
ScalarMinimum,
Sub, Sub,
TrueDiv, TrueDiv,
get_scalar_type, get_scalar_type,
scalar_maximum, maximum,
) )
from pytensor.scalar.basic import add as add_as from pytensor.scalar.basic import add as add_as
from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blas import BatchedDot
...@@ -104,16 +104,16 @@ def scalar_in_place_fn_IntDiv(op, idx, res, arr): ...@@ -104,16 +104,16 @@ def scalar_in_place_fn_IntDiv(op, idx, res, arr):
return f"{res}[{idx}] //= {arr}" return f"{res}[{idx}] //= {arr}"
@scalar_in_place_fn.register(ScalarMaximum) @scalar_in_place_fn.register(Maximum)
def scalar_in_place_fn_ScalarMaximum(op, idx, res, arr): def scalar_in_place_fn_Maximum(op, idx, res, arr):
return f""" return f"""
if {res}[{idx}] < {arr}: if {res}[{idx}] < {arr}:
{res}[{idx}] = {arr} {res}[{idx}] = {arr}
""" """
@scalar_in_place_fn.register(ScalarMinimum) @scalar_in_place_fn.register(Minimum)
def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr): def scalar_in_place_fn_Minimum(op, idx, res, arr):
return f""" return f"""
if {res}[{idx}] > {arr}: if {res}[{idx}] > {arr}:
{res}[{idx}] = {arr} {res}[{idx}] = {arr}
...@@ -459,7 +459,7 @@ def numba_funcify_Softmax(op, node, **kwargs): ...@@ -459,7 +459,7 @@ def numba_funcify_Softmax(op, node, **kwargs):
if axis is not None: if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim) axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_multiaxis_reducer( reduce_max_py = create_multiaxis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
) )
reduce_sum_py = create_multiaxis_reducer( reduce_sum_py = create_multiaxis_reducer(
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
...@@ -523,7 +523,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): ...@@ -523,7 +523,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
if axis is not None: if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim) axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_multiaxis_reducer( reduce_max_py = create_multiaxis_reducer(
scalar_maximum, maximum,
-np.inf, -np.inf,
(axis,), (axis,),
x_at.ndim, x_at.ndim,
......
...@@ -1855,7 +1855,7 @@ invert = Invert() ...@@ -1855,7 +1855,7 @@ invert = Invert()
############## ##############
# Arithmetic # Arithmetic
############## ##############
class ScalarMaximum(BinaryScalarOp): class Maximum(BinaryScalarOp):
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ("maximum", 2, 1) nfunc_spec = ("maximum", 2, 1)
...@@ -1895,10 +1895,14 @@ class ScalarMaximum(BinaryScalarOp): ...@@ -1895,10 +1895,14 @@ class ScalarMaximum(BinaryScalarOp):
return (gx, gy) return (gx, gy)
scalar_maximum = ScalarMaximum(upcast_out, name="maximum") maximum = Maximum(upcast_out)
# Backward compatibility
ScalarMaximum = Maximum
scalar_maximum = maximum
class ScalarMinimum(BinaryScalarOp):
class Minimum(BinaryScalarOp):
commutative = True commutative = True
associative = True associative = True
nfunc_spec = ("minimum", 2, 1) nfunc_spec = ("minimum", 2, 1)
...@@ -1937,7 +1941,11 @@ class ScalarMinimum(BinaryScalarOp): ...@@ -1937,7 +1941,11 @@ class ScalarMinimum(BinaryScalarOp):
return (gx, gy) return (gx, gy)
scalar_minimum = ScalarMinimum(upcast_out, name="minimum") minimum = Minimum(upcast_out)
# Backward compatibility
ScalarMinimum = Minimum
scalar_minimum = minimum
class Add(ScalarOp): class Add(ScalarOp):
......
...@@ -32,8 +32,8 @@ from pytensor.scalar.basic import ( ...@@ -32,8 +32,8 @@ from pytensor.scalar.basic import (
isinf, isinf,
log, log,
log1p, log1p,
maximum,
reciprocal, reciprocal,
scalar_maximum,
sqrt, sqrt,
switch, switch,
true_div, true_div,
...@@ -1315,7 +1315,7 @@ class Softplus(UnaryScalarOp): ...@@ -1315,7 +1315,7 @@ class Softplus(UnaryScalarOp):
return v return v
softplus = Softplus(upgrade_to_float, name="scalar_softplus") softplus = Softplus(upgrade_to_float)
class Log1mexp(UnaryScalarOp): class Log1mexp(UnaryScalarOp):
...@@ -1360,7 +1360,7 @@ class Log1mexp(UnaryScalarOp): ...@@ -1360,7 +1360,7 @@ class Log1mexp(UnaryScalarOp):
raise NotImplementedError("only floating point is implemented") raise NotImplementedError("only floating point is implemented")
log1mexp = Log1mexp(upgrade_to_float, name="scalar_log1mexp") log1mexp = Log1mexp(upgrade_to_float)
class BetaInc(ScalarOp): class BetaInc(ScalarOp):
...@@ -1585,9 +1585,7 @@ def betainc_grad(p, q, x, wrtp: bool): ...@@ -1585,9 +1585,7 @@ def betainc_grad(p, q, x, wrtp: bool):
derivative_new = K * (F1 * dK + F2) derivative_new = K * (F1 * dK + F2)
errapx = scalar_abs(derivative - derivative_new) errapx = scalar_abs(derivative - derivative_new)
d_errapx = errapx / scalar_maximum( d_errapx = errapx / maximum(err_threshold, scalar_abs(derivative_new))
err_threshold, scalar_abs(derivative_new)
)
min_iters_cond = n > (min_iters - 1) min_iters_cond = n > (min_iters - 1)
derivative = switch( derivative = switch(
...@@ -1833,7 +1831,7 @@ def _grad_2f1_loop(a, b, c, z, *, skip_loop, wrt, dtype): ...@@ -1833,7 +1831,7 @@ def _grad_2f1_loop(a, b, c, z, *, skip_loop, wrt, dtype):
if len(grad_incs) == 1: if len(grad_incs) == 1:
[max_abs_grad_inc] = grad_incs [max_abs_grad_inc] = grad_incs
else: else:
max_abs_grad_inc = reduce(scalar_maximum, abs_grad_incs) max_abs_grad_inc = reduce(maximum, abs_grad_incs)
return ( return (
(*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k), (*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k),
......
...@@ -948,8 +948,8 @@ class Gemm(GemmRelated): ...@@ -948,8 +948,8 @@ class Gemm(GemmRelated):
z_shape, _, x_shape, y_shape, _ = input_shapes z_shape, _, x_shape, y_shape, _ = input_shapes
return [ return [
( (
pytensor.scalar.scalar_maximum(z_shape[0], x_shape[0]), pytensor.scalar.maximum(z_shape[0], x_shape[0]),
pytensor.scalar.scalar_maximum(z_shape[1], y_shape[1]), pytensor.scalar.maximum(z_shape[1], y_shape[1]),
) )
] ]
......
...@@ -357,12 +357,12 @@ fill_inplace = second_inplace ...@@ -357,12 +357,12 @@ fill_inplace = second_inplace
pprint.assign(fill_inplace, printing.FunctionPrinter(["fill="])) pprint.assign(fill_inplace, printing.FunctionPrinter(["fill="]))
@scalar_elemwise(symbolname="scalar_maximum_inplace") @scalar_elemwise
def maximum_inplace(a, b): def maximum_inplace(a, b):
"""elementwise addition (inplace on `a`)""" """elementwise addition (inplace on `a`)"""
@scalar_elemwise(symbolname="scalar_minimum_inplace") @scalar_elemwise
def minimum_inplace(a, b): def minimum_inplace(a, b):
"""elementwise addition (inplace on `a`)""" """elementwise addition (inplace on `a`)"""
......
...@@ -399,7 +399,7 @@ class Max(NonZeroDimsCAReduce): ...@@ -399,7 +399,7 @@ class Max(NonZeroDimsCAReduce):
nfunc_spec = ("max", 1, 1) nfunc_spec = ("max", 1, 1)
def __init__(self, axis): def __init__(self, axis):
super().__init__(ps.scalar_maximum, axis) super().__init__(ps.maximum, axis)
def clone(self, **kwargs): def clone(self, **kwargs):
axis = kwargs.get("axis", self.axis) axis = kwargs.get("axis", self.axis)
...@@ -457,7 +457,7 @@ class Min(NonZeroDimsCAReduce): ...@@ -457,7 +457,7 @@ class Min(NonZeroDimsCAReduce):
nfunc_spec = ("min", 1, 1) nfunc_spec = ("min", 1, 1)
def __init__(self, axis): def __init__(self, axis):
super().__init__(ps.scalar_minimum, axis) super().__init__(ps.minimum, axis)
def clone(self, **kwargs): def clone(self, **kwargs):
axis = kwargs.get("axis", self.axis) axis = kwargs.get("axis", self.axis)
...@@ -2755,7 +2755,7 @@ def median(x: TensorLike, axis=None) -> TensorVariable: ...@@ -2755,7 +2755,7 @@ def median(x: TensorLike, axis=None) -> TensorVariable:
return ifelse(even_k, even_median, odd_median, name="median") return ifelse(even_k, even_median, odd_median, name="median")
@scalar_elemwise(symbolname="scalar_maximum") @scalar_elemwise
def maximum(x, y): def maximum(x, y):
"""elemwise maximum. See max for the maximum in one tensor """elemwise maximum. See max for the maximum in one tensor
...@@ -2791,7 +2791,7 @@ def maximum(x, y): ...@@ -2791,7 +2791,7 @@ def maximum(x, y):
# see decorator for function body # see decorator for function body
@scalar_elemwise(symbolname="scalar_minimum") @scalar_elemwise
def minimum(x, y): def minimum(x, y):
"""elemwise minimum. See min for the minimum in one tensor """elemwise minimum. See min for the minimum in one tensor
......
...@@ -60,7 +60,7 @@ def local_max_to_min(fgraph, node): ...@@ -60,7 +60,7 @@ def local_max_to_min(fgraph, node):
if ( if (
max.owner max.owner
and isinstance(max.owner.op, CAReduce) and isinstance(max.owner.op, CAReduce)
and max.owner.op.scalar_op == ps.scalar_maximum and max.owner.op.scalar_op == ps.maximum
): ):
neg_node = max.owner.inputs[0] neg_node = max.owner.inputs[0]
if neg_node.owner and neg_node.owner.op == neg: if neg_node.owner and neg_node.owner.op == neg:
......
...@@ -31,7 +31,7 @@ class XlogX(ps.UnaryScalarOp): ...@@ -31,7 +31,7 @@ class XlogX(ps.UnaryScalarOp):
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
scalar_xlogx = XlogX(ps.upgrade_to_float, name="scalar_xlogx") scalar_xlogx = XlogX(ps.upgrade_to_float)
xlogx = Elemwise(scalar_xlogx, name="xlogx") xlogx = Elemwise(scalar_xlogx, name="xlogx")
...@@ -62,5 +62,5 @@ class XlogY0(ps.BinaryScalarOp): ...@@ -62,5 +62,5 @@ class XlogY0(ps.BinaryScalarOp):
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
scalar_xlogy0 = XlogY0(ps.upgrade_to_float, name="scalar_xlogy0") scalar_xlogy0 = XlogY0(ps.upgrade_to_float)
xlogy0 = Elemwise(scalar_xlogy0, name="xlogy0") xlogy0 = Elemwise(scalar_xlogy0, name="xlogy0")
...@@ -388,11 +388,11 @@ def reciprocal(): ... ...@@ -388,11 +388,11 @@ def reciprocal(): ...
def round(): ... def round(): ...
@_as_xelemwise(ps.scalar_maximum) @_as_xelemwise(ps.maximum)
def maximum(): ... def maximum(): ...
@_as_xelemwise(ps.scalar_minimum) @_as_xelemwise(ps.minimum)
def minimum(): ... def minimum(): ...
......
...@@ -63,8 +63,8 @@ def reduce(x, dim: REDUCE_DIM = None, *, binary_op): ...@@ -63,8 +63,8 @@ def reduce(x, dim: REDUCE_DIM = None, *, binary_op):
sum = partial(reduce, binary_op=ps.add) sum = partial(reduce, binary_op=ps.add)
prod = partial(reduce, binary_op=ps.mul) prod = partial(reduce, binary_op=ps.mul)
max = partial(reduce, binary_op=ps.scalar_maximum) max = partial(reduce, binary_op=ps.maximum)
min = partial(reduce, binary_op=ps.scalar_minimum) min = partial(reduce, binary_op=ps.minimum)
def bool_reduce(x, dim: REDUCE_DIM = None, *, binary_op): def bool_reduce(x, dim: REDUCE_DIM = None, *, binary_op):
......
...@@ -30,9 +30,9 @@ def lower_reduce(fgraph, node): ...@@ -30,9 +30,9 @@ def lower_reduce(fgraph, node):
tensor_op_class = All tensor_op_class = All
case ps.or_: case ps.or_:
tensor_op_class = Any tensor_op_class = Any
case ps.scalar_maximum: case ps.maximum:
tensor_op_class = Max tensor_op_class = Max
case ps.scalar_minimum: case ps.minimum:
tensor_op_class = Min tensor_op_class = Min
case _: case _:
# Case without known/predefined Ops # Case without known/predefined Ops
......
...@@ -3887,10 +3887,7 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None): ...@@ -3887,10 +3887,7 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None):
fgraph = f.maker.fgraph.toposort() fgraph = f.maker.fgraph.toposort()
for node in fgraph: for node in fgraph:
if ( if hasattr(node.op, "scalar_op") and node.op.scalar_op == ps.basic.maximum:
hasattr(node.op, "scalar_op")
and node.op.scalar_op == ps.basic.scalar_maximum
):
return return
# In mode FAST_COMPILE, the rewrites don't replace the # In mode FAST_COMPILE, the rewrites don't replace the
......
...@@ -544,14 +544,14 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -544,14 +544,14 @@ class TestCAReduce(unittest_tools.InferShapeTester):
elif scalar_op == ps.mul: elif scalar_op == ps.mul:
for axis in sorted(tosum, reverse=True): for axis in sorted(tosum, reverse=True):
zv = np.multiply.reduce(zv, axis) zv = np.multiply.reduce(zv, axis)
elif scalar_op == ps.scalar_maximum: elif scalar_op == ps.maximum:
# There is no identity value for the maximum function # There is no identity value for the maximum function
# So we can't support shape of dimensions 0. # So we can't support shape of dimensions 0.
if np.prod(zv.shape) == 0: if np.prod(zv.shape) == 0:
continue continue
for axis in sorted(tosum, reverse=True): for axis in sorted(tosum, reverse=True):
zv = np.maximum.reduce(zv, axis) zv = np.maximum.reduce(zv, axis)
elif scalar_op == ps.scalar_minimum: elif scalar_op == ps.minimum:
# There is no identity value for the minimum function # There is no identity value for the minimum function
# So we can't support shape of dimensions 0. # So we can't support shape of dimensions 0.
if np.prod(zv.shape) == 0: if np.prod(zv.shape) == 0:
...@@ -594,7 +594,7 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -594,7 +594,7 @@ class TestCAReduce(unittest_tools.InferShapeTester):
tosum = list(range(len(xsh))) tosum = list(range(len(xsh)))
f = pytensor.function([x], e.shape, mode=mode, on_unused_input="ignore") f = pytensor.function([x], e.shape, mode=mode, on_unused_input="ignore")
if not ( if not (
scalar_op in [ps.scalar_maximum, ps.scalar_minimum] scalar_op in [ps.maximum, ps.minimum]
and (xsh == () or np.prod(xsh) == 0) and (xsh == () or np.prod(xsh) == 0)
): ):
assert all(f(xv) == zv.shape) assert all(f(xv) == zv.shape)
...@@ -606,8 +606,8 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -606,8 +606,8 @@ class TestCAReduce(unittest_tools.InferShapeTester):
for dtype in ["bool", "floatX", "complex64", "complex128", "int8", "uint8"]: for dtype in ["bool", "floatX", "complex64", "complex128", "int8", "uint8"]:
self.with_mode(Mode(linker="py"), ps.add, dtype=dtype) self.with_mode(Mode(linker="py"), ps.add, dtype=dtype)
self.with_mode(Mode(linker="py"), ps.mul, dtype=dtype) self.with_mode(Mode(linker="py"), ps.mul, dtype=dtype)
self.with_mode(Mode(linker="py"), ps.scalar_maximum, dtype=dtype) self.with_mode(Mode(linker="py"), ps.maximum, dtype=dtype)
self.with_mode(Mode(linker="py"), ps.scalar_minimum, dtype=dtype) self.with_mode(Mode(linker="py"), ps.minimum, dtype=dtype)
self.with_mode(Mode(linker="py"), ps.and_, dtype=dtype, tensor_op=pt_all) self.with_mode(Mode(linker="py"), ps.and_, dtype=dtype, tensor_op=pt_all)
self.with_mode(Mode(linker="py"), ps.or_, dtype=dtype, tensor_op=pt_any) self.with_mode(Mode(linker="py"), ps.or_, dtype=dtype, tensor_op=pt_any)
for dtype in ["int8", "uint8"]: for dtype in ["int8", "uint8"]:
...@@ -619,12 +619,8 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -619,12 +619,8 @@ class TestCAReduce(unittest_tools.InferShapeTester):
for dtype in ["floatX", "complex64", "complex128"]: for dtype in ["floatX", "complex64", "complex128"]:
self.with_mode(Mode(linker="py"), ps.add, dtype=dtype, test_nan=True) self.with_mode(Mode(linker="py"), ps.add, dtype=dtype, test_nan=True)
self.with_mode(Mode(linker="py"), ps.mul, dtype=dtype, test_nan=True) self.with_mode(Mode(linker="py"), ps.mul, dtype=dtype, test_nan=True)
self.with_mode( self.with_mode(Mode(linker="py"), ps.maximum, dtype=dtype, test_nan=True)
Mode(linker="py"), ps.scalar_maximum, dtype=dtype, test_nan=True self.with_mode(Mode(linker="py"), ps.minimum, dtype=dtype, test_nan=True)
)
self.with_mode(
Mode(linker="py"), ps.scalar_minimum, dtype=dtype, test_nan=True
)
self.with_mode( self.with_mode(
Mode(linker="py"), Mode(linker="py"),
ps.or_, ps.or_,
...@@ -659,8 +655,8 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -659,8 +655,8 @@ class TestCAReduce(unittest_tools.InferShapeTester):
self.with_mode(Mode(linker="c"), ps.add, dtype=dtype) self.with_mode(Mode(linker="c"), ps.add, dtype=dtype)
self.with_mode(Mode(linker="c"), ps.mul, dtype=dtype) self.with_mode(Mode(linker="c"), ps.mul, dtype=dtype)
for dtype in ["bool", "floatX", "int8", "uint8"]: for dtype in ["bool", "floatX", "int8", "uint8"]:
self.with_mode(Mode(linker="c"), ps.scalar_minimum, dtype=dtype) self.with_mode(Mode(linker="c"), ps.minimum, dtype=dtype)
self.with_mode(Mode(linker="c"), ps.scalar_maximum, dtype=dtype) self.with_mode(Mode(linker="c"), ps.maximum, dtype=dtype)
self.with_mode(Mode(linker="c"), ps.and_, dtype=dtype, tensor_op=pt_all) self.with_mode(Mode(linker="c"), ps.and_, dtype=dtype, tensor_op=pt_all)
self.with_mode(Mode(linker="c"), ps.or_, dtype=dtype, tensor_op=pt_any) self.with_mode(Mode(linker="c"), ps.or_, dtype=dtype, tensor_op=pt_any)
for dtype in ["bool", "int8", "uint8"]: for dtype in ["bool", "int8", "uint8"]:
...@@ -678,12 +674,8 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -678,12 +674,8 @@ class TestCAReduce(unittest_tools.InferShapeTester):
self.with_mode(Mode(linker="c"), ps.add, dtype=dtype, test_nan=True) self.with_mode(Mode(linker="c"), ps.add, dtype=dtype, test_nan=True)
self.with_mode(Mode(linker="c"), ps.mul, dtype=dtype, test_nan=True) self.with_mode(Mode(linker="c"), ps.mul, dtype=dtype, test_nan=True)
for dtype in ["floatX"]: for dtype in ["floatX"]:
self.with_mode( self.with_mode(Mode(linker="c"), ps.minimum, dtype=dtype, test_nan=True)
Mode(linker="c"), ps.scalar_minimum, dtype=dtype, test_nan=True self.with_mode(Mode(linker="c"), ps.maximum, dtype=dtype, test_nan=True)
)
self.with_mode(
Mode(linker="c"), ps.scalar_maximum, dtype=dtype, test_nan=True
)
def test_infer_shape(self, dtype=None, pre_scalar_op=None): def test_infer_shape(self, dtype=None, pre_scalar_op=None):
if dtype is None: if dtype is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论