提交 d50db11c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Provide static shape in output of Split

上级 5308dddb
...@@ -2201,8 +2201,28 @@ class Split(COp): ...@@ -2201,8 +2201,28 @@ class Split(COp):
raise TypeError("`axis` parameter must be an integer scalar") raise TypeError("`axis` parameter must be an integer scalar")
inputs = [x, axis, splits] inputs = [x, axis, splits]
out_type = TensorType(dtype=x.dtype, shape=(None,) * x.type.ndim)
outputs = [out_type() for i in range(self.len_splits)] x_dtype = x.type.dtype
if isinstance(axis, Constant):
# In this case we can preserve more static shape info
static_axis = axis.data.item()
outputs = []
x_static_shape = list(x.type.shape)
for i in range(self.len_splits):
try:
static_split_size = int(get_scalar_constant_value(splits[i]))
except NotScalarConstantError:
static_split_size = None
except IndexError:
raise ValueError("Number of splits is larger than splits size")
static_out_shape = x_static_shape.copy()
static_out_shape[static_axis] = static_split_size
outputs.append(tensor(shape=tuple(static_out_shape), dtype=x_dtype))
else:
outputs = [
tensor(shape=(None,) * x.type.ndim, dtype=x_dtype)
for i in range(self.len_splits)
]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
......
...@@ -150,12 +150,14 @@ class TestJaxSplit: ...@@ -150,12 +150,14 @@ class TestJaxSplit:
): ):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX)) fn(np.zeros((6, 4), dtype=pytensor.config.floatX))
a_splits = ptb.split(a, splits_size=[2, 4], n_splits=3, axis=0) # This check is triggered at compile time if splits_size has incompatible static length
fn = pytensor.function([a], a_splits, mode="JAX") splits_size = vector("splits_size", shape=(None,), dtype=int)
a_splits = ptb.split(a, splits_size=splits_size, n_splits=3, axis=0)
fn = pytensor.function([a, splits_size], a_splits, mode="JAX")
with pytest.raises( with pytest.raises(
ValueError, match="Length of splits is not equal to n_splits" ValueError, match="Length of splits is not equal to n_splits"
): ):
fn(np.zeros((6, 4), dtype=pytensor.config.floatX)) fn(np.zeros((6, 4), dtype=pytensor.config.floatX), [2, 2])
a_splits = ptb.split(a, splits_size=[2, 4], n_splits=2, axis=0) a_splits = ptb.split(a, splits_size=[2, 4], n_splits=2, axis=0)
fn = pytensor.function([a], a_splits, mode="JAX") fn = pytensor.function([a], a_splits, mode="JAX")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论