提交 72e67d83 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Brandon T. Willard

Check Split axis is a scalar

上级 e5373112
......@@ -1892,7 +1892,7 @@ class Split(COp):
if splits.type.ndim == 1 and splits.type.dtype not in integer_dtypes:
raise TypeError("`splits` parameter must be tensors of integer type")
if axis.type.dtype not in integer_dtypes:
if axis.type.dtype not in integer_dtypes or axis.ndim != 0:
raise TypeError("`axis` parameter must be an integer scalar")
inputs = [x, axis, splits]
......
......@@ -1213,6 +1213,9 @@ class TestJoinAndSplit:
with pytest.raises(TypeError, match=".*integer.*"):
Split(2)(matrix(), dscalar(), [1, 1])
with pytest.raises(TypeError, match=".*integer.*"):
Split(2)(matrix(), ivector(), [1, 1])
with pytest.raises(TypeError, match=".*integer.*"):
join(dscalar(), matrix(), matrix())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论