提交 de6aca85 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Infer n_splits in split helper

上级 3a7f2141
......@@ -2176,9 +2176,13 @@ def matrix_transpose(x: "TensorLike") -> TensorVariable:
return swapaxes(x, -1, -2)
def split(x, splits_size, n_splits, axis=0):
the_split = Split(n_splits)
return the_split(x, axis, splits_size)
def split(x, splits_size, *, n_splits=None, axis=0):
if n_splits is None:
if isinstance(splits_size, Variable):
n_splits = get_vector_length(splits_size)
else:
n_splits = len(splits_size)
return Split(n_splits)(x, axis, splits_size)
class Split(COp):
......
......@@ -116,8 +116,8 @@ class Fourier(Op):
l = len(shape_a)
shape_a = stack(shape_a)
out_shape = concatenate((shape_a[0:axis], [n], shape_a[axis + 1 :]))
n_splits = [1] * l
out_shape = split(out_shape, n_splits, l)
splits = [1] * l
out_shape = split(out_shape, splits, n_splits=l)
out_shape = [a[0] for a in out_shape]
return [out_shape]
......
......@@ -148,7 +148,7 @@ def test_empty_dynamic_shape():
def test_split_const_axis_const_splits_compiled():
x = pt.vector("x")
splits = [2, 3]
outs = pt.split(x, splits, len(splits), axis=0)
outs = pt.split(x, splits, n_splits=len(splits), axis=0)
compare_mlx_and_py([x], outs, [np.arange(5, dtype="float32")])
......@@ -156,7 +156,7 @@ def test_split_dynamic_axis_const_splits():
x = pt.matrix("x")
axis = pt.scalar("axis", dtype="int64")
splits = [1, 2, 3]
outs = pt.split(x, splits, len(splits), axis=axis)
outs = pt.split(x, splits, n_splits=len(splits), axis=axis)
test_input = np.arange(12).astype(config.floatX).reshape(2, 6)
......
......@@ -209,7 +209,7 @@ def test_Join(vals, axis):
def test_Split(n_splits, axis, values, sizes):
values, values_test = values
sizes, sizes_test = sizes
g = pt.split(values, sizes, n_splits, axis=axis)
g = pt.split(values, sizes, n_splits=n_splits, axis=axis)
assert len(g) == n_splits
if n_splits == 0:
return
......
......@@ -518,7 +518,7 @@ rng = np.random.default_rng(42849)
def test_Split(n_splits, axis, values, sizes):
i = pt.tensor("i", shape=values.shape, dtype=config.floatX)
s = pt.vector("s", dtype="int64")
g = pt.split(i, s, n_splits, axis=axis)
g = pt.split(i, s, n_splits=n_splits, axis=axis)
assert len(g) == n_splits
if n_splits == 0:
return
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论