提交 75b7233e authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Fix error with scalar inputs in CAReduce and squeeze

上级 00e0d806
...@@ -1304,7 +1304,11 @@ class CAReduce(COp): ...@@ -1304,7 +1304,11 @@ class CAReduce(COp):
axis = list(range(inp_dims)) axis = list(range(inp_dims))
copy_op = any(a < 0 for a in axis) copy_op = any(a < 0 for a in axis)
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=inp_dims) # scalar inputs are treated as 1D regarding axis in this `Op`
try:
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, inp_dims))
except np.AxisError:
raise np.AxisError(axis, ndim=inp_dims)
# We can't call self.__class__() as there is a class that # We can't call self.__class__() as there is a class that
# inherits from CAReduce that doesn't have the same signature # inherits from CAReduce that doesn't have the same signature
......
...@@ -640,7 +640,11 @@ def squeeze(x, axis=None): ...@@ -640,7 +640,11 @@ def squeeze(x, axis=None):
elif not isinstance(axis, Collection): elif not isinstance(axis, Collection):
axis = (axis,) axis = (axis,)
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=x.ndim) # scalar inputs are treated as 1D regarding axis in this `Op`
try:
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, x.ndim))
except np.AxisError:
raise np.AxisError(axis, ndim=x.ndim)
return x.dimshuffle([i for i in range(x.ndim) if i not in axis]) return x.dimshuffle([i for i in range(x.ndim) if i not in axis])
......
import math import math
import re
import tracemalloc import tracemalloc
from copy import copy from copy import copy
...@@ -638,6 +639,17 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -638,6 +639,17 @@ class TestCAReduce(unittest_tools.InferShapeTester):
with pytest.raises(ValueError, match="repeated axis"): with pytest.raises(ValueError, match="repeated axis"):
self.op(aes.add, axis=(0, 0))(x) self.op(aes.add, axis=(0, 0))(x)
def test_scalar_input(self):
x = scalar("x")
assert self.op(aes.add, axis=(-1,))(x).eval({x: 5}) == 5
with pytest.raises(
np.AxisError,
match=re.escape("axis (-2,) is out of bounds for array of dimension 0"),
):
self.op(aes.add, axis=(-2,))(x)
class TestBitOpReduceGrad: class TestBitOpReduceGrad:
def setup_method(self): def setup_method(self):
......
import re
import numpy as np import numpy as np
import pytest import pytest
...@@ -448,6 +450,17 @@ class TestSqueeze(utt.InferShapeTester): ...@@ -448,6 +450,17 @@ class TestSqueeze(utt.InferShapeTester):
): ):
squeeze(variable, axis=1) squeeze(variable, axis=1)
def test_scalar_input(self):
x = at.scalar("x")
assert squeeze(x, axis=(0,)).eval({x: 5}) == 5
with pytest.raises(
np.AxisError,
match=re.escape("axis (1,) is out of bounds for array of dimension 0"),
):
squeeze(x, axis=1)
class TestCompress(utt.InferShapeTester): class TestCompress(utt.InferShapeTester):
axis_list = [None, -1, 0, 0, 0, 1] axis_list = [None, -1, 0, 0, 0, 1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论