提交 3425b9d3 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op Split

上级 c3ce2054
......@@ -4653,6 +4653,18 @@ class Split(Op):
outputs[i][0] = x.__getitem__(general_key).copy()
lower_idx = upper_idx
def infer_shape(self, node, in_shapes):
axis = node.inputs[1]
splits = node.inputs[2]
shp_x, shp_axis, shp_splits = in_shapes
out_shapes = []
for i in range(self.len_splits):
temp = as_tensor_variable(shp_x)
temp = set_subtensor(temp[axis], splits[i])
temp = [temp[i] for i in range(len(shp_x))]
out_shapes.append(temp)
return out_shapes
def grad(self, inputs, g_outputs):
"""Join the gradients along the axis that was used to split x."""
_, axis, _ = inputs
......
......@@ -6102,6 +6102,14 @@ class TestInferShape(utt.InferShapeTester):
[Dot()(admat, bdmat)],
[admat_val, bdmat_val], (Dot, tensor.blas.Gemm,
tensor.blas.Dot22))
# Split
aivec = ivector()
adtens_val = rand(4, 10, 3)
aivec_val = [2, 5, 3]
self._compile_and_check([adtens, aiscal, aivec],
[Split(3)(adtens, aiscal, aivec)[0]],
[adtens_val, 1, aivec_val], (Split))
if __name__ == '__main__':
t = TestInferShape('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论