提交 79862c63 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Cleanup Max and Argmax

上级 b675d135
...@@ -30,7 +30,11 @@ from pytensor.tensor.type import ( ...@@ -30,7 +30,11 @@ from pytensor.tensor.type import (
float_dtypes, float_dtypes,
lvector, lvector,
) )
from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string from pytensor.tensor.utils import (
broadcast_static_dim_lengths,
import_func_from_string,
normalize_reduce_axis,
)
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
from pytensor.utils import uniq from pytensor.utils import uniq
...@@ -1371,7 +1375,6 @@ class CAReduce(COp): ...@@ -1371,7 +1375,6 @@ class CAReduce(COp):
def make_node(self, input): def make_node(self, input):
input = as_tensor_variable(input) input = as_tensor_variable(input)
inp_dims = input.type.ndim
inp_dtype = input.type.dtype inp_dtype = input.type.dtype
# We need to redefine make_node so that, if self.dtype is None, # We need to redefine make_node so that, if self.dtype is None,
...@@ -1383,29 +1386,19 @@ class CAReduce(COp): ...@@ -1383,29 +1386,19 @@ class CAReduce(COp):
assert dtype is not None assert dtype is not None
assert acc_dtype is not None assert acc_dtype is not None
axis = self.axis axis = normalize_reduce_axis(self.axis, ndim=input.type.ndim)
# scalar inputs are treated as 1D regarding axis in this `Op` if axis != self.axis or dtype != self.dtype or acc_dtype != self.acc_dtype:
if axis is not None: op = self.clone(axis=axis, dtype=dtype, acc_dtype=acc_dtype)
try: else:
axis = normalize_axis_tuple(axis, ndim=max(1, inp_dims)) op = self
except np.AxisError:
raise np.AxisError(axis, ndim=inp_dims)
if axis is None:
out_shape = ()
else:
out_shape = tuple( out_shape = tuple(
s for i, s in enumerate(input.type.shape) if i not in axis s for i, s in enumerate(input.type.shape) if i not in axis
) )
else:
out_shape = ()
if (
(axis is not None and any(a < 0 for a in axis))
or dtype != self.dtype
or acc_dtype != self.acc_dtype
):
op = self.clone(axis=axis, dtype=dtype, acc_dtype=acc_dtype)
else:
op = self
output = TensorType(dtype=dtype, shape=out_shape)() output = TensorType(dtype=dtype, shape=out_shape)()
......
...@@ -8,7 +8,6 @@ from numpy.core.numeric import normalize_axis_tuple ...@@ -8,7 +8,6 @@ from numpy.core.numeric import normalize_axis_tuple
from pytensor import config, printing from pytensor import config, printing
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
...@@ -26,9 +25,9 @@ from pytensor.tensor.basic import ( ...@@ -26,9 +25,9 @@ from pytensor.tensor.basic import (
cast, cast,
concatenate, concatenate,
constant, constant,
expand_dims,
stack, stack,
switch, switch,
zeros_like,
) )
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.elemwise import ( from pytensor.tensor.elemwise import (
...@@ -45,14 +44,11 @@ from pytensor.tensor.type import ( ...@@ -45,14 +44,11 @@ from pytensor.tensor.type import (
continuous_dtypes, continuous_dtypes,
discrete_dtypes, discrete_dtypes,
int_dtypes, int_dtypes,
integer_dtypes,
tensor, tensor,
uint_dtypes, uint_dtypes,
) )
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.utils import as_list, normalize_reduce_axis
from pytensor.tensor.utils import as_list
from pytensor.tensor.variable import ( from pytensor.tensor.variable import (
TensorConstant,
TensorVariable, TensorVariable,
_tensor_py_operators, _tensor_py_operators,
) )
...@@ -157,7 +153,7 @@ class Argmax(COp): ...@@ -157,7 +153,7 @@ class Argmax(COp):
def __init__(self, axis): def __init__(self, axis):
if axis is not None: if axis is not None:
axis = tuple(axis) axis = tuple(sorted(axis))
self.axis = axis self.axis = axis
def get_params(self, node): def get_params(self, node):
...@@ -168,7 +164,7 @@ class Argmax(COp): ...@@ -168,7 +164,7 @@ class Argmax(COp):
c_axis = np.int64(-1) c_axis = np.int64(-1)
return self.params_type.get_params(c_axis=c_axis) return self.params_type.get_params(c_axis=c_axis)
def make_node(self, x, axis=None): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
if self.axis is None: if self.axis is None:
all_axes = list(range(x.ndim)) all_axes = list(range(x.ndim))
...@@ -198,7 +194,9 @@ class Argmax(COp): ...@@ -198,7 +194,9 @@ class Argmax(COp):
# Work around # Work around
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64") keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
# Not-reduced axes in front # Not-reduced axes in front
transposed_x = np.transpose(x, np.concatenate((keep_axes, axes))) transposed_x = np.transpose(
x, np.concatenate((keep_axes, np.asarray(axes, dtype="int64")))
)
kept_shape = transposed_x.shape[: len(keep_axes)] kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :] reduced_shape = transposed_x.shape[len(keep_axes) :]
new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64")) new_shape = (*kept_shape, np.prod(reduced_shape, dtype="int64"))
...@@ -214,7 +212,7 @@ class Argmax(COp): ...@@ -214,7 +212,7 @@ class Argmax(COp):
if self.axis is None: if self.axis is None:
axis_code = "axis = NPY_MAXDIMS;" axis_code = "axis = NPY_MAXDIMS;"
else: else:
if len(self.axis) > 1: if len(self.axis) != 1:
raise NotImplementedError() raise NotImplementedError()
# params is only used here for now # params is only used here for now
axis_code = """ axis_code = """
...@@ -253,7 +251,7 @@ class Argmax(COp): ...@@ -253,7 +251,7 @@ class Argmax(COp):
return ret % locals() return ret % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
(ishape,) = shapes (ishape,) = shapes
...@@ -277,7 +275,7 @@ class Argmax(COp): ...@@ -277,7 +275,7 @@ class Argmax(COp):
return [x.zeros_like()] return [x.zeros_like()]
def argmax(x, axis=None, keepdims=False): def argmax(x: TensorLike, axis=None, keepdims: bool = False):
""" """
Returns indices of maximum elements obtained by iterating over given axis. Returns indices of maximum elements obtained by iterating over given axis.
...@@ -286,17 +284,29 @@ def argmax(x, axis=None, keepdims=False): ...@@ -286,17 +284,29 @@ def argmax(x, axis=None, keepdims=False):
Parameters Parameters
---------- ----------
x: TensorLike
Array on which to compute argmax
axis:
Axis along which to compute argmax. Unlike numpy multiple partial axis are supported.
keepdims : bool keepdims : bool
If this is set to True, the axes which are reduced are left in If this is set to True, the axes which are reduced are left in
the result as dimensions with size one. With this option, the result the result as dimensions with size one. With this option, the result
will broadcast correctly against the original tensor. will broadcast correctly against the original tensor.
Returns
-------
TensorVariable
TensorVariable representing the argmax operation
""" """
argout = max_and_argmax(x, axis)[1] x = as_tensor_variable(x)
axis = normalize_reduce_axis(axis, ndim=x.type.ndim)
out = Argmax(axis)(x)
if keepdims: if keepdims:
argout = makeKeepDims(x, argout, axis) out = makeKeepDims(x, out, axis)
return argout
return out
@_vectorize_node.register(Argmax) @_vectorize_node.register(Argmax)
...@@ -324,59 +334,6 @@ def makeKeepDims(x, y, axis): ...@@ -324,59 +334,6 @@ def makeKeepDims(x, y, axis):
return expand_dims(y, axis) return expand_dims(y, axis)
def check_and_normalize_axes(x, axis):
"""Check axes, normalize and convert them to a Python list of integers.
Parameters
----------
x: TensorVariable
axis: int, tuple or list of integers
Returns
-------
axis: list of integers
Return an empty list if argument is None.
"""
x = as_tensor_variable(x)
if axis is None:
axis = []
elif isinstance(axis, int | np.integer) or (
isinstance(axis, np.ndarray) and axis.ndim == 0
):
axis = [int(axis)]
elif isinstance(axis, tuple | list | np.ndarray):
axis = [int(i) for i in axis]
elif isinstance(axis, Variable):
if NoneConst.equals(axis):
axis = []
elif not isinstance(axis, TensorConstant):
raise TypeError(f"Computation needs a constant axis. Got {axis}")
else:
assert axis.dtype in integer_dtypes
if isinstance(axis.data, int | np.integer) or (
isinstance(axis.data, np.ndarray) and axis.data.ndim == 0
):
axis = [int(axis.data)]
elif isinstance(axis.data, list | np.ndarray):
axis = [int(i) for i in axis.data]
else:
raise TypeError(
f"Axis must be an integer, tuple, list of integers or a TensorVariable. Got {axis}"
)
if len(axis) > 0:
for i in range(len(axis)):
if axis[i] < 0:
axis[i] += x.type.ndim
if axis[i] < 0 or axis[i] >= x.type.ndim:
raise ValueError(
f"Computation needs a valid axis number for {int(x.type.ndim)}-D tensor. Got {int(axis[i])}"
)
axis = list(set(axis))
axis.sort()
return axis
def max_and_argmax(a, axis=None, keepdims=False): def max_and_argmax(a, axis=None, keepdims=False):
""" """
Returns maximum elements and their indices obtained by iterating over Returns maximum elements and their indices obtained by iterating over
...@@ -395,28 +352,10 @@ def max_and_argmax(a, axis=None, keepdims=False): ...@@ -395,28 +352,10 @@ def max_and_argmax(a, axis=None, keepdims=False):
""" """
# Check axis and convert it to a Python list of integers. # Check axis and convert it to a Python list of integers.
# Axis will be used as an op param of Max and Argmax. # Axis will be used as an op param of Max and Argmax.
a = as_tensor_variable(a) return [
max(a, axis=axis, keepdims=keepdims),
is_axis_empty = False argmax(a, axis=axis, keepdims=keepdims),
if axis == (): ]
is_axis_empty = True
axis = check_and_normalize_axes(a, axis)
if len(axis) == 0 and not is_axis_empty:
axis = None
out = Max(axis)(a)
if not is_axis_empty:
argout = Argmax(axis)(a)
else:
argout = zeros_like(a, dtype="int64")
if keepdims:
out = makeKeepDims(a, out, axis)
argout = makeKeepDims(a, argout, axis)
return [out, argout]
class FixedOpCAReduce(CAReduce): class FixedOpCAReduce(CAReduce):
...@@ -465,7 +404,7 @@ class Max(NonZeroDimsCAReduce): ...@@ -465,7 +404,7 @@ class Max(NonZeroDimsCAReduce):
axis = kwargs.get("axis", self.axis) axis = kwargs.get("axis", self.axis)
return type(self)(axis=axis) return type(self)(axis=axis)
def grad(self, inp, grads): def L_op(self, inputs, outputs, grads):
# The strict sense mathematical gradient of the maximum function is # The strict sense mathematical gradient of the maximum function is
# not calculated here for it is not defined at every point where some # not calculated here for it is not defined at every point where some
# coordinates are identical. However, since the latter set has null # coordinates are identical. However, since the latter set has null
...@@ -479,53 +418,27 @@ class Max(NonZeroDimsCAReduce): ...@@ -479,53 +418,27 @@ class Max(NonZeroDimsCAReduce):
# g_max has one less dimension than x, so you need to complete # g_max has one less dimension than x, so you need to complete
# g_max to x's shape when axis=0 the broadcasting mechanism # g_max to x's shape when axis=0 the broadcasting mechanism
# does it automatically # does it automatically
x = inp[0] [x] = inputs
if self.axis is None: [out] = outputs
self.axis = tuple(range(x.ndim)) [g_out] = grads
axis = as_tensor_variable(self.axis)
(g_max,) = grads
g_max_disconnected = isinstance(g_max.type, DisconnectedType)
# if the op is totally disconnected, so are its inputs axis = tuple(range(x.ndim)) if self.axis is None else self.axis
if g_max_disconnected: out_pad = expand_dims(out, axis)
return [DisconnectedType()()] g_out_pad = expand_dims(g_out, axis)
# if NoneConst.equals(axis):
if axis is None:
axis_ = list(range(x.ndim))
else:
axis_ = axis
xmax = max(x, axis_)
# Raise the g_max and xmax to the same number of dim as the input.
pattern = []
out_dim = 0
if NoneConst.equals(axis):
# We are taking the max/argmax over all dimensions.
axis = None
for i in range(x.ndim):
if axis is None or i in axis.data:
pattern.append("x")
else:
pattern.append(out_dim)
out_dim += 1
g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max)
xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax)
# Set the grad to the correct position. # Set the grad to the correct position.
g_x = eq(xmax_pad, x) * g_max_pad g_x = eq(out_pad, x) * g_out_pad
return (g_x,) return (g_x,)
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
return [None, None] return [None, None]
if len(self.axis) != 1: if len(self.axis) != 1:
raise ValueError("R_op supported for arg_max only for one axis!") raise ValueError("R_op supported for max only for one axis!")
if self.axis[0] > 1: if self.axis[0] > 1:
raise ValueError("R_op supported for arg_max only when axis is 0 or 1") raise ValueError("R_op supported for max only when axis is 0 or 1")
if inputs[0].ndim != 2: if inputs[0].ndim != 2:
raise ValueError("R_op supported for arg_max only when input is a matrix") raise ValueError("R_op supported for max only when input is a matrix")
max_pos = Argmax(self.axis).make_node(*inputs).outputs max_pos = Argmax(self.axis).make_node(*inputs).outputs
# print(eval_points[0].eval()) # print(eval_points[0].eval())
if self.axis[0] == 0: if self.axis[0] == 0:
...@@ -564,7 +477,7 @@ def max(x, axis=None, keepdims=False): ...@@ -564,7 +477,7 @@ def max(x, axis=None, keepdims=False):
We return an error as numpy when we reduce a dim with a shape of 0. We return an error as numpy when we reduce a dim with a shape of 0.
""" """
out = max_and_argmax(x, axis)[0] out = Max(axis=axis)(x)
if keepdims: if keepdims:
out = makeKeepDims(x, out, axis) out = makeKeepDims(x, out, axis)
......
import re import re
from collections.abc import Sequence from collections.abc import Sequence
from typing import cast
import numpy as np import numpy as np
from numpy.core.numeric import normalize_axis_tuple # type: ignore
import pytensor import pytensor
from pytensor.utils import hash_from_code from pytensor.utils import hash_from_code
...@@ -223,3 +225,19 @@ def safe_signature( ...@@ -223,3 +225,19 @@ def safe_signature(
operand_sig(ndim, prefix=f"o{n}") for n, ndim in enumerate(core_outputs_ndim) operand_sig(ndim, prefix=f"o{n}") for n, ndim in enumerate(core_outputs_ndim)
) )
return f"{inputs_sig}->{outputs_sig}" return f"{inputs_sig}->{outputs_sig}"
def normalize_reduce_axis(axis, ndim: int) -> tuple[int, ...] | None:
"""Normalize the axis parameter for reduce operations."""
if axis is None:
return None
# scalar inputs are treated as 1D regarding axis in reduce operations
if axis is not None:
try:
axis = normalize_axis_tuple(axis, ndim=max(1, ndim))
except np.AxisError:
raise np.AxisError(axis, ndim=ndim)
# TODO: If axis tuple is equivalent to None, return None for more canonicalization?
return cast(tuple, axis)
...@@ -154,7 +154,6 @@ from pytensor.tensor.type import ( ...@@ -154,7 +154,6 @@ from pytensor.tensor.type import (
vectors, vectors,
zvector, zvector,
) )
from pytensor.tensor.type_other import NoneConst
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.link.test_link import make_function from tests.link.test_link import make_function
from tests.tensor.utils import ( from tests.tensor.utils import (
...@@ -767,9 +766,10 @@ class TestMaxAndArgmax: ...@@ -767,9 +766,10 @@ class TestMaxAndArgmax:
Max.debug = 0 Max.debug = 0
Argmax.debug = 0 Argmax.debug = 0
def test_basic(self): @pytest.mark.parametrize("empty_axis", [(), None])
def test_empty_axis_scalar(self, empty_axis):
n = as_tensor_variable(5) n = as_tensor_variable(5)
v, i = eval_outputs(max_and_argmax(n, axis=())) v, i = eval_outputs(max_and_argmax(n, axis=empty_axis))
assert v == 5.0 assert v == 5.0
assert i == 0 assert i == 0
assert i.dtype == "int64" assert i.dtype == "int64"
...@@ -778,6 +778,29 @@ class TestMaxAndArgmax: ...@@ -778,6 +778,29 @@ class TestMaxAndArgmax:
v = eval_outputs(max_and_argmax(n)[1].shape) v = eval_outputs(max_and_argmax(n)[1].shape)
assert len(v) == 0 assert len(v) == 0
def test_empty_axis_tensor(self):
x = np.random.normal(size=(2, 3, 5, 7))
axis = ()
non_axis = tuple(i for i in range(x.ndim) if i not in axis)
shape_axis = tuple(x.shape[dim] for dim in axis)
shape_non_axis = tuple(x.shape[dim] for dim in non_axis)
x_transposed = x.transpose(*axis, *non_axis)
x_axis_raveled = x_transposed.reshape(
np.prod(shape_axis, dtype=int), np.prod(shape_non_axis, dtype=int)
)
max_x = max_and_argmax(x, axis=axis)[0].eval()
argmax_x = max_and_argmax(x, axis=axis)[1].eval()
raveled_max = x_axis_raveled[
argmax_x.ravel(), np.arange(np.prod(shape_non_axis, dtype=int))
]
indirect_max = raveled_max.reshape(shape_non_axis)
np.testing.assert_allclose(max_x, x.max(axis=axis))
np.testing.assert_allclose(indirect_max, x.max(axis=axis))
def test_basic_1(self): def test_basic_1(self):
n = as_tensor_variable([1, 2, 3, 2, -6]) n = as_tensor_variable([1, 2, 3, 2, -6])
v, i = eval_outputs(max_and_argmax(n)) v, i = eval_outputs(max_and_argmax(n))
...@@ -796,8 +819,6 @@ class TestMaxAndArgmax: ...@@ -796,8 +819,6 @@ class TestMaxAndArgmax:
(None, None), (None, None),
([0, 1], None), ([0, 1], None),
([1, 0], None), ([1, 0], None),
(NoneConst.clone(), None),
(constant(0), 0),
], ],
) )
def test_basic_2(self, axis, np_axis): def test_basic_2(self, axis, np_axis):
...@@ -826,8 +847,6 @@ class TestMaxAndArgmax: ...@@ -826,8 +847,6 @@ class TestMaxAndArgmax:
(None, None), (None, None),
([0, 1], None), ([0, 1], None),
([1, 0], None), ([1, 0], None),
(NoneConst.clone(), None),
(constant(0), 0),
], ],
) )
def test_basic_2_float16(self, axis, np_axis): def test_basic_2_float16(self, axis, np_axis):
...@@ -986,7 +1005,7 @@ class TestMaxAndArgmax: ...@@ -986,7 +1005,7 @@ class TestMaxAndArgmax:
safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data]) safe_verify_grad(lambda v: max_and_argmax(v, axis=[i])[1], [data])
# Test grad with multiple axes # Test grad with multiple axes
for i in [[0, 1], [0, 0]]: for i in [[0, 1], [0, 2, 3]]:
safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[0], [data]) safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[0], [data])
safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[1], [data]) safe_verify_grad(lambda v: max_and_argmax(v, axis=i)[1], [data])
...@@ -1043,29 +1062,6 @@ class TestMaxAndArgmax: ...@@ -1043,29 +1062,6 @@ class TestMaxAndArgmax:
assert isinstance(new_node.op, Argmax) assert isinstance(new_node.op, Argmax)
assert new_node.op.axis == batch_axis assert new_node.op.axis == batch_axis
def test_max_empty_axis(self):
x = np.random.normal(size=(2, 3, 5, 7))
axis = ()
non_axis = tuple(i for i in range(x.ndim) if i not in axis)
shape_axis = tuple(x.shape[dim] for dim in axis)
shape_non_axis = tuple(x.shape[dim] for dim in non_axis)
x_transposed = x.transpose(*axis, *non_axis)
x_axis_raveled = x_transposed.reshape(
np.prod(shape_axis, dtype=int), np.prod(shape_non_axis, dtype=int)
)
max_x = max_and_argmax(x, axis=axis)[0].eval()
argmax_x = max_and_argmax(x, axis=axis)[1].eval()
raveled_max = x_axis_raveled[
argmax_x.ravel(), np.arange(np.prod(shape_non_axis, dtype=int))
]
indirect_max = raveled_max.reshape(shape_non_axis)
np.testing.assert_allclose(max_x, x.max(axis=axis))
np.testing.assert_allclose(indirect_max, x.max(axis=axis))
class TestArgminArgmax: class TestArgminArgmax:
def setup_method(self): def setup_method(self):
......
...@@ -192,9 +192,7 @@ class RopLopChecker: ...@@ -192,9 +192,7 @@ class RopLopChecker:
class TestRopLop(RopLopChecker): class TestRopLop(RopLopChecker):
def test_max(self): def test_max(self):
# If we call max directly, we will return an CAReduce object # self.check_mat_rop_lop(pt_max(self.mx, axis=[0,1])[0], ())
# which doesn't have R_op implemented!
# self.check_mat_rop_lop(at_max(self.mx, axis=[0,1])[0], ())
self.check_mat_rop_lop(pt_max(self.mx, axis=0), (self.mat_in_shape[1],)) self.check_mat_rop_lop(pt_max(self.mx, axis=0), (self.mat_in_shape[1],))
self.check_mat_rop_lop(pt_max(self.mx, axis=1), (self.mat_in_shape[0],)) self.check_mat_rop_lop(pt_max(self.mx, axis=1), (self.mat_in_shape[0],))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论