Unverified 提交 e6e6d69f authored 作者: Dhruvanshu-Joshi's avatar Dhruvanshu-Joshi 提交者: GitHub

Break MaxandArgmax Op to seperate TensorMax Op and Argmax Op (#731)

* Break MaxandArgmax to TensorMax and Argmax seperately * XFAIL pytensor tests for uint64 data type * Deprecate and raise AttributeError for MaxAndArgmax
上级 dbe0e09a
...@@ -477,7 +477,8 @@ acceptable_ops = ( ...@@ -477,7 +477,8 @@ acceptable_ops = (
Reshape, Reshape,
Unbroadcast, Unbroadcast,
pt.math.Dot, pt.math.Dot,
pt.math.MaxAndArgmax, pt.math.Max,
pt.math.Argmax,
pt.subtensor.Subtensor, pt.subtensor.Subtensor,
pt.subtensor.IncSubtensor, pt.subtensor.IncSubtensor,
pt.basic.Alloc, pt.basic.Alloc,
......
...@@ -2,7 +2,7 @@ import jax.numpy as jnp ...@@ -2,7 +2,7 @@ import jax.numpy as jnp
from pytensor.link.jax.dispatch import jax_funcify from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot, MaxAndArgmax from pytensor.tensor.math import Argmax, Dot, Max
from pytensor.tensor.nlinalg import ( from pytensor.tensor.nlinalg import (
SVD, SVD,
Det, Det,
...@@ -104,18 +104,28 @@ def jax_funcify_BatchedDot(op, **kwargs): ...@@ -104,18 +104,28 @@ def jax_funcify_BatchedDot(op, **kwargs):
return batched_dot return batched_dot
@jax_funcify.register(MaxAndArgmax) @jax_funcify.register(Max)
def jax_funcify_MaxAndArgmax(op, **kwargs): def jax_funcify_Max(op, **kwargs):
axis = op.axis axis = op.axis
def maxandargmax(x, axis=axis): def max(x):
max_res = jnp.max(x, axis)
return max_res
return max
@jax_funcify.register(Argmax)
def jax_funcify_Argmax(op, **kwargs):
axis = op.axis
def argmax(x):
if axis is None: if axis is None:
axes = tuple(range(x.ndim)) axes = tuple(range(x.ndim))
else: else:
axes = tuple(int(ax) for ax in axis) axes = tuple(int(ax) for ax in axis)
max_res = jnp.max(x, axis)
# NumPy does not support multiple axes for argmax; this is a # NumPy does not support multiple axes for argmax; this is a
# work-around # work-around
keep_axes = jnp.array( keep_axes = jnp.array(
...@@ -138,6 +148,6 @@ def jax_funcify_MaxAndArgmax(op, **kwargs): ...@@ -138,6 +148,6 @@ def jax_funcify_MaxAndArgmax(op, **kwargs):
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64") max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
return max_res, max_idx_res return max_idx_res
return maxandargmax return argmax
...@@ -44,7 +44,7 @@ from pytensor.scalar.basic import ( ...@@ -44,7 +44,7 @@ from pytensor.scalar.basic import (
) )
from pytensor.scalar.basic import add as add_as from pytensor.scalar.basic import add as add_as
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from pytensor.tensor.type import scalar from pytensor.tensor.type import scalar
...@@ -827,8 +827,8 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): ...@@ -827,8 +827,8 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
return log_softmax return log_softmax
@numba_funcify.register(MaxAndArgmax) @numba_funcify.register(Argmax)
def numba_funcify_MaxAndArgmax(op, node, **kwargs): def numba_funcify_Argmax(op, node, **kwargs):
axis = op.axis axis = op.axis
x_at = node.inputs[0] x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype x_dtype = x_at.type.numpy_dtype
...@@ -838,8 +838,8 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs): ...@@ -838,8 +838,8 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
if x_ndim == 0: if x_ndim == 0:
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
def maxandargmax(x): def argmax(x):
return x, 0 return 0
else: else:
axes = tuple(int(ax) for ax in axis) axes = tuple(int(ax) for ax in axis)
...@@ -848,20 +848,6 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs): ...@@ -848,20 +848,6 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
# work-around # work-around
keep_axes = tuple(i for i in range(x_ndim) if i not in axes) keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
reduce_max_py_fn = create_multiaxis_reducer(
scalar_maximum,
-np.inf,
axes,
x_ndim,
x_dtype,
return_scalar=False,
)
reduce_max = jit_compile_reducer(
Apply(node.op, node.inputs, [node.outputs[0].clone()]),
reduce_max_py_fn,
reduce_to_scalar=False,
)
reduced_x_ndim = x_ndim - len(axes) + 1 reduced_x_ndim = x_ndim - len(axes) + 1
argmax_axis = create_axis_apply_fn( argmax_axis = create_axis_apply_fn(
np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64 np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
...@@ -872,9 +858,7 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs): ...@@ -872,9 +858,7 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
sl2 = slice(len(keep_axes), None) sl2 = slice(len(keep_axes), None)
@numba_basic.numba_njit @numba_basic.numba_njit
def maxandargmax(x): def argmax(x):
max_res = reduce_max(x)
# Not-reduced axes in front # Not-reduced axes in front
transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order)) transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order))
kept_shape = transposed_x.shape[sl1] kept_shape = transposed_x.shape[sl1]
...@@ -890,6 +874,6 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs): ...@@ -890,6 +874,6 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
max_idx_res = argmax_axis(reshaped_x) max_idx_res = argmax_axis(reshaped_x)
return max_res, max_idx_res return max_idx_res
return maxandargmax return argmax
...@@ -14,7 +14,6 @@ from pytensor.graph.op import Op ...@@ -14,7 +14,6 @@ from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.link.c.type import Generic
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import pprint from pytensor.printing import pprint
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
...@@ -29,6 +28,7 @@ from pytensor.tensor.basic import ( ...@@ -29,6 +28,7 @@ from pytensor.tensor.basic import (
constant, constant,
stack, stack,
switch, switch,
zeros_like,
) )
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.elemwise import ( from pytensor.tensor.elemwise import (
...@@ -107,6 +107,14 @@ else: ...@@ -107,6 +107,14 @@ else:
float64_atol = 1e-8 float64_atol = 1e-8
def __getattr__(name):
if name == "MaxAndArgmax":
raise AttributeError(
"The class `MaxandArgmax` has been deprecated. Call `Max` and `Argmax` seperately as an alternative."
)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def _get_atol_rtol(a, b): def _get_atol_rtol(a, b):
tiny = ("float16",) tiny = ("float16",)
narrow = ("float32", "complex64") narrow = ("float32", "complex64")
...@@ -134,215 +142,6 @@ def _allclose(a, b, rtol=None, atol=None): ...@@ -134,215 +142,6 @@ def _allclose(a, b, rtol=None, atol=None):
return np.allclose(a, b, atol=atol_, rtol=rtol_) return np.allclose(a, b, atol=atol_, rtol=rtol_)
class MaxAndArgmax(COp):
"""
Calculate the max and argmax over a given axis or over all axes.
"""
nin = 2 # tensor, axis
nout = 2 # max val, max idx
E_axis = "invalid axis"
params_type = Generic()
__props__ = ("axis",)
_f16_ok = True
def __init__(self, axis):
assert isinstance(axis, tuple | list)
self.axis = tuple(axis)
def get_params(self, node):
return self.axis
def make_node(self, x):
x = as_tensor_variable(x)
# Keep the original shapes for axes on which we do not perform the max/argmax.
all_axes = set(self.axis)
inputs = [x]
out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes)
outputs = [
tensor(dtype=x.type.dtype, shape=out_shape, name="max"),
tensor(dtype="int64", shape=out_shape, name="argmax"),
]
return Apply(self, inputs, outputs)
def perform(self, node, inp, outs):
x = inp[0]
axes = self.axis
max, max_idx = outs
if axes is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axes)
max[0] = _asarray(np.max(x, axes), dtype=node.outputs[0].dtype)
# Numpy does not support multiple axes for argmax
# Work around
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
# Not-reduced axes in front
transposed_x = np.transpose(x, np.concatenate((keep_axes, axes)))
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
# Otherwise reshape would complain citing float arg
new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64"))
reshaped_x = transposed_x.reshape(new_shape)
max_idx[0] = _asarray(np.argmax(reshaped_x, axis=-1), dtype="int64")
def c_code(self, node, name, inp, out, sub):
if len(self.axis) != 1 and len(self.axis) != node.inputs[0].ndim:
raise NotImplementedError(
"NumPy C-API can compute max and argmax only for 1 axis or for all axes."
)
x = inp[0]
axis = sub["params"]
max, argmax = out
fail = sub["fail"]
ret = """
#if PY_MAJOR_VERSION >= 3
#ifndef PyInt_AS_LONG
#define PyInt_AS_LONG PyLong_AS_LONG
#endif
#endif
int axis;
if (PyTuple_GET_SIZE(%(axis)s) == PyArray_NDIM(%(x)s)) {
axis = NPY_MAXDIMS;
} else if(PyTuple_GET_SIZE(%(axis)s) == 1) {
PyObject* axis_object = PyTuple_GET_ITEM(%(axis)s, 0);
axis = (int)PyInt_AS_LONG(axis_object);
if (axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)) {
PyErr_SetString(PyExc_ValueError,
"MaxAndArgmax: bad axis argument");
%(fail)s
}
} else {
PyErr_SetString(PyExc_NotImplementedError,
"MaxAndArgmax: NumPy C-API can compute max and argmax only for 1 axis or for all axes.");
%(fail)s
}
Py_CLEAR(%(max)s);
Py_CLEAR(%(argmax)s);//todo pass them as out parameter.
%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
if (%(max)s == NULL) {
%(fail)s;
}
if (!PyArray_CheckExact(%(max)s)) {
%(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
if(%(max)s == NULL){
%(fail)s;
}
}
%(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL);
if (%(argmax)s == NULL) {
Py_CLEAR(%(max)s);
%(fail)s;
}
if (!PyArray_CheckExact(%(argmax)s)) {
%(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
if(%(argmax)s == NULL){
%(fail)s;
}
}
if (PyArray_TYPE(%(argmax)s) != NPY_INT64) {
PyObject * tmp = PyArray_Cast(%(argmax)s, NPY_INT64);
if (NULL == tmp){
%(fail)s;
}
Py_DECREF(%(argmax)s);
%(argmax)s = (PyArrayObject*)tmp;
}
"""
return ret % locals()
def c_code_cache_version(self):
return (5,)
def infer_shape(self, fgraph, node, shapes):
ishape = shapes[0]
rval = tuple(
ishape[i]
for (i, b) in enumerate(node.inputs[0].type.broadcastable)
if i not in self.axis
)
return [rval, rval]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None, None]
if len(self.axis) != 1:
raise ValueError("R_op supported for arg_max only for one axis!")
if self.axis[0] > 1:
raise ValueError("R_op supported for arg_max only when axis is 0 or 1")
if inputs[0].ndim != 2:
raise ValueError("R_op supported for arg_max only when input is a matrix")
max_vals, max_pos = self.make_node(*inputs).outputs
if self.axis[0] == 0:
return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None]
else:
return [eval_points[0][arange(eval_points[0].shape[0]), max_pos], None]
def grad(self, inp, grads):
# The strict sense mathematical gradient of the maximum function is
# not calculated here for it is not defined at every point where some
# coordinates are identical. However, since the latter set has null
# Lebesgue measure, the result may be interpreted as weak gradient.
# @note: This function should work correctly for L{vector}s.
# (x, y), (gz, gw)
# gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
# gMax * dMax/dx + gArgMax * dArgMax/dx,
# gMax * dMax/daxis + gArgMax * dArgMax/daxis
# g_max has one less dimension than x, so you need to complete
# g_max to x's shape when axis=0 the broadcasting mechanism
# does it automatically
x = inp[0]
axis = as_tensor_variable(self.axis)
g_max, g_max_idx = grads
g_max_disconnected = isinstance(g_max.type, DisconnectedType)
g_max_idx_disconnected = isinstance(g_max_idx.type, DisconnectedType)
# if the op is totally disconnected, so are its inputs
if g_max_disconnected and g_max_idx_disconnected:
return [DisconnectedType()(), DisconnectedType()()]
# if the max is disconnected but the argmax is not,
# the gradient on its inputs is zero
if g_max_disconnected:
return [x.zeros_like()]
if NoneConst.equals(axis):
axis_ = list(range(x.ndim))
else:
axis_ = axis
xmax = max(x, axis_)
# Raise the g_max and xmax to the same number of dim as the input.
pattern = []
out_dim = 0
if NoneConst.equals(axis):
# We are taking the max/argmax over all dimensions.
axis = None
for i in range(x.ndim):
if axis is None or i in axis.data:
pattern.append("x")
else:
pattern.append(out_dim)
out_dim += 1
g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max)
xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax)
# Set the grad to the correct position.
g_x = eq(xmax_pad, x) * g_max_pad
return (g_x,)
class Argmax(COp): class Argmax(COp):
""" """
Calculate the argmax over a given axis or over all axes. Calculate the argmax over a given axis or over all axes.
...@@ -359,7 +158,7 @@ class Argmax(COp): ...@@ -359,7 +158,7 @@ class Argmax(COp):
def __init__(self, axis): def __init__(self, axis):
if axis is not None: if axis is not None:
axis = tuple(axis) axis = tuple(axis)
self.axis = tuple(axis) self.axis = axis
def get_params(self, node): def get_params(self, node):
if self.axis is not None and len(self.axis) == 1: if self.axis is not None and len(self.axis) == 1:
...@@ -395,7 +194,6 @@ class Argmax(COp): ...@@ -395,7 +194,6 @@ class Argmax(COp):
(max_idx,) = outs (max_idx,) = outs
if axes is None: if axes is None:
axes = tuple(range(x.ndim)) axes = tuple(range(x.ndim))
# Numpy does not support multiple axes for argmax # Numpy does not support multiple axes for argmax
# Work around # Work around
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64") keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
...@@ -403,7 +201,7 @@ class Argmax(COp): ...@@ -403,7 +201,7 @@ class Argmax(COp):
transposed_x = np.transpose(x, np.concatenate((keep_axes, axes))) transposed_x = np.transpose(x, np.concatenate((keep_axes, axes)))
kept_shape = transposed_x.shape[: len(keep_axes)] kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :] reduced_shape = transposed_x.shape[len(keep_axes) :]
new_shape = (*kept_shape, np.prod(reduced_shape)) new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64"))
reshaped_x = transposed_x.reshape(new_shape) reshaped_x = transposed_x.reshape(new_shape)
max_idx[0] = _asarray(np.argmax(reshaped_x, axis=-1), dtype="int64") max_idx[0] = _asarray(np.argmax(reshaped_x, axis=-1), dtype="int64")
...@@ -470,6 +268,9 @@ class Argmax(COp): ...@@ -470,6 +268,9 @@ class Argmax(COp):
) )
return [rval] return [rval]
def R_op(self, inputs, eval_points):
raise ValueError("Argmax is non-diifferentiable")
def grad(self, inp, grads): def grad(self, inp, grads):
(x,) = inp (x,) = inp
...@@ -477,7 +278,6 @@ class Argmax(COp): ...@@ -477,7 +278,6 @@ class Argmax(COp):
@_vectorize_node.register(Argmax) @_vectorize_node.register(Argmax)
@_vectorize_node.register(MaxAndArgmax)
def vectorize_argmax_node(op, node, batch_x): def vectorize_argmax_node(op, node, batch_x):
core_ndim = node.inputs[0].type.ndim core_ndim = node.inputs[0].type.ndim
batch_ndim = batch_x.type.ndim - core_ndim batch_ndim = batch_x.type.ndim - core_ndim
...@@ -595,12 +395,24 @@ def max_and_argmax(a, axis=None, keepdims=False): ...@@ -595,12 +395,24 @@ def max_and_argmax(a, axis=None, keepdims=False):
""" """
# Check axis and convert it to a Python list of integers. # Check axis and convert it to a Python list of integers.
# Axis will be used as an op param of MaxAndArgmax. # Axis will be used as an op param of Max and Argmax.
a = as_tensor_variable(a) a = as_tensor_variable(a)
is_axis_empty = False
if axis == ():
is_axis_empty = True
axis = check_and_normalize_axes(a, axis) axis = check_and_normalize_axes(a, axis)
if len(axis) == 0:
axis = list(range(a.type.ndim)) if len(axis) == 0 and not is_axis_empty:
out, argout = MaxAndArgmax(axis)(a) axis = None
out = Max(axis)(a)
if not is_axis_empty:
argout = Argmax(axis)(a)
else:
argout = zeros_like(a, dtype="int64")
if keepdims: if keepdims:
out = makeKeepDims(a, out, axis) out = makeKeepDims(a, out, axis)
...@@ -654,6 +466,74 @@ class Max(NonZeroDimsCAReduce): ...@@ -654,6 +466,74 @@ class Max(NonZeroDimsCAReduce):
axis = kwargs.get("axis", self.axis) axis = kwargs.get("axis", self.axis)
return type(self)(axis=axis) return type(self)(axis=axis)
def grad(self, inp, grads):
# The strict sense mathematical gradient of the maximum function is
# not calculated here for it is not defined at every point where some
# coordinates are identical. However, since the latter set has null
# Lebesgue measure, the result may be interpreted as weak gradient.
# @note: This function should work correctly for L{vector}s.
# (x, y), (gz, gw)
# gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
# gMax * dMax/dx + gArgMax * dArgMax/dx,
# gMax * dMax/daxis + gArgMax * dArgMax/daxis
# g_max has one less dimension than x, so you need to complete
# g_max to x's shape when axis=0 the broadcasting mechanism
# does it automatically
x = inp[0]
if self.axis is None:
self.axis = tuple(range(x.ndim))
axis = as_tensor_variable(self.axis)
(g_max,) = grads
g_max_disconnected = isinstance(g_max.type, DisconnectedType)
# if the op is totally disconnected, so are its inputs
if g_max_disconnected:
return [DisconnectedType()()]
# if NoneConst.equals(axis):
if axis is None:
axis_ = list(range(x.ndim))
else:
axis_ = axis
xmax = max(x, axis_)
# Raise the g_max and xmax to the same number of dim as the input.
pattern = []
out_dim = 0
if NoneConst.equals(axis):
# We are taking the max/argmax over all dimensions.
axis = None
for i in range(x.ndim):
if axis is None or i in axis.data:
pattern.append("x")
else:
pattern.append(out_dim)
out_dim += 1
g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max)
xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax)
# Set the grad to the correct position.
g_x = eq(xmax_pad, x) * g_max_pad
return (g_x,)
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None, None]
if len(self.axis) != 1:
raise ValueError("R_op supported for arg_max only for one axis!")
if self.axis[0] > 1:
raise ValueError("R_op supported for arg_max only when axis is 0 or 1")
if inputs[0].ndim != 2:
raise ValueError("R_op supported for arg_max only when input is a matrix")
max_pos = Argmax(self.axis).make_node(*inputs).outputs
# print(eval_points[0].eval())
if self.axis[0] == 0:
return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None]
else:
return [eval_points[0][arange(eval_points[0].shape[0]), max_pos], None]
class Min(NonZeroDimsCAReduce): class Min(NonZeroDimsCAReduce):
nfunc_spec = ("min", 1, 1) nfunc_spec = ("min", 1, 1)
...@@ -685,16 +565,6 @@ def max(x, axis=None, keepdims=False): ...@@ -685,16 +565,6 @@ def max(x, axis=None, keepdims=False):
We return an error as numpy when we reduce a dim with a shape of 0. We return an error as numpy when we reduce a dim with a shape of 0.
""" """
# We have a choice of implementing this call with the
# CAReduce op or the MaxAndArgmax op.
# MaxAndArgmax supports grad and Rop, so we prefer to use that.
# CAReduce is faster, but optimizations will replace MaxAndArgmax[0]
# with CAReduce at compile time, so at this stage the important
# thing is supporting all user interface features, not speed.
# Some cases can be implemented only with CAReduce.
out = max_and_argmax(x, axis)[0] out = max_and_argmax(x, axis)[0]
if keepdims: if keepdims:
......
...@@ -35,31 +35,12 @@ from pytensor import scalar as ps ...@@ -35,31 +35,12 @@ from pytensor import scalar as ps
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import Alloc, alloc, constant from pytensor.tensor.basic import Alloc, alloc, constant
from pytensor.tensor.elemwise import CAReduce, DimShuffle from pytensor.tensor.elemwise import CAReduce, DimShuffle
from pytensor.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg from pytensor.tensor.math import Min, neg
from pytensor.tensor.rewriting.basic import register_uncanonicalize from pytensor.tensor.rewriting.basic import register_uncanonicalize
from pytensor.tensor.shape import Reshape, reshape from pytensor.tensor.shape import Reshape, reshape
from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.subtensor import Subtensor
@register_uncanonicalize
@node_rewriter([MaxAndArgmax])
def local_max_and_argmax(fgraph, node):
"""
If we don't use the argmax, change it to a max only.
"""
if isinstance(node.op, MaxAndArgmax):
axis = node.op.axis
if len(fgraph.clients[node.outputs[1]]) == 0:
new = Max(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new)
return [new, None]
if len(fgraph.clients[node.outputs[0]]) == 0:
new = Argmax(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new)
return [None, new]
@register_uncanonicalize @register_uncanonicalize
@node_rewriter([neg]) @node_rewriter([neg])
def local_max_to_min(fgraph, node): def local_max_to_min(fgraph, node):
...@@ -71,7 +52,7 @@ def local_max_to_min(fgraph, node): ...@@ -71,7 +52,7 @@ def local_max_to_min(fgraph, node):
Notes Notes
----- -----
We don't need an opt that will do the reverse as by default We don't need an opt that will do the reverse as by default
the interface put only MaxAndArgmax into the graph. the interface put only Max into the graph.
""" """
if node.op == neg and node.inputs[0].owner: if node.op == neg and node.inputs[0].owner:
......
...@@ -11,7 +11,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery ...@@ -11,7 +11,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.link.jax import JAXLinker from pytensor.link.jax import JAXLinker
from pytensor.tensor import blas as pt_blas from pytensor.tensor import blas as pt_blas
from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor.math import MaxAndArgmax, maximum from pytensor.tensor.math import Argmax, Max, maximum
from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import max as pt_max
from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -88,7 +88,8 @@ def test_jax_basic_multiout_omni(): ...@@ -88,7 +88,8 @@ def test_jax_basic_multiout_omni():
# Test that a single output of a multi-output `Op` can be used as input to # Test that a single output of a multi-output `Op` can be used as input to
# another `Op` # another `Op`
x = dvector() x = dvector()
mx, amx = MaxAndArgmax([0])(x) mx = Max([0])(x)
amx = Argmax([0])(x)
out = mx * amx out = mx * amx
out_fg = FunctionGraph([x], [out]) out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.r_[1, 2]]) compare_jax_and_py(out_fg, [np.r_[1, 2]])
......
...@@ -552,8 +552,53 @@ def test_LogSoftmax(x, axis, exc): ...@@ -552,8 +552,53 @@ def test_LogSoftmax(x, axis, exc):
), ),
], ],
) )
def test_MaxAndArgmax(x, axes, exc): def test_Max(x, axes, exc):
g = ptm.MaxAndArgmax(axes)(x) g = ptm.Max(axes)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
)
@pytest.mark.parametrize(
"x, axes, exc",
[
(
set_test_value(pt.dscalar(), np.array(0.0, dtype="float64")),
[],
None,
),
(
set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")),
[0],
None,
),
(
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0],
None,
),
(
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0, 1],
None,
),
],
)
def test_Argmax(x, axes, exc):
g = ptm.Argmax(axes)(x)
if isinstance(g, list): if isinstance(g, list):
g_fg = FunctionGraph(outputs=g) g_fg = FunctionGraph(outputs=g)
......
...@@ -38,7 +38,7 @@ from pytensor.tensor.blockwise import Blockwise ...@@ -38,7 +38,7 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import ( from pytensor.tensor.math import (
Dot, Dot,
MaxAndArgmax, Max,
Prod, Prod,
Sum, Sum,
_conj, _conj,
...@@ -3730,8 +3730,8 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None): ...@@ -3730,8 +3730,8 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None):
return return
# In mode FAST_COMPILE, the rewrites don't replace the # In mode FAST_COMPILE, the rewrites don't replace the
# `MaxAndArgmax` `Op`. # `Max` `Op`.
if isinstance(node.op, MaxAndArgmax): if isinstance(node.op, Max):
return return
# TODO FIXME: Refactor this test so that it makes a direct assertion and # TODO FIXME: Refactor this test so that it makes a direct assertion and
......
...@@ -9,8 +9,6 @@ from pytensor.graph.fg import FunctionGraph ...@@ -9,8 +9,6 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import out2in from pytensor.graph.rewriting.basic import out2in
from pytensor.link.basic import PerformLinker from pytensor.link.basic import PerformLinker
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import MaxAndArgmax, max_and_argmax
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import min as pt_min from pytensor.tensor.math import min as pt_min
from pytensor.tensor.rewriting.uncanonicalize import ( from pytensor.tensor.rewriting.uncanonicalize import (
local_alloc_dimshuffle, local_alloc_dimshuffle,
...@@ -23,67 +21,12 @@ from pytensor.tensor.type import dtensor4, iscalar, matrix, tensor, vector ...@@ -23,67 +21,12 @@ from pytensor.tensor.type import dtensor4, iscalar, matrix, tensor, vector
from tests.link.test_link import make_function from tests.link.test_link import make_function
class TestMaxAndArgmax:
def test_optimization(self):
# If we use only the max output, we should replace this op with
# a faster one.
mode = pytensor.compile.mode.get_default_mode().including(
"canonicalize", "fast_run"
)
for axis in [0, 1, -1]:
n = matrix()
f = function([n], max_and_argmax(n, axis)[0], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce)
f = function([n], max_and_argmax(n, axis), mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, MaxAndArgmax)
class TestMinMax: class TestMinMax:
def setup_method(self): def setup_method(self):
self.mode = pytensor.compile.mode.get_default_mode().including( self.mode = pytensor.compile.mode.get_default_mode().including(
"canonicalize", "fast_run" "canonicalize", "fast_run"
) )
def test_optimization_max(self):
data = np.asarray(np.random.random((2, 3)), dtype=config.floatX)
n = matrix()
for axis in [0, 1, -1]:
f = function([n], pt_max(n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce)
f(data)
f = function([n], pt_max(-n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert isinstance(topo[0].op, Elemwise)
assert isinstance(topo[0].op.scalar_op, ps.Neg)
assert isinstance(topo[1].op, CAReduce)
f(data)
f = function([n], -pt_max(n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert isinstance(topo[0].op, CAReduce)
assert isinstance(topo[1].op, Elemwise)
assert isinstance(topo[1].op.scalar_op, ps.Neg)
f(data)
f = function([n], -pt_max(-n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce) # min
f(data)
def test_optimization_min(self): def test_optimization_min(self):
data = np.asarray(np.random.random((2, 3)), dtype=config.floatX) data = np.asarray(np.random.random((2, 3)), dtype=config.floatX)
n = matrix() n = matrix()
......
...@@ -11,6 +11,7 @@ import scipy.special ...@@ -11,6 +11,7 @@ import scipy.special
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
from scipy.special import logsumexp as scipy_logsumexp from scipy.special import logsumexp as scipy_logsumexp
import pytensor
import pytensor.scalar as ps import pytensor.scalar as ps
from pytensor.compile.debugmode import DebugMode from pytensor.compile.debugmode import DebugMode
from pytensor.compile.function import function from pytensor.compile.function import function
...@@ -39,7 +40,7 @@ from pytensor.tensor.elemwise import CAReduce, Elemwise ...@@ -39,7 +40,7 @@ from pytensor.tensor.elemwise import CAReduce, Elemwise
from pytensor.tensor.math import ( from pytensor.tensor.math import (
Argmax, Argmax,
Dot, Dot,
MaxAndArgmax, Max,
Mean, Mean,
Prod, Prod,
ProdWithoutZeros, ProdWithoutZeros,
...@@ -760,11 +761,12 @@ def test_isnan(): ...@@ -760,11 +761,12 @@ def test_isnan():
class TestMaxAndArgmax: class TestMaxAndArgmax:
def setup_method(self): def setup_method(self):
MaxAndArgmax.debug = 0 Max.debug = 0
Argmax.debug = 0
def test_basic(self): def test_basic(self):
n = as_tensor_variable(5.0) n = as_tensor_variable(5)
v, i = eval_outputs(max_and_argmax(n)) v, i = eval_outputs(max_and_argmax(n, axis=()))
assert v == 5.0 assert v == 5.0
assert i == 0 assert i == 0
assert i.dtype == "int64" assert i.dtype == "int64"
...@@ -1030,31 +1032,45 @@ class TestMaxAndArgmax: ...@@ -1030,31 +1032,45 @@ class TestMaxAndArgmax:
x = tensor(shape=(5, 5, 5, 5)) x = tensor(shape=(5, 5, 5, 5))
batch_x = tensor(shape=(3, 5, 5, 5, 5)) batch_x = tensor(shape=(3, 5, 5, 5, 5))
# Test MaxAndArgmax argmax_x = argmax(x, axis=core_axis)
max_x, argmax_x = max_and_argmax(x, axis=core_axis)
node = max_x.owner
assert isinstance(node.op, MaxAndArgmax)
new_node = vectorize_node(node, batch_x)
assert isinstance(new_node.op, MaxAndArgmax)
assert new_node.op.axis == batch_axis
# Test Argmax arg_max_node = argmax_x.owner
# Argmax is not user-facing, so we have to create it manually new_node = vectorize_node(arg_max_node, batch_x)
node = Argmax(axis=node.op.axis).make_node(x)
new_node = vectorize_node(node, batch_x)
assert isinstance(new_node.op, Argmax) assert isinstance(new_node.op, Argmax)
assert new_node.op.axis == batch_axis assert new_node.op.axis == batch_axis
def test_max_empty_axis(self):
x = np.random.normal(size=(2, 3, 5, 7))
axis = ()
non_axis = tuple(i for i in range(x.ndim) if i not in axis)
shape_axis = tuple(x.shape[dim] for dim in axis)
shape_non_axis = tuple(x.shape[dim] for dim in non_axis)
x_transposed = x.transpose(*axis, *non_axis)
x_axis_raveled = x_transposed.reshape(
np.prod(shape_axis, dtype=int), np.prod(shape_non_axis, dtype=int)
)
max_x = max_and_argmax(x, axis=axis)[0].eval()
argmax_x = max_and_argmax(x, axis=axis)[1].eval()
raveled_max = x_axis_raveled[
argmax_x.ravel(), np.arange(np.prod(shape_non_axis, dtype=int))
]
indirect_max = raveled_max.reshape(shape_non_axis)
np.testing.assert_allclose(max_x, x.max(axis=axis))
np.testing.assert_allclose(indirect_max, x.max(axis=axis))
class TestArgminArgmax: class TestArgminArgmax:
def setup_method(self): def setup_method(self):
MaxAndArgmax.debug = 0 Argmax.debug = 0
def test_scalar(self): def test_scalar(self):
for fct in [argmin, argmax]: for fct in [argmin, argmax]:
n = as_tensor_variable(5.0) n = as_tensor_variable([5.0])
i = eval_outputs(fct(n)) i = eval_outputs(fct(n))
assert i == 0 assert i == 0
v = eval_outputs(fct(n).shape) v = eval_outputs(fct(n).shape)
...@@ -1212,7 +1228,7 @@ class TestArgminArgmax: ...@@ -1212,7 +1228,7 @@ class TestArgminArgmax:
class TestMinMax: class TestMinMax:
def setup_method(self): def setup_method(self):
MaxAndArgmax.debug = 0 Max.debug = 0
def test_scalar(self): def test_scalar(self):
for fct in [max, min]: for fct in [max, min]:
...@@ -1379,6 +1395,7 @@ class TestMinMax: ...@@ -1379,6 +1395,7 @@ class TestMinMax:
# check_grad_max(data, eval_outputs(grad(max_and_argmax(n, # check_grad_max(data, eval_outputs(grad(max_and_argmax(n,
# axis=1)[0], n)),axis=1) # axis=1)[0], n)),axis=1)
@pytest.mark.xfail(reason="Fails due to #770")
def test_uint(self): def test_uint(self):
for dtype in ("uint8", "uint16", "uint32", "uint64"): for dtype in ("uint8", "uint16", "uint32", "uint64"):
itype = np.iinfo(dtype) itype = np.iinfo(dtype)
...@@ -1404,6 +1421,14 @@ class TestMinMax: ...@@ -1404,6 +1421,14 @@ class TestMinMax:
assert np.all(i) assert np.all(i)
def test_MaxAndArgmax_deprecated():
with pytest.raises(
AttributeError,
match="The class `MaxandArgmax` has been deprecated. Call `Max` and `Argmax` seperately as an alternative.",
):
pytensor.tensor.math.MaxAndArgmax
rng = np.random.default_rng(seed=utt.fetch_seed()) rng = np.random.default_rng(seed=utt.fetch_seed())
TestClip1 = makeTester( TestClip1 = makeTester(
name="ClipTester", name="ClipTester",
...@@ -2572,27 +2597,50 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2572,27 +2597,50 @@ class TestInferShape(utt.InferShapeTester):
[adtens3], [Mean(aiscal_val)(adtens3)], [adtens3_val], Mean [adtens3], [Mean(aiscal_val)(adtens3)], [adtens3_val], Mean
) )
def test_MaxAndArgmax(self): def test_Max(self):
adtens3 = dtensor3()
adtens3_val = random(4, 5, 3)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, None), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 0), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 1), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 2), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], Max
)
def test_Argmax(self):
adtens3 = dtensor3() adtens3 = dtensor3()
adtens3_val = random(4, 5, 3) adtens3_val = random(4, 5, 3)
self._compile_and_check( self._compile_and_check(
[adtens3], max_and_argmax(adtens3, None), [adtens3_val], MaxAndArgmax [adtens3], max_and_argmax(adtens3, None), [adtens3_val], Argmax
) )
self._compile_and_check( self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 0), [adtens3_val], MaxAndArgmax [adtens3], max_and_argmax(adtens3, 0), [adtens3_val], Argmax
) )
self._compile_and_check( self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 1), [adtens3_val], MaxAndArgmax [adtens3], max_and_argmax(adtens3, 1), [adtens3_val], Argmax
) )
self._compile_and_check( self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 2), [adtens3_val], MaxAndArgmax [adtens3], max_and_argmax(adtens3, 2), [adtens3_val], Argmax
) )
self._compile_and_check( self._compile_and_check(
[adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], MaxAndArgmax [adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], Argmax
) )
def test_Dot(self): def test_Dot(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论