Unverified 提交 e6e6d69f authored 作者: Dhruvanshu-Joshi's avatar Dhruvanshu-Joshi 提交者: GitHub

Break MaxandArgmax Op to seperate TensorMax Op and Argmax Op (#731)

* Break MaxandArgmax to TensorMax and Argmax seperately * XFAIL pytensor tests for uint64 data type * Deprecate and raise AttributeError for MaxAndArgmax
上级 dbe0e09a
...@@ -477,7 +477,8 @@ acceptable_ops = ( ...@@ -477,7 +477,8 @@ acceptable_ops = (
Reshape, Reshape,
Unbroadcast, Unbroadcast,
pt.math.Dot, pt.math.Dot,
pt.math.MaxAndArgmax, pt.math.Max,
pt.math.Argmax,
pt.subtensor.Subtensor, pt.subtensor.Subtensor,
pt.subtensor.IncSubtensor, pt.subtensor.IncSubtensor,
pt.basic.Alloc, pt.basic.Alloc,
......
...@@ -2,7 +2,7 @@ import jax.numpy as jnp ...@@ -2,7 +2,7 @@ import jax.numpy as jnp
from pytensor.link.jax.dispatch import jax_funcify from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot, MaxAndArgmax from pytensor.tensor.math import Argmax, Dot, Max
from pytensor.tensor.nlinalg import ( from pytensor.tensor.nlinalg import (
SVD, SVD,
Det, Det,
...@@ -104,18 +104,28 @@ def jax_funcify_BatchedDot(op, **kwargs): ...@@ -104,18 +104,28 @@ def jax_funcify_BatchedDot(op, **kwargs):
return batched_dot return batched_dot
@jax_funcify.register(MaxAndArgmax) @jax_funcify.register(Max)
def jax_funcify_MaxAndArgmax(op, **kwargs): def jax_funcify_Max(op, **kwargs):
axis = op.axis axis = op.axis
def maxandargmax(x, axis=axis): def max(x):
max_res = jnp.max(x, axis)
return max_res
return max
@jax_funcify.register(Argmax)
def jax_funcify_Argmax(op, **kwargs):
axis = op.axis
def argmax(x):
if axis is None: if axis is None:
axes = tuple(range(x.ndim)) axes = tuple(range(x.ndim))
else: else:
axes = tuple(int(ax) for ax in axis) axes = tuple(int(ax) for ax in axis)
max_res = jnp.max(x, axis)
# NumPy does not support multiple axes for argmax; this is a # NumPy does not support multiple axes for argmax; this is a
# work-around # work-around
keep_axes = jnp.array( keep_axes = jnp.array(
...@@ -138,6 +148,6 @@ def jax_funcify_MaxAndArgmax(op, **kwargs): ...@@ -138,6 +148,6 @@ def jax_funcify_MaxAndArgmax(op, **kwargs):
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64") max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
return max_res, max_idx_res return max_idx_res
return maxandargmax return argmax
...@@ -44,7 +44,7 @@ from pytensor.scalar.basic import ( ...@@ -44,7 +44,7 @@ from pytensor.scalar.basic import (
) )
from pytensor.scalar.basic import add as add_as from pytensor.scalar.basic import add as add_as
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from pytensor.tensor.type import scalar from pytensor.tensor.type import scalar
...@@ -827,8 +827,8 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): ...@@ -827,8 +827,8 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
return log_softmax return log_softmax
@numba_funcify.register(MaxAndArgmax) @numba_funcify.register(Argmax)
def numba_funcify_MaxAndArgmax(op, node, **kwargs): def numba_funcify_Argmax(op, node, **kwargs):
axis = op.axis axis = op.axis
x_at = node.inputs[0] x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype x_dtype = x_at.type.numpy_dtype
...@@ -838,8 +838,8 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs): ...@@ -838,8 +838,8 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
if x_ndim == 0: if x_ndim == 0:
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
def maxandargmax(x): def argmax(x):
return x, 0 return 0
else: else:
axes = tuple(int(ax) for ax in axis) axes = tuple(int(ax) for ax in axis)
...@@ -848,20 +848,6 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs): ...@@ -848,20 +848,6 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
# work-around # work-around
keep_axes = tuple(i for i in range(x_ndim) if i not in axes) keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
reduce_max_py_fn = create_multiaxis_reducer(
scalar_maximum,
-np.inf,
axes,
x_ndim,
x_dtype,
return_scalar=False,
)
reduce_max = jit_compile_reducer(
Apply(node.op, node.inputs, [node.outputs[0].clone()]),
reduce_max_py_fn,
reduce_to_scalar=False,
)
reduced_x_ndim = x_ndim - len(axes) + 1 reduced_x_ndim = x_ndim - len(axes) + 1
argmax_axis = create_axis_apply_fn( argmax_axis = create_axis_apply_fn(
np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64 np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
...@@ -872,9 +858,7 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs): ...@@ -872,9 +858,7 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
sl2 = slice(len(keep_axes), None) sl2 = slice(len(keep_axes), None)
@numba_basic.numba_njit @numba_basic.numba_njit
def maxandargmax(x): def argmax(x):
max_res = reduce_max(x)
# Not-reduced axes in front # Not-reduced axes in front
transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order)) transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order))
kept_shape = transposed_x.shape[sl1] kept_shape = transposed_x.shape[sl1]
...@@ -890,6 +874,6 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs): ...@@ -890,6 +874,6 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
max_idx_res = argmax_axis(reshaped_x) max_idx_res = argmax_axis(reshaped_x)
return max_res, max_idx_res return max_idx_res
return maxandargmax return argmax
差异被折叠。
...@@ -35,31 +35,12 @@ from pytensor import scalar as ps ...@@ -35,31 +35,12 @@ from pytensor import scalar as ps
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import Alloc, alloc, constant from pytensor.tensor.basic import Alloc, alloc, constant
from pytensor.tensor.elemwise import CAReduce, DimShuffle from pytensor.tensor.elemwise import CAReduce, DimShuffle
from pytensor.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg from pytensor.tensor.math import Min, neg
from pytensor.tensor.rewriting.basic import register_uncanonicalize from pytensor.tensor.rewriting.basic import register_uncanonicalize
from pytensor.tensor.shape import Reshape, reshape from pytensor.tensor.shape import Reshape, reshape
from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.subtensor import Subtensor
@register_uncanonicalize
@node_rewriter([MaxAndArgmax])
def local_max_and_argmax(fgraph, node):
"""
If we don't use the argmax, change it to a max only.
"""
if isinstance(node.op, MaxAndArgmax):
axis = node.op.axis
if len(fgraph.clients[node.outputs[1]]) == 0:
new = Max(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new)
return [new, None]
if len(fgraph.clients[node.outputs[0]]) == 0:
new = Argmax(axis)(node.inputs[0])
copy_stack_trace(node.outputs[0], new)
return [None, new]
@register_uncanonicalize @register_uncanonicalize
@node_rewriter([neg]) @node_rewriter([neg])
def local_max_to_min(fgraph, node): def local_max_to_min(fgraph, node):
...@@ -71,7 +52,7 @@ def local_max_to_min(fgraph, node): ...@@ -71,7 +52,7 @@ def local_max_to_min(fgraph, node):
Notes Notes
----- -----
We don't need an opt that will do the reverse as by default We don't need an opt that will do the reverse as by default
the interface put only MaxAndArgmax into the graph. the interface put only Max into the graph.
""" """
if node.op == neg and node.inputs[0].owner: if node.op == neg and node.inputs[0].owner:
......
...@@ -11,7 +11,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery ...@@ -11,7 +11,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.link.jax import JAXLinker from pytensor.link.jax import JAXLinker
from pytensor.tensor import blas as pt_blas from pytensor.tensor import blas as pt_blas
from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor.math import MaxAndArgmax, maximum from pytensor.tensor.math import Argmax, Max, maximum
from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import max as pt_max
from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -88,7 +88,8 @@ def test_jax_basic_multiout_omni(): ...@@ -88,7 +88,8 @@ def test_jax_basic_multiout_omni():
# Test that a single output of a multi-output `Op` can be used as input to # Test that a single output of a multi-output `Op` can be used as input to
# another `Op` # another `Op`
x = dvector() x = dvector()
mx, amx = MaxAndArgmax([0])(x) mx = Max([0])(x)
amx = Argmax([0])(x)
out = mx * amx out = mx * amx
out_fg = FunctionGraph([x], [out]) out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.r_[1, 2]]) compare_jax_and_py(out_fg, [np.r_[1, 2]])
......
...@@ -552,8 +552,53 @@ def test_LogSoftmax(x, axis, exc): ...@@ -552,8 +552,53 @@ def test_LogSoftmax(x, axis, exc):
), ),
], ],
) )
def test_MaxAndArgmax(x, axes, exc): def test_Max(x, axes, exc):
g = ptm.MaxAndArgmax(axes)(x) g = ptm.Max(axes)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
)
@pytest.mark.parametrize(
"x, axes, exc",
[
(
set_test_value(pt.dscalar(), np.array(0.0, dtype="float64")),
[],
None,
),
(
set_test_value(pt.dvector(), rng.random(size=(3,)).astype("float64")),
[0],
None,
),
(
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0],
None,
),
(
set_test_value(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
[0, 1],
None,
),
],
)
def test_Argmax(x, axes, exc):
g = ptm.Argmax(axes)(x)
if isinstance(g, list): if isinstance(g, list):
g_fg = FunctionGraph(outputs=g) g_fg = FunctionGraph(outputs=g)
......
...@@ -38,7 +38,7 @@ from pytensor.tensor.blockwise import Blockwise ...@@ -38,7 +38,7 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import ( from pytensor.tensor.math import (
Dot, Dot,
MaxAndArgmax, Max,
Prod, Prod,
Sum, Sum,
_conj, _conj,
...@@ -3730,8 +3730,8 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None): ...@@ -3730,8 +3730,8 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None):
return return
# In mode FAST_COMPILE, the rewrites don't replace the # In mode FAST_COMPILE, the rewrites don't replace the
# `MaxAndArgmax` `Op`. # `Max` `Op`.
if isinstance(node.op, MaxAndArgmax): if isinstance(node.op, Max):
return return
# TODO FIXME: Refactor this test so that it makes a direct assertion and # TODO FIXME: Refactor this test so that it makes a direct assertion and
......
...@@ -9,8 +9,6 @@ from pytensor.graph.fg import FunctionGraph ...@@ -9,8 +9,6 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import out2in from pytensor.graph.rewriting.basic import out2in
from pytensor.link.basic import PerformLinker from pytensor.link.basic import PerformLinker
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import MaxAndArgmax, max_and_argmax
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import min as pt_min from pytensor.tensor.math import min as pt_min
from pytensor.tensor.rewriting.uncanonicalize import ( from pytensor.tensor.rewriting.uncanonicalize import (
local_alloc_dimshuffle, local_alloc_dimshuffle,
...@@ -23,67 +21,12 @@ from pytensor.tensor.type import dtensor4, iscalar, matrix, tensor, vector ...@@ -23,67 +21,12 @@ from pytensor.tensor.type import dtensor4, iscalar, matrix, tensor, vector
from tests.link.test_link import make_function from tests.link.test_link import make_function
class TestMaxAndArgmax:
def test_optimization(self):
# If we use only the max output, we should replace this op with
# a faster one.
mode = pytensor.compile.mode.get_default_mode().including(
"canonicalize", "fast_run"
)
for axis in [0, 1, -1]:
n = matrix()
f = function([n], max_and_argmax(n, axis)[0], mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce)
f = function([n], max_and_argmax(n, axis), mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, MaxAndArgmax)
class TestMinMax: class TestMinMax:
def setup_method(self): def setup_method(self):
self.mode = pytensor.compile.mode.get_default_mode().including( self.mode = pytensor.compile.mode.get_default_mode().including(
"canonicalize", "fast_run" "canonicalize", "fast_run"
) )
def test_optimization_max(self):
data = np.asarray(np.random.random((2, 3)), dtype=config.floatX)
n = matrix()
for axis in [0, 1, -1]:
f = function([n], pt_max(n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce)
f(data)
f = function([n], pt_max(-n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert isinstance(topo[0].op, Elemwise)
assert isinstance(topo[0].op.scalar_op, ps.Neg)
assert isinstance(topo[1].op, CAReduce)
f(data)
f = function([n], -pt_max(n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 2
assert isinstance(topo[0].op, CAReduce)
assert isinstance(topo[1].op, Elemwise)
assert isinstance(topo[1].op.scalar_op, ps.Neg)
f(data)
f = function([n], -pt_max(-n, axis), mode=self.mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, CAReduce) # min
f(data)
def test_optimization_min(self): def test_optimization_min(self):
data = np.asarray(np.random.random((2, 3)), dtype=config.floatX) data = np.asarray(np.random.random((2, 3)), dtype=config.floatX)
n = matrix() n = matrix()
......
...@@ -11,6 +11,7 @@ import scipy.special ...@@ -11,6 +11,7 @@ import scipy.special
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
from scipy.special import logsumexp as scipy_logsumexp from scipy.special import logsumexp as scipy_logsumexp
import pytensor
import pytensor.scalar as ps import pytensor.scalar as ps
from pytensor.compile.debugmode import DebugMode from pytensor.compile.debugmode import DebugMode
from pytensor.compile.function import function from pytensor.compile.function import function
...@@ -39,7 +40,7 @@ from pytensor.tensor.elemwise import CAReduce, Elemwise ...@@ -39,7 +40,7 @@ from pytensor.tensor.elemwise import CAReduce, Elemwise
from pytensor.tensor.math import ( from pytensor.tensor.math import (
Argmax, Argmax,
Dot, Dot,
MaxAndArgmax, Max,
Mean, Mean,
Prod, Prod,
ProdWithoutZeros, ProdWithoutZeros,
...@@ -760,11 +761,12 @@ def test_isnan(): ...@@ -760,11 +761,12 @@ def test_isnan():
class TestMaxAndArgmax: class TestMaxAndArgmax:
def setup_method(self): def setup_method(self):
MaxAndArgmax.debug = 0 Max.debug = 0
Argmax.debug = 0
def test_basic(self): def test_basic(self):
n = as_tensor_variable(5.0) n = as_tensor_variable(5)
v, i = eval_outputs(max_and_argmax(n)) v, i = eval_outputs(max_and_argmax(n, axis=()))
assert v == 5.0 assert v == 5.0
assert i == 0 assert i == 0
assert i.dtype == "int64" assert i.dtype == "int64"
...@@ -1030,31 +1032,45 @@ class TestMaxAndArgmax: ...@@ -1030,31 +1032,45 @@ class TestMaxAndArgmax:
x = tensor(shape=(5, 5, 5, 5)) x = tensor(shape=(5, 5, 5, 5))
batch_x = tensor(shape=(3, 5, 5, 5, 5)) batch_x = tensor(shape=(3, 5, 5, 5, 5))
# Test MaxAndArgmax argmax_x = argmax(x, axis=core_axis)
max_x, argmax_x = max_and_argmax(x, axis=core_axis)
node = max_x.owner
assert isinstance(node.op, MaxAndArgmax)
new_node = vectorize_node(node, batch_x) arg_max_node = argmax_x.owner
assert isinstance(new_node.op, MaxAndArgmax) new_node = vectorize_node(arg_max_node, batch_x)
assert new_node.op.axis == batch_axis
# Test Argmax
# Argmax is not user-facing, so we have to create it manually
node = Argmax(axis=node.op.axis).make_node(x)
new_node = vectorize_node(node, batch_x)
assert isinstance(new_node.op, Argmax) assert isinstance(new_node.op, Argmax)
assert new_node.op.axis == batch_axis assert new_node.op.axis == batch_axis
def test_max_empty_axis(self):
x = np.random.normal(size=(2, 3, 5, 7))
axis = ()
non_axis = tuple(i for i in range(x.ndim) if i not in axis)
shape_axis = tuple(x.shape[dim] for dim in axis)
shape_non_axis = tuple(x.shape[dim] for dim in non_axis)
x_transposed = x.transpose(*axis, *non_axis)
x_axis_raveled = x_transposed.reshape(
np.prod(shape_axis, dtype=int), np.prod(shape_non_axis, dtype=int)
)
max_x = max_and_argmax(x, axis=axis)[0].eval()
argmax_x = max_and_argmax(x, axis=axis)[1].eval()
raveled_max = x_axis_raveled[
argmax_x.ravel(), np.arange(np.prod(shape_non_axis, dtype=int))
]
indirect_max = raveled_max.reshape(shape_non_axis)
np.testing.assert_allclose(max_x, x.max(axis=axis))
np.testing.assert_allclose(indirect_max, x.max(axis=axis))
class TestArgminArgmax: class TestArgminArgmax:
def setup_method(self): def setup_method(self):
MaxAndArgmax.debug = 0 Argmax.debug = 0
def test_scalar(self): def test_scalar(self):
for fct in [argmin, argmax]: for fct in [argmin, argmax]:
n = as_tensor_variable(5.0) n = as_tensor_variable([5.0])
i = eval_outputs(fct(n)) i = eval_outputs(fct(n))
assert i == 0 assert i == 0
v = eval_outputs(fct(n).shape) v = eval_outputs(fct(n).shape)
...@@ -1212,7 +1228,7 @@ class TestArgminArgmax: ...@@ -1212,7 +1228,7 @@ class TestArgminArgmax:
class TestMinMax: class TestMinMax:
def setup_method(self): def setup_method(self):
MaxAndArgmax.debug = 0 Max.debug = 0
def test_scalar(self): def test_scalar(self):
for fct in [max, min]: for fct in [max, min]:
...@@ -1379,6 +1395,7 @@ class TestMinMax: ...@@ -1379,6 +1395,7 @@ class TestMinMax:
# check_grad_max(data, eval_outputs(grad(max_and_argmax(n, # check_grad_max(data, eval_outputs(grad(max_and_argmax(n,
# axis=1)[0], n)),axis=1) # axis=1)[0], n)),axis=1)
@pytest.mark.xfail(reason="Fails due to #770")
def test_uint(self): def test_uint(self):
for dtype in ("uint8", "uint16", "uint32", "uint64"): for dtype in ("uint8", "uint16", "uint32", "uint64"):
itype = np.iinfo(dtype) itype = np.iinfo(dtype)
...@@ -1404,6 +1421,14 @@ class TestMinMax: ...@@ -1404,6 +1421,14 @@ class TestMinMax:
assert np.all(i) assert np.all(i)
def test_MaxAndArgmax_deprecated():
with pytest.raises(
AttributeError,
match="The class `MaxandArgmax` has been deprecated. Call `Max` and `Argmax` seperately as an alternative.",
):
pytensor.tensor.math.MaxAndArgmax
rng = np.random.default_rng(seed=utt.fetch_seed()) rng = np.random.default_rng(seed=utt.fetch_seed())
TestClip1 = makeTester( TestClip1 = makeTester(
name="ClipTester", name="ClipTester",
...@@ -2572,27 +2597,50 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2572,27 +2597,50 @@ class TestInferShape(utt.InferShapeTester):
[adtens3], [Mean(aiscal_val)(adtens3)], [adtens3_val], Mean [adtens3], [Mean(aiscal_val)(adtens3)], [adtens3_val], Mean
) )
def test_MaxAndArgmax(self): def test_Max(self):
adtens3 = dtensor3()
adtens3_val = random(4, 5, 3)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, None), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 0), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 1), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 2), [adtens3_val], Max
)
self._compile_and_check(
[adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], Max
)
def test_Argmax(self):
adtens3 = dtensor3() adtens3 = dtensor3()
adtens3_val = random(4, 5, 3) adtens3_val = random(4, 5, 3)
self._compile_and_check( self._compile_and_check(
[adtens3], max_and_argmax(adtens3, None), [adtens3_val], MaxAndArgmax [adtens3], max_and_argmax(adtens3, None), [adtens3_val], Argmax
) )
self._compile_and_check( self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 0), [adtens3_val], MaxAndArgmax [adtens3], max_and_argmax(adtens3, 0), [adtens3_val], Argmax
) )
self._compile_and_check( self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 1), [adtens3_val], MaxAndArgmax [adtens3], max_and_argmax(adtens3, 1), [adtens3_val], Argmax
) )
self._compile_and_check( self._compile_and_check(
[adtens3], max_and_argmax(adtens3, 2), [adtens3_val], MaxAndArgmax [adtens3], max_and_argmax(adtens3, 2), [adtens3_val], Argmax
) )
self._compile_and_check( self._compile_and_check(
[adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], MaxAndArgmax [adtens3], max_and_argmax(adtens3, [0, 1, 2]), [adtens3_val], Argmax
) )
def test_Dot(self): def test_Dot(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论