提交 75b7233e authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Fix error with scalar inputs in CAReduce and squeeze

上级 00e0d806
......@@ -1304,7 +1304,11 @@ class CAReduce(COp):
axis = list(range(inp_dims))
copy_op = any(a < 0 for a in axis)
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=inp_dims)
# scalar inputs are treated as 1D regarding axis in this `Op`
try:
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, inp_dims))
except np.AxisError:
raise np.AxisError(axis, ndim=inp_dims)
# We can't call self.__class__() as there is a class that
# inherits from CAReduce that doesn't have the same signature
......
......@@ -640,7 +640,11 @@ def squeeze(x, axis=None):
elif not isinstance(axis, Collection):
axis = (axis,)
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=x.ndim)
# scalar inputs are treated as 1D regarding axis in this `Op`
try:
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, x.ndim))
except np.AxisError:
raise np.AxisError(axis, ndim=x.ndim)
return x.dimshuffle([i for i in range(x.ndim) if i not in axis])
......
import math
import re
import tracemalloc
from copy import copy
......@@ -638,6 +639,17 @@ class TestCAReduce(unittest_tools.InferShapeTester):
with pytest.raises(ValueError, match="repeated axis"):
self.op(aes.add, axis=(0, 0))(x)
def test_scalar_input(self):
x = scalar("x")
assert self.op(aes.add, axis=(-1,))(x).eval({x: 5}) == 5
with pytest.raises(
np.AxisError,
match=re.escape("axis (-2,) is out of bounds for array of dimension 0"),
):
self.op(aes.add, axis=(-2,))(x)
class TestBitOpReduceGrad:
def setup_method(self):
......
import re
import numpy as np
import pytest
......@@ -448,6 +450,17 @@ class TestSqueeze(utt.InferShapeTester):
):
squeeze(variable, axis=1)
def test_scalar_input(self):
x = at.scalar("x")
assert squeeze(x, axis=(0,)).eval({x: 5}) == 5
with pytest.raises(
np.AxisError,
match=re.escape("axis (1,) is out of bounds for array of dimension 0"),
):
squeeze(x, axis=1)
class TestCompress(utt.InferShapeTester):
axis_list = [None, -1, 0, 0, 0, 1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论