提交 7ede010a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Eager check for invalid ndim in TensorType

上级 611af883
......@@ -120,7 +120,12 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
self.shape = _shape = tuple(parse_bcast_and_shape(s) for s in shape)
self.broadcastable = tuple(s == 1 for s in _shape)
self.ndim = len(_shape)
self.ndim = _ndim = len(_shape)
if _ndim > 64:
# Message mimicks that of numpy
raise ValueError(
f"maximum supported dimension for a TensorType is currently 64, found {_ndim}"
)
self.dtype_specs() # error checking is done there
self.name = name
self.numpy_dtype = np.dtype(self.dtype)
......
......@@ -181,12 +181,12 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
)
def test_too_big_rank(self):
numpy_maxdims = 64
x = self.type(self.dtype, shape=())()
y = x.dimshuffle(("x",) * (numpy_maxdims + 1))
with pytest.raises((ValueError, SystemError)):
y.eval({x: 0})
with pytest.raises(
ValueError,
match="maximum supported dimension for a TensorType is currently 64, found 65",
):
x.dimshuffle(("x",) * 65)
def test_c_views(self):
x_pt = vector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论