提交 8f96d930 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Validate axis in `CAReduce.make_node`

Closes #677
上级 39a455d3
......@@ -1298,21 +1298,13 @@ class CAReduce(COp):
inp_dims = input.type.ndim
inp_bdcast = input.type.broadcastable
inp_dtype = input.type.dtype
copy_op = False
axis = self.axis
if axis is None:
axis = list(range(len(inp_bdcast)))
axis = list(range(inp_dims))
axis = list(axis)
for i, a in enumerate(axis):
if a >= inp_dims or a < -inp_dims:
raise ValueError(
f"Not enough dimensions on {input} to reduce on axis {a}"
)
if a < 0:
copy_op = True
axis[i] = a + inp_dims
copy_op = any(a < 0 for a in axis)
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=inp_dims)
# We can't call self.__class__() as there is a class that
# inherits from CAReduce that doesn't have the same signature
......
......@@ -633,6 +633,11 @@ class TestCAReduce(unittest_tools.InferShapeTester):
op = CAReduceDtype(aes.add, axis=(1,), acc_dtype="float64")
assert str(op) == "CAReduceDtype{add}{axis=[1], acc_dtype=float64}"
def test_repeated_axis(self):
x = vector("x")
with pytest.raises(ValueError, match="repeated axis"):
self.op(aes.add, axis=(0, 0))(x)
class TestBitOpReduceGrad:
def setup_method(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论