Unverified 提交 e6e6d69f authored 作者: Dhruvanshu-Joshi's avatar Dhruvanshu-Joshi 提交者: GitHub

Break MaxandArgmax Op to seperate TensorMax Op and Argmax Op (#731)

* Break MaxandArgmax to TensorMax and Argmax seperately * XFAIL pytensor tests for uint64 data type * Deprecate and raise AttributeError for MaxAndArgmax
上级 dbe0e09a
......@@ -477,7 +477,8 @@ acceptable_ops = (
Reshape,
Unbroadcast,
pt.math.Dot,
pt.math.MaxAndArgmax,
pt.math.Max,
pt.math.Argmax,
pt.subtensor.Subtensor,
pt.subtensor.IncSubtensor,
pt.basic.Alloc,
......
......@@ -2,7 +2,7 @@ import jax.numpy as jnp
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot, MaxAndArgmax
from pytensor.tensor.math import Argmax, Dot, Max
from pytensor.tensor.nlinalg import (
SVD,
Det,
......@@ -104,18 +104,28 @@ def jax_funcify_BatchedDot(op, **kwargs):
return batched_dot
@jax_funcify.register(MaxAndArgmax)
def jax_funcify_MaxAndArgmax(op, **kwargs):
@jax_funcify.register(Max)
def jax_funcify_Max(op, **kwargs):
axis = op.axis
def maxandargmax(x, axis=axis):
def max(x):
max_res = jnp.max(x, axis)
return max_res
return max
@jax_funcify.register(Argmax)
def jax_funcify_Argmax(op, **kwargs):
axis = op.axis
def argmax(x):
if axis is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axis)
max_res = jnp.max(x, axis)
# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = jnp.array(
......@@ -138,6 +148,6 @@ def jax_funcify_MaxAndArgmax(op, **kwargs):
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
return max_res, max_idx_res
return max_idx_res
return maxandargmax
return argmax
......@@ -44,7 +44,7 @@ from pytensor.scalar.basic import (
)
from pytensor.scalar.basic import add as add_as
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum
from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from pytensor.tensor.type import scalar
......@@ -827,8 +827,8 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
return log_softmax
@numba_funcify.register(MaxAndArgmax)
def numba_funcify_MaxAndArgmax(op, node, **kwargs):
@numba_funcify.register(Argmax)
def numba_funcify_Argmax(op, node, **kwargs):
axis = op.axis
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
......@@ -838,8 +838,8 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
if x_ndim == 0:
@numba_basic.numba_njit(inline="always")
def maxandargmax(x):
return x, 0
def argmax(x):
return 0
else:
axes = tuple(int(ax) for ax in axis)
......@@ -848,20 +848,6 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
# work-around
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
reduce_max_py_fn = create_multiaxis_reducer(
scalar_maximum,
-np.inf,
axes,
x_ndim,
x_dtype,
return_scalar=False,
)
reduce_max = jit_compile_reducer(
Apply(node.op, node.inputs, [node.outputs[0].clone()]),
reduce_max_py_fn,
reduce_to_scalar=False,
)
reduced_x_ndim = x_ndim - len(axes) + 1
argmax_axis = create_axis_apply_fn(
np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
......@@ -872,9 +858,7 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
sl2 = slice(len(keep_axes), None)
@numba_basic.numba_njit
def maxandargmax(x):
max_res = reduce_max(x)
def argmax(x):
# Not-reduced axes in front
transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order))
kept_shape = transposed_x.shape[sl1]
......@@ -890,6 +874,6 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
max_idx_res = argmax_axis(reshaped_x)
return max_res, max_idx_res
return max_idx_res
return maxandargmax
return argmax
......@@ -14,7 +14,6 @@ from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.link.c.type import Generic
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import pprint
from pytensor.raise_op import Assert
......@@ -29,6 +28,7 @@ from pytensor.tensor.basic import (
constant,
stack,
switch,
zeros_like,
)
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.elemwise import (
......@@ -107,6 +107,14 @@ else:
float64_atol = 1e-8
def __getattr__(name):
if name == "MaxAndArgmax":
raise AttributeError(
"The class `MaxandArgmax` has been deprecated. Call `Max` and `Argmax` seperately as an alternative."
)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def _get_atol_rtol(a, b):
tiny = ("float16",)
narrow = ("float32", "complex64")
......@@ -134,215 +142,6 @@ def _allclose(a, b, rtol=None, atol=None):
return np.allclose(a, b, atol=atol_, rtol=rtol_)
class MaxAndArgmax(COp):
"""
Calculate the max and argmax over a given axis or over all axes.
"""
nin = 2 # tensor, axis
nout = 2 # max val, max idx
E_axis = "invalid axis"
params_type = Generic()
__props__ = ("axis",)
_f16_ok = True
def __init__(self, axis):
assert isinstance(axis, tuple | list)
self.axis = tuple(axis)
def get_params(self, node):
return self.axis
def make_node(self, x):
x = as_tensor_variable(x)
# Keep the original shapes for axes on which we do not perform the max/argmax.
all_axes = set(self.axis)
inputs = [x]
out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes)
outputs = [
tensor(dtype=x.type.dtype, shape=out_shape, name="max"),
tensor(dtype="int64", shape=out_shape, name="argmax"),
]
return Apply(self, inputs, outputs)
def perform(self, node, inp, outs):
x = inp[0]
axes = self.axis
max, max_idx = outs
if axes is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axes)
max[0] = _asarray(np.max(x, axes), dtype=node.outputs[0].dtype)
# Numpy does not support multiple axes for argmax
# Work around
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
# Not-reduced axes in front
transposed_x = np.transpose(x, np.concatenate((keep_axes, axes)))
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
# Otherwise reshape would complain citing float arg
new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64"))
reshaped_x = transposed_x.reshape(new_shape)
max_idx[0] = _asarray(np.argmax(reshaped_x, axis=-1), dtype="int64")
def c_code(self, node, name, inp, out, sub):
if len(self.axis) != 1 and len(self.axis) != node.inputs[0].ndim:
raise NotImplementedError(
"NumPy C-API can compute max and argmax only for 1 axis or for all axes."
)
x = inp[0]
axis = sub["params"]
max, argmax = out
fail = sub["fail"]
ret = """
#if PY_MAJOR_VERSION >= 3
#ifndef PyInt_AS_LONG
#define PyInt_AS_LONG PyLong_AS_LONG
#endif
#endif
int axis;
if (PyTuple_GET_SIZE(%(axis)s) == PyArray_NDIM(%(x)s)) {
axis = NPY_MAXDIMS;
} else if(PyTuple_GET_SIZE(%(axis)s) == 1) {
PyObject* axis_object = PyTuple_GET_ITEM(%(axis)s, 0);
axis = (int)PyInt_AS_LONG(axis_object);
if (axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)) {
PyErr_SetString(PyExc_ValueError,
"MaxAndArgmax: bad axis argument");
%(fail)s
}
} else {
PyErr_SetString(PyExc_NotImplementedError,
"MaxAndArgmax: NumPy C-API can compute max and argmax only for 1 axis or for all axes.");
%(fail)s
}
Py_CLEAR(%(max)s);
Py_CLEAR(%(argmax)s);//todo pass them as out parameter.
%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
if (%(max)s == NULL) {
%(fail)s;
}
if (!PyArray_CheckExact(%(max)s)) {
%(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
if(%(max)s == NULL){
%(fail)s;
}
}
%(argmax)s = (PyArrayObject*)PyArray_ArgMax(%(x)s, axis, NULL);
if (%(argmax)s == NULL) {
Py_CLEAR(%(max)s);
%(fail)s;
}
if (!PyArray_CheckExact(%(argmax)s)) {
%(argmax)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(argmax)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
if(%(argmax)s == NULL){
%(fail)s;
}
}
if (PyArray_TYPE(%(argmax)s) != NPY_INT64) {
PyObject * tmp = PyArray_Cast(%(argmax)s, NPY_INT64);
if (NULL == tmp){
%(fail)s;
}
Py_DECREF(%(argmax)s);
%(argmax)s = (PyArrayObject*)tmp;
}
"""
return ret % locals()
def c_code_cache_version(self):
return (5,)
def infer_shape(self, fgraph, node, shapes):
ishape = shapes[0]
rval = tuple(
ishape[i]
for (i, b) in enumerate(node.inputs[0].type.broadcastable)
if i not in self.axis
)
return [rval, rval]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None, None]
if len(self.axis) != 1:
raise ValueError("R_op supported for arg_max only for one axis!")
if self.axis[0] > 1:
raise ValueError("R_op supported for arg_max only when axis is 0 or 1")
if inputs[0].ndim != 2:
raise ValueError("R_op supported for arg_max only when input is a matrix")
max_vals, max_pos = self.make_node(*inputs).outputs
if self.axis[0] == 0:
return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None]
else:
return [eval_points[0][arange(eval_points[0].shape[0]), max_pos], None]
def grad(self, inp, grads):
# The strict sense mathematical gradient of the maximum function is
# not calculated here for it is not defined at every point where some
# coordinates are identical. However, since the latter set has null
# Lebesgue measure, the result may be interpreted as weak gradient.
# @note: This function should work correctly for L{vector}s.
# (x, y), (gz, gw)
# gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
# gMax * dMax/dx + gArgMax * dArgMax/dx,
# gMax * dMax/daxis + gArgMax * dArgMax/daxis
# g_max has one less dimension than x, so you need to complete
# g_max to x's shape when axis=0 the broadcasting mechanism
# does it automatically
x = inp[0]
axis = as_tensor_variable(self.axis)
g_max, g_max_idx = grads
g_max_disconnected = isinstance(g_max.type, DisconnectedType)
g_max_idx_disconnected = isinstance(g_max_idx.type, DisconnectedType)
# if the op is totally disconnected, so are its inputs
if g_max_disconnected and g_max_idx_disconnected:
return [DisconnectedType()(), DisconnectedType()()]
# if the max is disconnected but the argmax is not,
# the gradient on its inputs is zero
if g_max_disconnected:
return [x.zeros_like()]
if NoneConst.equals(axis):
axis_ = list(range(x.ndim))
else:
axis_ = axis
xmax = max(x, axis_)
# Raise the g_max and xmax to the same number of dim as the input.
pattern = []
out_dim = 0
if NoneConst.equals(axis):
# We are taking the max/argmax over all dimensions.
axis = None
for i in range(x.ndim):
if axis is None or i in axis.data:
pattern.append("x")
else:
pattern.append(out_dim)
out_dim += 1
g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max)
xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax)
# Set the grad to the correct position.
g_x = eq(xmax_pad, x) * g_max_pad
return (g_x,)
class Argmax(COp):
"""
Calculate the argmax over a given axis or over all axes.
......@@ -359,7 +158,7 @@ class Argmax(COp):
def __init__(self, axis):
if axis is not None:
axis = tuple(axis)
self.axis = tuple(axis)
self.axis = axis
def get_params(self, node):
if self.axis is not None and len(self.axis) == 1:
......@@ -395,7 +194,6 @@ class Argmax(COp):
(max_idx,) = outs
if axes is None:
axes = tuple(range(x.ndim))
# Numpy does not support multiple axes for argmax
# Work around
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
......@@ -403,7 +201,7 @@ class Argmax(COp):
transposed_x = np.transpose(x, np.concatenate((keep_axes, axes)))
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]
new_shape = (*kept_shape, np.prod(reduced_shape))
new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64"))
reshaped_x = transposed_x.reshape(new_shape)
max_idx[0] = _asarray(np.argmax(reshaped_x, axis=-1), dtype="int64")
......@@ -470,6 +268,9 @@ class Argmax(COp):
)
return [rval]
def R_op(self, inputs, eval_points):
raise ValueError("Argmax is non-diifferentiable")
def grad(self, inp, grads):
(x,) = inp
......@@ -477,7 +278,6 @@ class Argmax(COp):
@_vectorize_node.register(Argmax)
@_vectorize_node.register(MaxAndArgmax)
def vectorize_argmax_node(op, node, batch_x):
core_ndim = node.inputs[0].type.ndim
batch_ndim = batch_x.type.ndim - core_ndim
......@@ -595,12 +395,24 @@ def max_and_argmax(a, axis=None, keepdims=False):
"""
# Check axis and convert it to a Python list of integers.
# Axis will be used as an op param of MaxAndArgmax.
# Axis will be used as an op param of Max and Argmax.
a = as_tensor_variable(a)
is_axis_empty = False
if axis == ():
is_axis_empty = True
axis = check_and_normalize_axes(a, axis)
if len(axis) == 0:
axis = list(range(a.type.ndim))
out, argout = MaxAndArgmax(axis)(a)
if len(axis) == 0 and not is_axis_empty:
axis = None
out = Max(axis)(a)
if not is_axis_empty:
argout = Argmax(axis)(a)
else:
argout = zeros_like(a, dtype="int64")
if keepdims:
out = makeKeepDims(a, out, axis)
......@@ -654,6 +466,74 @@ class Max(NonZeroDimsCAReduce):
axis = kwargs.get("axis", self.axis)
return type(self)(axis=axis)
def grad(self, inp, grads):
# The strict sense mathematical gradient of the maximum function is
# not calculated here for it is not defined at every point where some
# coordinates are identical. However, since the latter set has null
# Lebesgue measure, the result may be interpreted as weak gradient.
# @note: This function should work correctly for L{vector}s.
# (x, y), (gz, gw)
# gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
# gMax * dMax/dx + gArgMax * dArgMax/dx,
# gMax * dMax/daxis + gArgMax * dArgMax/daxis
# g_max has one less dimension than x, so you need to complete
# g_max to x's shape when axis=0 the broadcasting mechanism
# does it automatically
x = inp[0]
if self.axis is None:
self.axis = tuple(range(x.ndim))
axis = as_tensor_variable(self.axis)
(g_max,) = grads
g_max_disconnected = isinstance(g_max.type, DisconnectedType)
# if the op is totally disconnected, so are its inputs
if g_max_disconnected:
return [DisconnectedType()()]
# if NoneConst.equals(axis):
if axis is None:
axis_ = list(range(x.ndim))
else:
axis_ = axis
xmax = max(x, axis_)
# Raise the g_max and xmax to the same number of dim as the input.
pattern = []
out_dim = 0
if NoneConst.equals(axis):
# We are taking the max/argmax over all dimensions.
axis = None
for i in range(x.ndim):
if axis is None or i in axis.data:
pattern.append("x")
else:
pattern.append(out_dim)
out_dim += 1
g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max)
xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax)
# Set the grad to the correct position.
g_x = eq(xmax_pad, x) * g_max_pad
return (g_x,)
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None, None]
if len(self.axis) != 1:
raise ValueError("R_op supported for arg_max only for one axis!")
if self.axis[0] > 1:
raise ValueError("R_op supported for arg_max only when axis is 0 or 1")
if inputs[0].ndim != 2:
raise ValueError("R_op supported for arg_max only when input is a matrix")
max_pos = Argmax(self.axis).make_node(*inputs).outputs
# print(eval_points[0].eval())
if self.axis[0] == 0:
return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None]
else:
return [eval_points[0][arange(eval_points[0].shape[0]), max_pos], None]
class Min(NonZeroDimsCAReduce):
nfunc_spec = ("min", 1, 1)
......@@ -685,16 +565,6 @@ def max(x, axis=None, keepdims=False):
We return an error as numpy when we reduce a dim with a shape of 0.
"""
# We have a choice of implementing this call with the
# CAReduce op or the MaxAndArgmax op.
# MaxAndArgmax supports grad and Rop, so we prefer to use that.
# CAReduce is faster, but optimizations will replace MaxAndArgmax[0]
# with CAReduce at compile time, so at this stage the important
# thing is supporting all user interface features, not speed.
# Some cases can be implemented only with CAReduce.
out = max_and_argmax(x, axis)[0]
if keepdims:
......
......@@ -35,31 +35,12 @@ from pytensor import scalar as ps
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import Alloc, alloc, constant
from pytensor.tensor.elemwise import CAReduce, DimShuffle
from pytensor.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg
from pytensor.tensor.math import Min, neg
from pytensor.tensor.rewriting.basic import register_uncanonicalize
from pytensor.tensor.shape import Reshape, reshape
from pytensor.tensor.subtensor import Subtensor
@register_uncanonicalize
@node_rewriter([MaxAndArgmax])
def local_max_and_argmax(fgraph, node):
"""
If we don't use the argmax, change it to a max only.
"""
if isinstance(node.op, MaxAndArgmax):
axis = node.op.axis
if len(fgraph.clients[node.outputs[1]]) == 0:
new = Max(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new)
return [new, None]
if len(fgraph.clients[node.outputs[0]]) == 0:
new = Argmax(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new)
return [None, new]
@register_uncanonicalize
@node_rewriter([neg])
def local_max_to_min(fgraph, node):
......@@ -71,7 +52,7 @@ def local_max_to_min(fgraph, node):
Notes
-----
We don't need an opt that will do the reverse as by default
the interface put only MaxAndArgmax into the graph.
the interface put only Max into the graph.
"""
if node.op == neg and node.inputs[0].owner:
......
......@@ -11,7 +11,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.link.jax import JAXLinker
from pytensor.tensor import blas as pt_blas
from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor.math import MaxAndArgmax, maximum
from pytensor.tensor.math import Argmax, Max, maximum
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -88,7 +88,8 @@ def test_jax_basic_multiout_omni():
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
x = dvector()
mx, amx = MaxAndArgmax([0])(x)
mx = Max([0])(x)
amx = Argmax([0])(x)
out = mx * amx
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.r_[1, 2]])
......
......@@ -552,8 +552,53 @@ def test_LogSoftmax(x, axis, exc):
),
],
)
def test_MaxAndArgmax(x, axes, exc):
g = ptm.MaxAndArgmax(axes)(x)
def test_Max(x, axes, exc):
g = ptm.Max(axes)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
)
@pytest.mark.parametrize(
"x, axes, exc",
[
(
set_test_value(pt.dscalar(), np.array(0.0, dtype="float64")),
[],
None,
),
(
set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")),
[0],
None,
),
(
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0],
None,
),
(
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0, 1],
None,
),
],
)
def test_Argmax(x, axes, exc):
g = ptm.Argmax(axes)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
......
......@@ -38,7 +38,7 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import (
Dot,
MaxAndArgmax,
Max,
Prod,
Sum,
_conj,
......@@ -3730,8 +3730,8 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None):
return
# In mode FAST_COMPILE, the rewrites don't replace the
# `MaxAndArgmax` `Op`.
if isinstance(node.op, MaxAndArgmax):
# `Max` `Op`.
if isinstance(node.op, Max):
return
# TODO FIXME: Refactor this test so that it makes a direct assertion and
......
......@@ -9,8 +9,6 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import out2in
from pytensor.link.basic import PerformLinker
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import MaxAndArgmax, max_and_argmax
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import min as pt_min
from pytensor.tensor.rewriting.uncanonicalize import (
local_alloc_dimshuffle,
......@@ -23,67 +21,12 @@ from pytensor.tensor.type import dtensor4, iscalar, matrix, tensor, vector
from tests.link.test_link import make_function
class TestMaxAndArgmax:
def test_optimization(self):
# If we use only the max output, we should replace this op with
# a faster one.
mode = pytensor.compile.mode.get_default_mode().including(
"canonicalize", "fast_run"
)
for axis in [0, 1, -1]:
n = matrix()
f = function([n], max_and_argmax(n, axis)[0], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce)
f = function([n], max_and_argmax(n, axis), mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, MaxAndArgmax)
class TestMinMax:
def setup_method(self):
self.mode = pytensor.compile.mode.get_default_mode().including(
"canonicalize", "fast_run"
)
def test_optimization_max(self):
data = np.asarray(np.random.random((2, 3)), dtype=config.floatX)
n = matrix()
for axis in [0, 1, -1]:
f = function([n], pt_max(n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce)
f(data)
f = function([n], pt_max(-n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert isinstance(topo[0].op, Elemwise)
assert isinstance(topo[0].op.scalar_op, ps.Neg)
assert isinstance(topo[1].op, CAReduce)
f(data)
f = function([n], -pt_max(n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert isinstance(topo[0].op, CAReduce)
assert isinstance(topo[1].op, Elemwise)
assert isinstance(topo[1].op.scalar_op, ps.Neg)
f(data)
f = function([n], -pt_max(-n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce) # min
f(data)
def test_optimization_min(self):
data = np.asarray(np.random.random((2, 3)), dtype=config.floatX)
n = matrix()
......
......@@ -11,6 +11,7 @@ import scipy.special
from numpy.testing import assert_array_equal
from scipy.special import logsumexp as scipy_logsumexp
import pytensor
import pytensor.scalar as ps
from pytensor.compile.debugmode import DebugMode
from pytensor.compile.function import function
......@@ -39,7 +40,7 @@ from pytensor.tensor.elemwise import CAReduce, Elemwise
from pytensor.tensor.math import (
Argmax,
Dot,
MaxAndArgmax,
Max,
Mean,
Prod,
ProdWithoutZeros,
......@@ -760,11 +761,12 @@ def test_isnan():
class TestMaxAndArgmax:
def setup_method(self):
MaxAndArgmax.debug = 0
Max.debug = 0
Argmax.debug = 0
def test_basic(self):
n = as_tensor_variable(5.0)
v, i = eval_outputs(max_and_argmax(n))
n = as_tensor_variable(5)
v, i = eval_outputs(max_and_argmax(n, axis=()))
assert v == 5.0
assert i == 0
assert i.dtype == "int64"
......@@ -1030,31 +1032,45 @@ class TestMaxAndArgmax:
x = tensor(shape=(5, 5, 5, 5))
batch_x = tensor(shape=(3, 5, 5, 5, 5))
# Test MaxAndArgmax
max_x, argmax_x = max_and_argmax(x, axis=core_axis)
node = max_x.owner
assert isinstance(node.op, MaxAndArgmax)
new_node = vectorize_node(node, batch_x)
assert isinstance(new_node.op, MaxAndArgmax)
assert new_node.op.axis == batch_axis
argmax_x = argmax(x, axis=core_axis)
# Test Argmax
# Argmax is not user-facing, so we have to create it manually
node = Argmax(axis=node.op.axis).make_node(x)
arg_max_node = argmax_x.owner
new_node = vectorize_node(arg_max_node, batch_x)
new_node = vectorize_node(node, batch_x)
assert isinstance(new_node.op, Argmax)
assert new_node.op.axis == batch_axis
def test_max_empty_axis(self):
x = np.random.normal(size=(2, 3, 5, 7))
axis = ()
non_axis = tuple(i for i in range(x.ndim) if i not in axis)
shape_axis = tuple(x.shape[dim] for dim in axis)
shape_non_axis = tuple(x.shape[dim] for dim in non_axis)
x_transposed = x.transpose(*axis, *non_axis)
x_axis_raveled = x_transposed.reshape(
np.prod(shape_axis, dtype=int), np.prod(shape_non_axis, dtype=int)
)
max_x = max_and_argmax(x, axis=axis)[0].eval()
argmax_x = max_and_argmax(x, axis=axis)[1].eval()
raveled_max = x_axis_raveled[
argmax_x.ravel(), np.arange(np.prod(shape_non_axis, dtype=int))
]
indirect_max = raveled_max.reshape(shape_non_axis)
np.testing.assert_allclose(max_x, x.max(axis=axis))
np.testing.assert_allclose(indirect_max, x.max(axis=axis))
class TestArgminArgmax:
def setup_method(self):
MaxAndArgmax.debug = 0
Argmax.debug = 0
def test_scalar(self):
for fct in [argmin, argmax]:
n = as_tensor_variable(5.0)
n = as_tensor_variable([5.0])
i = eval_outputs(fct(n))
assert i == 0
v = eval_outputs(fct(n).shape)
......@@ -1212,7 +1228,7 @@ class TestArgminArgmax:
class TestMinMax:
def setup_method(self):
MaxAndArgmax.debug = 0
Max.debug = 0
def test_scalar(self):
for fct in [max, min]:
......@@ -1379,6 +1395,7 @@ class TestMinMax:
# check_grad_max(data, eval_outputs(grad(max_and_argmax(n,
# axis=1)[0], n)),axis=1)
@pytest.mark.xfail(reason="Fails due to #770")
def test_uint(self):
for dtype in ("uint8", "uint16", "uint32", "uint64"):
itype = np.iinfo(dtype)
......@@ -1404,6 +1421,14 @@ class TestMinMax:
assert np.all(i)
def test_MaxAndArgmax_deprecated():
with pytest.raises(
AttributeError,
match="The class `MaxandArgmax` has been deprecated. Call `Max` and `Argmax` seperately as an alternative.",
):
pytensor.tensor.math.MaxAndArgmax
rng = np.random.default_rng(seed=utt.fetch_seed())
TestClip1 = makeTester(
name="ClipTester",
......@@ -2572,27 +2597,50 @@ class TestInferShape(utt.InferShapeTester):
[adtens3], [Mean(aiscal_val)(adtens3)], [adtens3_val], Mean
)
def test_MaxAndArgmax(self):
def test_Max(self):
adtens3 = dtensor3()
adtens3_val = random(4, 5, 3)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, None), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 0), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 1), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 2), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], Max
)
def test_Argmax(self):
adtens3 = dtensor3()
adtens3_val = random(4, 5, 3)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, None), [adtens3_val], MaxAndArgmax
[adtens3], max_and_argmax(adtens3, None), [adtens3_val], Argmax
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 0), [adtens3_val], MaxAndArgmax
[adtens3], max_and_argmax(adtens3, 0), [adtens3_val], Argmax
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 1), [adtens3_val], MaxAndArgmax
[adtens3], max_and_argmax(adtens3, 1), [adtens3_val], Argmax
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 2), [adtens3_val], MaxAndArgmax
[adtens3], max_and_argmax(adtens3, 2), [adtens3_val], Argmax
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], MaxAndArgmax
[adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], Argmax
)
def test_Dot(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论