提交 17c4dbdd authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Added new utility function 'split'

上级 921d129c
...@@ -1358,6 +1358,10 @@ class SetSubtensor(Subtensor): ...@@ -1358,6 +1358,10 @@ class SetSubtensor(Subtensor):
x.__setitem__(cdata, y) x.__setitem__(cdata, y)
out[0] = x out[0] = x
def split(x, splits_size, n_splits, axis=0):
the_split = Split(n_splits)
return the_split(x, axis, splits_size)
class Split(Op): class Split(Op):
"""Partition a `TensorResult` along some axis. """Partition a `TensorResult` along some axis.
...@@ -1366,9 +1370,9 @@ class Split(Op): ...@@ -1366,9 +1370,9 @@ class Split(Op):
x = vector() x = vector()
splits = lvector() splits = lvector()
# you have to declare right away how many split_points there will be. # you have to declare right away how many split_points there will be.
ra, rb, rc = split(x, axis=0, points=splits, n_splits=3) ra, rb, rc = split(x, splits, n_splits = 3, axis = 0)
f = compile([x, splits], [ra, rb, rc]) f = function([x, splits], [ra, rb, rc])
a, b, c = f([0,1,2,3,4,5,6], [3, 2, 1]) a, b, c = f([0,1,2,3,4,5,6], [3, 2, 1])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论