提交 0758de4b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Sort and Argsort: Check axis are integers

上级 22cda11a
......@@ -42,13 +42,17 @@ class SortOp(Op):
def make_node(self, input, axis=-1):
input = as_tensor_variable(input)
axis = as_tensor_variable(axis, ndim=0, dtype=int)
if axis.type.numpy_dtype.kind != "i":
raise ValueError(
f"Sort axis must have an integer dtype, got {axis.type.dtype}"
)
out_type = input.type()
return Apply(self, [input, axis], [out_type])
def perform(self, node, inputs, output_storage):
a, axis = inputs
z = output_storage[0]
z[0] = np.sort(a, int(axis), self.kind)
z[0] = np.sort(a, axis, self.kind)
def infer_shape(self, fgraph, node, inputs_shapes):
assert node.inputs[0].ndim == node.outputs[0].ndim
......@@ -163,6 +167,10 @@ class ArgSortOp(Op):
def make_node(self, input, axis=-1):
input = as_tensor_variable(input)
axis = as_tensor_variable(axis, ndim=0, dtype=int)
if axis.type.numpy_dtype.kind != "i":
raise ValueError(
f"ArgSort axis must have an integer dtype, got {axis.type.dtype}"
)
return Apply(
self,
[input, axis],
......@@ -173,7 +181,7 @@ class ArgSortOp(Op):
a, axis = inputs
z = output_storage[0]
z[0] = np.asarray(
np.argsort(a, int(axis), self.kind),
np.argsort(a, axis, self.kind),
dtype=node.outputs[0].dtype,
)
......
......@@ -7,6 +7,7 @@ from pytensor.tensor.type import (
dmatrix,
dvector,
float_dtypes,
fscalar,
integer_dtypes,
lscalar,
matrix,
......@@ -31,6 +32,12 @@ class TestSort:
self.m_val = self.rng.random((3, 2))
self.v_val = self.rng.random(4)
def test_invalid_axis_dtype(self):
with pytest.raises(
ValueError, match="Sort axis must have an integer dtype, got float32"
):
sort(dmatrix(), fscalar())
def test1(self):
a = dmatrix()
w = sort(a)
......@@ -39,7 +46,7 @@ class TestSort:
def test2(self):
a = dmatrix()
axis = scalar()
axis = scalar(dtype="int64")
w = sort(a, axis)
f = pytensor.function([a, axis], w)
for axis_val in 0, 1:
......@@ -57,12 +64,12 @@ class TestSort:
def test4(self):
a = dmatrix()
axis = scalar()
axis = scalar(dtype="int8")
l = sort(a, axis, "mergesort")
f = pytensor.function([a, axis], l)
for axis_val in 0, 1:
gv = f(self.m_val, axis_val)
gt = np.sort(self.m_val, axis_val)
gv = f(self.m_val, np.array(axis_val, dtype="int8"))
gt = np.sort(self.m_val, np.array(axis_val, dtype="int8"))
utt.assert_allclose(gv, gt)
def test5(self):
......@@ -199,12 +206,12 @@ def test_argsort():
# Example 4
a = dmatrix()
axis = lscalar()
axis = scalar(dtype="int8")
l = argsort(a, axis, "mergesort")
f = pytensor.function([a, axis], l)
for axis_val in 0, 1:
gv = f(m_val, axis_val)
gt = np.argsort(m_val, axis_val)
gv = f(m_val, np.array(axis_val, dtype="int8"))
gt = np.argsort(m_val, np.array(axis_val, dtype="int8"))
utt.assert_allclose(gv, gt)
# Example 5
......@@ -222,6 +229,11 @@ def test_argsort():
gt = np.argsort(m_val, None)
utt.assert_allclose(gv, gt)
with pytest.raises(
ValueError, match="ArgSort axis must have an integer dtype, got float32"
):
argsort(dmatrix(), fscalar())
def test_argsort_grad():
rng = np.random.default_rng(seed=utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论