提交 0758de4b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Sort and Argsort: Check axis are integers

上级 22cda11a
...@@ -42,13 +42,17 @@ class SortOp(Op): ...@@ -42,13 +42,17 @@ class SortOp(Op):
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
input = as_tensor_variable(input) input = as_tensor_variable(input)
axis = as_tensor_variable(axis, ndim=0, dtype=int) axis = as_tensor_variable(axis, ndim=0, dtype=int)
if axis.type.numpy_dtype.kind != "i":
raise ValueError(
f"Sort axis must have an integer dtype, got {axis.type.dtype}"
)
out_type = input.type() out_type = input.type()
return Apply(self, [input, axis], [out_type]) return Apply(self, [input, axis], [out_type])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
a, axis = inputs a, axis = inputs
z = output_storage[0] z = output_storage[0]
z[0] = np.sort(a, int(axis), self.kind) z[0] = np.sort(a, axis, self.kind)
def infer_shape(self, fgraph, node, inputs_shapes): def infer_shape(self, fgraph, node, inputs_shapes):
assert node.inputs[0].ndim == node.outputs[0].ndim assert node.inputs[0].ndim == node.outputs[0].ndim
...@@ -163,6 +167,10 @@ class ArgSortOp(Op): ...@@ -163,6 +167,10 @@ class ArgSortOp(Op):
def make_node(self, input, axis=-1): def make_node(self, input, axis=-1):
input = as_tensor_variable(input) input = as_tensor_variable(input)
axis = as_tensor_variable(axis, ndim=0, dtype=int) axis = as_tensor_variable(axis, ndim=0, dtype=int)
if axis.type.numpy_dtype.kind != "i":
raise ValueError(
f"ArgSort axis must have an integer dtype, got {axis.type.dtype}"
)
return Apply( return Apply(
self, self,
[input, axis], [input, axis],
...@@ -173,7 +181,7 @@ class ArgSortOp(Op): ...@@ -173,7 +181,7 @@ class ArgSortOp(Op):
a, axis = inputs a, axis = inputs
z = output_storage[0] z = output_storage[0]
z[0] = np.asarray( z[0] = np.asarray(
np.argsort(a, int(axis), self.kind), np.argsort(a, axis, self.kind),
dtype=node.outputs[0].dtype, dtype=node.outputs[0].dtype,
) )
......
...@@ -7,6 +7,7 @@ from pytensor.tensor.type import ( ...@@ -7,6 +7,7 @@ from pytensor.tensor.type import (
dmatrix, dmatrix,
dvector, dvector,
float_dtypes, float_dtypes,
fscalar,
integer_dtypes, integer_dtypes,
lscalar, lscalar,
matrix, matrix,
...@@ -31,6 +32,12 @@ class TestSort: ...@@ -31,6 +32,12 @@ class TestSort:
self.m_val = self.rng.random((3, 2)) self.m_val = self.rng.random((3, 2))
self.v_val = self.rng.random(4) self.v_val = self.rng.random(4)
def test_invalid_axis_dtype(self):
with pytest.raises(
ValueError, match="Sort axis must have an integer dtype, got float32"
):
sort(dmatrix(), fscalar())
def test1(self): def test1(self):
a = dmatrix() a = dmatrix()
w = sort(a) w = sort(a)
...@@ -39,7 +46,7 @@ class TestSort: ...@@ -39,7 +46,7 @@ class TestSort:
def test2(self): def test2(self):
a = dmatrix() a = dmatrix()
axis = scalar() axis = scalar(dtype="int64")
w = sort(a, axis) w = sort(a, axis)
f = pytensor.function([a, axis], w) f = pytensor.function([a, axis], w)
for axis_val in 0, 1: for axis_val in 0, 1:
...@@ -57,12 +64,12 @@ class TestSort: ...@@ -57,12 +64,12 @@ class TestSort:
def test4(self): def test4(self):
a = dmatrix() a = dmatrix()
axis = scalar() axis = scalar(dtype="int8")
l = sort(a, axis, "mergesort") l = sort(a, axis, "mergesort")
f = pytensor.function([a, axis], l) f = pytensor.function([a, axis], l)
for axis_val in 0, 1: for axis_val in 0, 1:
gv = f(self.m_val, axis_val) gv = f(self.m_val, np.array(axis_val, dtype="int8"))
gt = np.sort(self.m_val, axis_val) gt = np.sort(self.m_val, np.array(axis_val, dtype="int8"))
utt.assert_allclose(gv, gt) utt.assert_allclose(gv, gt)
def test5(self): def test5(self):
...@@ -199,12 +206,12 @@ def test_argsort(): ...@@ -199,12 +206,12 @@ def test_argsort():
# Example 4 # Example 4
a = dmatrix() a = dmatrix()
axis = lscalar() axis = scalar(dtype="int8")
l = argsort(a, axis, "mergesort") l = argsort(a, axis, "mergesort")
f = pytensor.function([a, axis], l) f = pytensor.function([a, axis], l)
for axis_val in 0, 1: for axis_val in 0, 1:
gv = f(m_val, axis_val) gv = f(m_val, np.array(axis_val, dtype="int8"))
gt = np.argsort(m_val, axis_val) gt = np.argsort(m_val, np.array(axis_val, dtype="int8"))
utt.assert_allclose(gv, gt) utt.assert_allclose(gv, gt)
# Example 5 # Example 5
...@@ -222,6 +229,11 @@ def test_argsort(): ...@@ -222,6 +229,11 @@ def test_argsort():
gt = np.argsort(m_val, None) gt = np.argsort(m_val, None)
utt.assert_allclose(gv, gt) utt.assert_allclose(gv, gt)
with pytest.raises(
ValueError, match="ArgSort axis must have an integer dtype, got float32"
):
argsort(dmatrix(), fscalar())
def test_argsort_grad(): def test_argsort_grad():
rng = np.random.default_rng(seed=utt.fetch_seed()) rng = np.random.default_rng(seed=utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论