提交 734009ae authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Improve static output shape of AdvancedSubtensor1

上级 3db127e9
......@@ -1992,8 +1992,7 @@ class AdvancedSubtensor1(COp):
raise TypeError("index must be vector")
if x_.type.ndim == 0:
raise TypeError("cannot index into a scalar")
out_shape = (ilist_.type.shape[0],) + x_.type.shape[1:]
out_shape = tuple(1 if s == 1 else None for s in out_shape)
out_shape = (ilist_.type.shape[0], *x_.type.shape[1:])
return Apply(self, [x_, ilist_], [TensorType(dtype=x.dtype, shape=out_shape)()])
def perform(self, node, inp, out_):
......
......@@ -34,6 +34,7 @@ from pytensor.tensor.subtensor import (
advanced_inc_subtensor1,
advanced_set_subtensor,
advanced_set_subtensor1,
advanced_subtensor1,
as_index_literal,
basic_shape,
get_canonical_form_slice,
......@@ -2707,12 +2708,26 @@ def test_index_vars_to_types():
[(7, 13), (slice(None, None, 2), slice(-1, 1, -1)), (4, 11)],
],
)
def test_static_shapes(x_shape, indices, expected):
def test_subtensor_static_shapes(x_shape, indices, expected):
x = ptb.tensor(dtype="float64", shape=x_shape)
y = x[indices]
assert y.type.shape == expected
@pytest.mark.parametrize(
"x_shape, indices, expected",
[
[(None, 5, None, 3), vector(shape=(1,)), (1, 5, None, 3)],
[(None, 5, None, 3), vector(shape=(2,)), (2, 5, None, 3)],
[(None, 5, None, 3), vector(shape=(None,)), (None, 5, None, 3)],
],
)
def test_advanced_subtensor1_static_shapes(x_shape, indices, expected):
x = ptb.tensor(dtype="float64", shape=x_shape)
y = advanced_subtensor1(x, indices.astype(int))
assert y.type.shape == expected
def test_vectorize_subtensor_without_batch_indices():
signature = "(t1,t2,t3),()->(t1,t3)"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论