提交 72e67d83 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Brandon T. Willard

Check Split axis is a scalar

上级 e5373112
...@@ -1892,7 +1892,7 @@ class Split(COp): ...@@ -1892,7 +1892,7 @@ class Split(COp):
if splits.type.ndim == 1 and splits.type.dtype not in integer_dtypes: if splits.type.ndim == 1 and splits.type.dtype not in integer_dtypes:
raise TypeError("`splits` parameter must be tensors of integer type") raise TypeError("`splits` parameter must be tensors of integer type")
if axis.type.dtype not in integer_dtypes: if axis.type.dtype not in integer_dtypes or axis.ndim != 0:
raise TypeError("`axis` parameter must be an integer scalar") raise TypeError("`axis` parameter must be an integer scalar")
inputs = [x, axis, splits] inputs = [x, axis, splits]
......
...@@ -1213,6 +1213,9 @@ class TestJoinAndSplit: ...@@ -1213,6 +1213,9 @@ class TestJoinAndSplit:
with pytest.raises(TypeError, match=".*integer.*"): with pytest.raises(TypeError, match=".*integer.*"):
Split(2)(matrix(), dscalar(), [1, 1]) Split(2)(matrix(), dscalar(), [1, 1])
with pytest.raises(TypeError, match=".*integer.*"):
Split(2)(matrix(), ivector(), [1, 1])
with pytest.raises(TypeError, match=".*integer.*"): with pytest.raises(TypeError, match=".*integer.*"):
join(dscalar(), matrix(), matrix()) join(dscalar(), matrix(), matrix())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论