提交 53cad9b9 authored 作者: Cove Geary's avatar Cove Geary 提交者: Ricardo Vieira

Propagate static shape in MaxAndArgmax

上级 e37497fd
......@@ -142,15 +142,10 @@ class MaxAndArgmax(COp):
def make_node(self, x):
x = as_tensor_variable(x)
# We keep the original broadcastable flags for dimensions on which
# we do not perform the max / argmax.
# Keep the original shapes for axes on which we do not perform the max/argmax.
all_axes = set(self.axis)
inputs = [x]
out_shape = tuple(
1 if s == 1 else None
for i, s in enumerate(x.type.shape)
if i not in all_axes
)
out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes)
outputs = [
tensor(dtype=x.type.dtype, shape=out_shape, name="max"),
tensor(dtype="int64", shape=out_shape, name="argmax"),
......@@ -1521,7 +1516,6 @@ class Mean(CAReduce):
output[0] = np.asarray(np.mean(input, dtype="float64", axis=axis))
def c_code(self, node, name, inames, onames, sub):
ret = super().c_code(node, name, inames, onames, sub)
if self.axis is not None:
......@@ -1940,7 +1934,6 @@ class Dot(Op):
z[0] = np.asarray(np.dot(x, y))
def grad(self, inp, grads):
x, y = inp
(gz,) = grads
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
......@@ -2631,7 +2624,6 @@ class Prod(CAReduce):
# this handles inputs with zeros, but only certain input shapes
return [grad_case_without_zeros]
else:
where_zeros = eq(prod_in, 0.0)
sum_where_zeros = sum(where_zeros, axis=self.axis)
groups_with_single_zero = eq(sum_where_zeros, 1).dimshuffle(new_dims)
......@@ -2924,7 +2916,6 @@ class MatMul(Op):
)
return x2_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:]
else:
if validate:
from pytensor.tensor.random.basic import broadcast_shapes
......
......@@ -771,10 +771,9 @@ class TestMaxAndArgmax:
v = eval_outputs(max_and_argmax(n)[0].shape)
assert len(v) == 0
def test_basic_2(self):
data = random(2, 3)
n = as_tensor_variable(data)
for (axis, np_axis) in [
@pytest.mark.parametrize(
"axis,np_axis",
[
(-1, -1),
(0, 0),
(1, 1),
......@@ -783,19 +782,28 @@ class TestMaxAndArgmax:
([1, 0], None),
(NoneConst.clone(), None),
(constant(0), 0),
]:
v, i = eval_outputs(max_and_argmax(n, axis))
],
)
def test_basic_2(self, axis, np_axis):
data = random(2, 3)
n = as_tensor_variable(data)
# Test shape propagates (static & eval)
vt, it = max_and_argmax(n, axis)
np_max, np_argm = np.max(data, np_axis), np.argmax(data, np_axis)
assert vt.type.shape == np_max.shape
assert it.type.shape == np_argm.shape
v_shape, i_shape = eval_outputs([vt.shape, it.shape])
assert tuple(v_shape) == vt.type.shape
assert tuple(i_shape) == it.type.shape
# Test values
v, i = eval_outputs([vt, it])
assert i.dtype == "int64"
assert np.all(v == np.max(data, np_axis))
assert np.all(i == np.argmax(data, np_axis))
v_shape = eval_outputs(max_and_argmax(n, axis)[0].shape)
assert tuple(v_shape) == np.max(data, np_axis).shape
assert np.all(v == np_max)
assert np.all(i == np_argm)
def test_basic_2_float16(self):
# Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
data = (random(20, 30).astype("float16") - 0.5) * 20
n = shared(data)
for (axis, np_axis) in [
@pytest.mark.parametrize(
"axis,np_axis",
[
(-1, -1),
(0, 0),
(1, 1),
......@@ -804,13 +812,25 @@ class TestMaxAndArgmax:
([1, 0], None),
(NoneConst.clone(), None),
(constant(0), 0),
]:
v, i = eval_outputs(max_and_argmax(n, axis), (MaxAndArgmax,))
],
)
def test_basic_2_float16(self, axis, np_axis):
# Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
data = (random(20, 30).astype("float16") - 0.5) * 20
n = as_tensor_variable(data)
# Test shape propagates (static & eval)
vt, it = max_and_argmax(n, axis)
np_max, np_argm = np.max(data, np_axis), np.argmax(data, np_axis)
assert vt.type.shape == np_max.shape
assert it.type.shape == np_argm.shape
v_shape, i_shape = eval_outputs([vt.shape, it.shape])
assert tuple(v_shape) == vt.type.shape
assert tuple(i_shape) == it.type.shape
# Test values
v, i = eval_outputs([vt, it])
assert i.dtype == "int64"
assert np.all(v == np.max(data, np_axis))
assert np.all(i == np.argmax(data, np_axis))
v_shape = eval_outputs(max_and_argmax(n, axis)[0].shape)
assert tuple(v_shape) == np.max(data, np_axis).shape
assert np.all(v == np_max)
assert np.all(i == np_argm)
def test_basic_2_invalid(self):
n = as_tensor_variable(random(2, 3))
......@@ -840,23 +860,33 @@ class TestMaxAndArgmax:
v = eval_outputs(max_and_argmax(n, -2)[0].shape)
assert v == (3)
def test_basic_3(self):
data = random(2, 3, 4)
n = as_tensor_variable(data)
for (axis, np_axis) in [
@pytest.mark.parametrize(
"axis,np_axis",
[
(-1, -1),
(0, 0),
(1, 1),
(None, None),
([0, 1, 2], None),
([1, 2, 0], None),
]:
v, i = eval_outputs(max_and_argmax(n, axis))
],
)
def test_basic_3(self, axis, np_axis):
data = random(2, 3, 4)
n = as_tensor_variable(data)
# Test shape propagates (static & eval)
vt, it = max_and_argmax(n, axis)
np_max, np_argm = np.max(data, np_axis), np.argmax(data, np_axis)
assert vt.type.shape == np_max.shape
assert it.type.shape == np_argm.shape
v_shape, i_shape = eval_outputs([vt.shape, it.shape])
assert tuple(v_shape) == vt.type.shape
assert tuple(i_shape) == it.type.shape
# Test values
v, i = eval_outputs([vt, it])
assert i.dtype == "int64"
assert np.all(v == np.max(data, np_axis))
assert np.all(i == np.argmax(data, np_axis))
v = eval_outputs(max_and_argmax(n, axis)[0].shape)
assert tuple(v) == np.max(data, np_axis).shape
assert np.all(v == np_max)
assert np.all(i == np_argm)
def test_arg_grad(self):
# The test checks that the gradient of argmax(x).sum() is 0
......@@ -948,17 +978,19 @@ class TestMaxAndArgmax:
# Ensure the original broadcastable flags are preserved by Max/Argmax.
x = matrix().dimshuffle("x", 0, "x", 1, "x")
y = x.max(axis=1)
assert y.type.shape == (1, 1, None, 1)
assert y.type.broadcastable == (True, True, False, True)
def test_multiple_axes(self):
data = np.arange(24).reshape(3, 2, 4)
x = as_tensor_variable(data)
v, i = eval_outputs(max_and_argmax(x, [1, -1]))
vt, it = max_and_argmax(x, [1, -1])
assert vt.type.shape == it.type.shape == (3,)
v, i = eval_outputs([vt, it])
assert np.all(v == np.array([7, 15, 23]))
assert np.all(i == np.array([7, 7, 7]))
v = eval_outputs(max_and_argmax(x, [1, -1])[0].shape)
assert tuple(v) == np.max(data, (1, -1)).shape
v = eval_outputs(vt.shape)
assert tuple(v) == vt.type.shape
def test_zero_shape(self):
x = matrix()
......@@ -972,8 +1004,8 @@ class TestMaxAndArgmax:
def test_numpy_input(self):
ar = np.array([1, 2, 3])
max_at, argmax_at = max_and_argmax(ar, axis=None)
assert max_at.eval(), 3
assert argmax_at.eval(), 2
assert max_at.eval() == 3
assert argmax_at.eval() == 2
class TestArgminArgmax:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论