提交 cd3a3ce2 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Brandon T. Willard

Fix Split output type

上级 72e67d83
......@@ -1896,7 +1896,8 @@ class Split(COp):
raise TypeError("`axis` parameter must be an integer scalar")
inputs = [x, axis, splits]
outputs = [x.type() for i in range(self.len_splits)]
out_type = TensorType(dtype=x.dtype, shape=[None] * x.type.ndim)
outputs = [out_type() for i in range(self.len_splits)]
return Apply(self, inputs, outputs)
......
......@@ -1937,6 +1937,12 @@ class TestJoinAndSplit:
with pytest.raises(ValueError):
f()
def test_split_static_shape(self):
x = TensorType("floatX", shape=(5,))("x")
s = iscalar("s")
y = Split(2)(x, 0, [s, 5 - s])[0]
assert y.type.shape == (None,)
def test_join_inplace():
# Test join to work inplace.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论