Unverified 提交 e6e6d69f authored 作者: Dhruvanshu-Joshi's avatar Dhruvanshu-Joshi 提交者: GitHub

Break MaxandArgmax Op to seperate TensorMax Op and Argmax Op (#731)

* Break MaxandArgmax to TensorMax and Argmax seperately * XFAIL pytensor tests for uint64 data type * Deprecate and raise AttributeError for MaxAndArgmax
上级 dbe0e09a
......@@ -477,7 +477,8 @@ acceptable_ops = (
Reshape,
Unbroadcast,
pt.math.Dot,
pt.math.MaxAndArgmax,
pt.math.Max,
pt.math.Argmax,
pt.subtensor.Subtensor,
pt.subtensor.IncSubtensor,
pt.basic.Alloc,
......
......@@ -2,7 +2,7 @@ import jax.numpy as jnp
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot, MaxAndArgmax
from pytensor.tensor.math import Argmax, Dot, Max
from pytensor.tensor.nlinalg import (
SVD,
Det,
......@@ -104,18 +104,28 @@ def jax_funcify_BatchedDot(op, **kwargs):
return batched_dot
@jax_funcify.register(MaxAndArgmax)
def jax_funcify_MaxAndArgmax(op, **kwargs):
@jax_funcify.register(Max)
def jax_funcify_Max(op, **kwargs):
axis = op.axis
def maxandargmax(x, axis=axis):
def max(x):
max_res = jnp.max(x, axis)
return max_res
return max
@jax_funcify.register(Argmax)
def jax_funcify_Argmax(op, **kwargs):
axis = op.axis
def argmax(x):
if axis is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axis)
max_res = jnp.max(x, axis)
# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = jnp.array(
......@@ -138,6 +148,6 @@ def jax_funcify_MaxAndArgmax(op, **kwargs):
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
return max_res, max_idx_res
return max_idx_res
return maxandargmax
return argmax
......@@ -44,7 +44,7 @@ from pytensor.scalar.basic import (
)
from pytensor.scalar.basic import add as add_as
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum
from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from pytensor.tensor.type import scalar
......@@ -827,8 +827,8 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
return log_softmax
@numba_funcify.register(MaxAndArgmax)
def numba_funcify_MaxAndArgmax(op, node, **kwargs):
@numba_funcify.register(Argmax)
def numba_funcify_Argmax(op, node, **kwargs):
axis = op.axis
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
......@@ -838,8 +838,8 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
if x_ndim == 0:
@numba_basic.numba_njit(inline="always")
def maxandargmax(x):
return x, 0
def argmax(x):
return 0
else:
axes = tuple(int(ax) for ax in axis)
......@@ -848,20 +848,6 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
# work-around
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
reduce_max_py_fn = create_multiaxis_reducer(
scalar_maximum,
-np.inf,
axes,
x_ndim,
x_dtype,
return_scalar=False,
)
reduce_max = jit_compile_reducer(
Apply(node.op, node.inputs, [node.outputs[0].clone()]),
reduce_max_py_fn,
reduce_to_scalar=False,
)
reduced_x_ndim = x_ndim - len(axes) + 1
argmax_axis = create_axis_apply_fn(
np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
......@@ -872,9 +858,7 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
sl2 = slice(len(keep_axes), None)
@numba_basic.numba_njit
def maxandargmax(x):
max_res = reduce_max(x)
def argmax(x):
# Not-reduced axes in front
transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order))
kept_shape = transposed_x.shape[sl1]
......@@ -890,6 +874,6 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
max_idx_res = argmax_axis(reshaped_x)
return max_res, max_idx_res
return max_idx_res
return maxandargmax
return argmax
差异被折叠。
......@@ -35,31 +35,12 @@ from pytensor import scalar as ps
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import Alloc, alloc, constant
from pytensor.tensor.elemwise import CAReduce, DimShuffle
from pytensor.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg
from pytensor.tensor.math import Min, neg
from pytensor.tensor.rewriting.basic import register_uncanonicalize
from pytensor.tensor.shape import Reshape, reshape
from pytensor.tensor.subtensor import Subtensor
@register_uncanonicalize
@node_rewriter([MaxAndArgmax])
def local_max_and_argmax(fgraph, node):
"""
If we don't use the argmax, change it to a max only.
"""
if isinstance(node.op, MaxAndArgmax):
axis = node.op.axis
if len(fgraph.clients[node.outputs[1]]) == 0:
new = Max(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new)
return [new, None]
if len(fgraph.clients[node.outputs[0]]) == 0:
new = Argmax(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new)
return [None, new]
@register_uncanonicalize
@node_rewriter([neg])
def local_max_to_min(fgraph, node):
......@@ -71,7 +52,7 @@ def local_max_to_min(fgraph, node):
Notes
-----
We don't need an opt that will do the reverse as by default
the interface put only MaxAndArgmax into the graph.
the interface put only Max into the graph.
"""
if node.op == neg and node.inputs[0].owner:
......
......@@ -11,7 +11,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.link.jax import JAXLinker
from pytensor.tensor import blas as pt_blas
from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor.math import MaxAndArgmax, maximum
from pytensor.tensor.math import Argmax, Max, maximum
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -88,7 +88,8 @@ def test_jax_basic_multiout_omni():
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
x = dvector()
mx, amx = MaxAndArgmax([0])(x)
mx = Max([0])(x)
amx = Argmax([0])(x)
out = mx * amx
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.r_[1, 2]])
......
......@@ -552,8 +552,53 @@ def test_LogSoftmax(x, axis, exc):
),
],
)
def test_MaxAndArgmax(x, axes, exc):
g = ptm.MaxAndArgmax(axes)(x)
def test_Max(x, axes, exc):
g = ptm.Max(axes)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
)
@pytest.mark.parametrize(
"x, axes, exc",
[
(
set_test_value(pt.dscalar(), np.array(0.0, dtype="float64")),
[],
None,
),
(
set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")),
[0],
None,
),
(
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0],
None,
),
(
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0, 1],
None,
),
],
)
def test_Argmax(x, axes, exc):
g = ptm.Argmax(axes)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
......
......@@ -38,7 +38,7 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import (
Dot,
MaxAndArgmax,
Max,
Prod,
Sum,
_conj,
......@@ -3730,8 +3730,8 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None):
return
# In mode FAST_COMPILE, the rewrites don't replace the
# `MaxAndArgmax` `Op`.
if isinstance(node.op, MaxAndArgmax):
# `Max` `Op`.
if isinstance(node.op, Max):
return
# TODO FIXME: Refactor this test so that it makes a direct assertion and
......
......@@ -9,8 +9,6 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import out2in
from pytensor.link.basic import PerformLinker
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import MaxAndArgmax, max_and_argmax
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import min as pt_min
from pytensor.tensor.rewriting.uncanonicalize import (
local_alloc_dimshuffle,
......@@ -23,67 +21,12 @@ from pytensor.tensor.type import dtensor4, iscalar, matrix, tensor, vector
from tests.link.test_link import make_function
class TestMaxAndArgmax:
def test_optimization(self):
# If we use only the max output, we should replace this op with
# a faster one.
mode = pytensor.compile.mode.get_default_mode().including(
"canonicalize", "fast_run"
)
for axis in [0, 1, -1]:
n = matrix()
f = function([n], max_and_argmax(n, axis)[0], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce)
f = function([n], max_and_argmax(n, axis), mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, MaxAndArgmax)
class TestMinMax:
def setup_method(self):
self.mode = pytensor.compile.mode.get_default_mode().including(
"canonicalize", "fast_run"
)
def test_optimization_max(self):
data = np.asarray(np.random.random((2, 3)), dtype=config.floatX)
n = matrix()
for axis in [0, 1, -1]:
f = function([n], pt_max(n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce)
f(data)
f = function([n], pt_max(-n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert isinstance(topo[0].op, Elemwise)
assert isinstance(topo[0].op.scalar_op, ps.Neg)
assert isinstance(topo[1].op, CAReduce)
f(data)
f = function([n], -pt_max(n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert isinstance(topo[0].op, CAReduce)
assert isinstance(topo[1].op, Elemwise)
assert isinstance(topo[1].op.scalar_op, ps.Neg)
f(data)
f = function([n], -pt_max(-n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce) # min
f(data)
def test_optimization_min(self):
data = np.asarray(np.random.random((2, 3)), dtype=config.floatX)
n = matrix()
......
......@@ -11,6 +11,7 @@ import scipy.special
from numpy.testing import assert_array_equal
from scipy.special import logsumexp as scipy_logsumexp
import pytensor
import pytensor.scalar as ps
from pytensor.compile.debugmode import DebugMode
from pytensor.compile.function import function
......@@ -39,7 +40,7 @@ from pytensor.tensor.elemwise import CAReduce, Elemwise
from pytensor.tensor.math import (
Argmax,
Dot,
MaxAndArgmax,
Max,
Mean,
Prod,
ProdWithoutZeros,
......@@ -760,11 +761,12 @@ def test_isnan():
class TestMaxAndArgmax:
def setup_method(self):
MaxAndArgmax.debug = 0
Max.debug = 0
Argmax.debug = 0
def test_basic(self):
n = as_tensor_variable(5.0)
v, i = eval_outputs(max_and_argmax(n))
n = as_tensor_variable(5)
v, i = eval_outputs(max_and_argmax(n, axis=()))
assert v == 5.0
assert i == 0
assert i.dtype == "int64"
......@@ -1030,31 +1032,45 @@ class TestMaxAndArgmax:
x = tensor(shape=(5, 5, 5, 5))
batch_x = tensor(shape=(3, 5, 5, 5, 5))
# Test MaxAndArgmax
max_x, argmax_x = max_and_argmax(x, axis=core_axis)
node = max_x.owner
assert isinstance(node.op, MaxAndArgmax)
new_node = vectorize_node(node, batch_x)
assert isinstance(new_node.op, MaxAndArgmax)
assert new_node.op.axis == batch_axis
argmax_x = argmax(x, axis=core_axis)
# Test Argmax
# Argmax is not user-facing, so we have to create it manually
node = Argmax(axis=node.op.axis).make_node(x)
arg_max_node = argmax_x.owner
new_node = vectorize_node(arg_max_node, batch_x)
new_node = vectorize_node(node, batch_x)
assert isinstance(new_node.op, Argmax)
assert new_node.op.axis == batch_axis
def test_max_empty_axis(self):
x = np.random.normal(size=(2, 3, 5, 7))
axis = ()
non_axis = tuple(i for i in range(x.ndim) if i not in axis)
shape_axis = tuple(x.shape[dim] for dim in axis)
shape_non_axis = tuple(x.shape[dim] for dim in non_axis)
x_transposed = x.transpose(*axis, *non_axis)
x_axis_raveled = x_transposed.reshape(
np.prod(shape_axis, dtype=int), np.prod(shape_non_axis, dtype=int)
)
max_x = max_and_argmax(x, axis=axis)[0].eval()
argmax_x = max_and_argmax(x, axis=axis)[1].eval()
raveled_max = x_axis_raveled[
argmax_x.ravel(), np.arange(np.prod(shape_non_axis, dtype=int))
]
indirect_max = raveled_max.reshape(shape_non_axis)
np.testing.assert_allclose(max_x, x.max(axis=axis))
np.testing.assert_allclose(indirect_max, x.max(axis=axis))
class TestArgminArgmax:
def setup_method(self):
MaxAndArgmax.debug = 0
Argmax.debug = 0
def test_scalar(self):
for fct in [argmin, argmax]:
n = as_tensor_variable(5.0)
n = as_tensor_variable([5.0])
i = eval_outputs(fct(n))
assert i == 0
v = eval_outputs(fct(n).shape)
......@@ -1212,7 +1228,7 @@ class TestArgminArgmax:
class TestMinMax:
def setup_method(self):
MaxAndArgmax.debug = 0
Max.debug = 0
def test_scalar(self):
for fct in [max, min]:
......@@ -1379,6 +1395,7 @@ class TestMinMax:
# check_grad_max(data, eval_outputs(grad(max_and_argmax(n,
# axis=1)[0], n)),axis=1)
@pytest.mark.xfail(reason="Fails due to #770")
def test_uint(self):
for dtype in ("uint8", "uint16", "uint32", "uint64"):
itype = np.iinfo(dtype)
......@@ -1404,6 +1421,14 @@ class TestMinMax:
assert np.all(i)
def test_MaxAndArgmax_deprecated():
with pytest.raises(
AttributeError,
match="The class `MaxandArgmax` has been deprecated. Call `Max` and `Argmax` seperately as an alternative.",
):
pytensor.tensor.math.MaxAndArgmax
rng = np.random.default_rng(seed=utt.fetch_seed())
TestClip1 = makeTester(
name="ClipTester",
......@@ -2572,27 +2597,50 @@ class TestInferShape(utt.InferShapeTester):
[adtens3], [Mean(aiscal_val)(adtens3)], [adtens3_val], Mean
)
def test_MaxAndArgmax(self):
def test_Max(self):
adtens3 = dtensor3()
adtens3_val = random(4, 5, 3)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, None), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 0), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 1), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 2), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], Max
)
def test_Argmax(self):
adtens3 = dtensor3()
adtens3_val = random(4, 5, 3)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, None), [adtens3_val], MaxAndArgmax
[adtens3], max_and_argmax(adtens3, None), [adtens3_val], Argmax
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 0), [adtens3_val], MaxAndArgmax
[adtens3], max_and_argmax(adtens3, 0), [adtens3_val], Argmax
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 1), [adtens3_val], MaxAndArgmax
[adtens3], max_and_argmax(adtens3, 1), [adtens3_val], Argmax
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 2), [adtens3_val], MaxAndArgmax
[adtens3], max_and_argmax(adtens3, 2), [adtens3_val], Argmax
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], MaxAndArgmax
[adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], Argmax
)
def test_Dot(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论