提交 de6aca85 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Infer n_splits in split helper

上级 3a7f2141
...@@ -2176,9 +2176,13 @@ def matrix_transpose(x: "TensorLike") -> TensorVariable: ...@@ -2176,9 +2176,13 @@ def matrix_transpose(x: "TensorLike") -> TensorVariable:
return swapaxes(x, -1, -2) return swapaxes(x, -1, -2)
def split(x, splits_size, n_splits, axis=0): def split(x, splits_size, *, n_splits=None, axis=0):
the_split = Split(n_splits) if n_splits is None:
return the_split(x, axis, splits_size) if isinstance(splits_size, Variable):
n_splits = get_vector_length(splits_size)
else:
n_splits = len(splits_size)
return Split(n_splits)(x, axis, splits_size)
class Split(COp): class Split(COp):
......
...@@ -116,8 +116,8 @@ class Fourier(Op): ...@@ -116,8 +116,8 @@ class Fourier(Op):
l = len(shape_a) l = len(shape_a)
shape_a = stack(shape_a) shape_a = stack(shape_a)
out_shape = concatenate((shape_a[0:axis], [n], shape_a[axis + 1 :])) out_shape = concatenate((shape_a[0:axis], [n], shape_a[axis + 1 :]))
n_splits = [1] * l splits = [1] * l
out_shape = split(out_shape, n_splits, l) out_shape = split(out_shape, splits, n_splits=l)
out_shape = [a[0] for a in out_shape] out_shape = [a[0] for a in out_shape]
return [out_shape] return [out_shape]
......
...@@ -148,7 +148,7 @@ def test_empty_dynamic_shape(): ...@@ -148,7 +148,7 @@ def test_empty_dynamic_shape():
def test_split_const_axis_const_splits_compiled(): def test_split_const_axis_const_splits_compiled():
x = pt.vector("x") x = pt.vector("x")
splits = [2, 3] splits = [2, 3]
outs = pt.split(x, splits, len(splits), axis=0) outs = pt.split(x, splits, n_splits=len(splits), axis=0)
compare_mlx_and_py([x], outs, [np.arange(5, dtype="float32")]) compare_mlx_and_py([x], outs, [np.arange(5, dtype="float32")])
...@@ -156,7 +156,7 @@ def test_split_dynamic_axis_const_splits(): ...@@ -156,7 +156,7 @@ def test_split_dynamic_axis_const_splits():
x = pt.matrix("x") x = pt.matrix("x")
axis = pt.scalar("axis", dtype="int64") axis = pt.scalar("axis", dtype="int64")
splits = [1, 2, 3] splits = [1, 2, 3]
outs = pt.split(x, splits, len(splits), axis=axis) outs = pt.split(x, splits, n_splits=len(splits), axis=axis)
test_input = np.arange(12).astype(config.floatX).reshape(2, 6) test_input = np.arange(12).astype(config.floatX).reshape(2, 6)
......
...@@ -209,7 +209,7 @@ def test_Join(vals, axis): ...@@ -209,7 +209,7 @@ def test_Join(vals, axis):
def test_Split(n_splits, axis, values, sizes): def test_Split(n_splits, axis, values, sizes):
values, values_test = values values, values_test = values
sizes, sizes_test = sizes sizes, sizes_test = sizes
g = pt.split(values, sizes, n_splits, axis=axis) g = pt.split(values, sizes, n_splits=n_splits, axis=axis)
assert len(g) == n_splits assert len(g) == n_splits
if n_splits == 0: if n_splits == 0:
return return
......
...@@ -518,7 +518,7 @@ rng = np.random.default_rng(42849) ...@@ -518,7 +518,7 @@ rng = np.random.default_rng(42849)
def test_Split(n_splits, axis, values, sizes): def test_Split(n_splits, axis, values, sizes):
i = pt.tensor("i", shape=values.shape, dtype=config.floatX) i = pt.tensor("i", shape=values.shape, dtype=config.floatX)
s = pt.vector("s", dtype="int64") s = pt.vector("s", dtype="int64")
g = pt.split(i, s, n_splits, axis=axis) g = pt.split(i, s, n_splits=n_splits, axis=axis)
assert len(g) == n_splits assert len(g) == n_splits
if n_splits == 0: if n_splits == 0:
return return
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论