提交 7ede010a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Eager check for invalid ndim in TensorType

上级 611af883
...@@ -120,7 +120,12 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -120,7 +120,12 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
self.shape = _shape = tuple(parse_bcast_and_shape(s) for s in shape) self.shape = _shape = tuple(parse_bcast_and_shape(s) for s in shape)
self.broadcastable = tuple(s == 1 for s in _shape) self.broadcastable = tuple(s == 1 for s in _shape)
self.ndim = len(_shape) self.ndim = _ndim = len(_shape)
if _ndim > 64:
# Message mimicks that of numpy
raise ValueError(
f"maximum supported dimension for a TensorType is currently 64, found {_ndim}"
)
self.dtype_specs() # error checking is done there self.dtype_specs() # error checking is done there
self.name = name self.name = name
self.numpy_dtype = np.dtype(self.dtype) self.numpy_dtype = np.dtype(self.dtype)
......
...@@ -181,12 +181,12 @@ class TestDimShuffle(unittest_tools.InferShapeTester): ...@@ -181,12 +181,12 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
) )
def test_too_big_rank(self): def test_too_big_rank(self):
numpy_maxdims = 64
x = self.type(self.dtype, shape=())() x = self.type(self.dtype, shape=())()
y = x.dimshuffle(("x",) * (numpy_maxdims + 1)) with pytest.raises(
ValueError,
with pytest.raises((ValueError, SystemError)): match="maximum supported dimension for a TensorType is currently 64, found 65",
y.eval({x: 0}) ):
x.dimshuffle(("x",) * 65)
def test_c_views(self): def test_c_views(self):
x_pt = vector() x_pt = vector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论