提交 b721b669 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in shape inference of AdvancedSubtensor with slices

上级 773da102
......@@ -59,6 +59,7 @@ from pytensor.tensor.type import (
zscalar,
)
from pytensor.tensor.type_other import (
MakeSlice,
NoneConst,
NoneTypeT,
SliceConstant,
......@@ -527,11 +528,20 @@ def basic_shape(shape, indices):
if isinstance(idx, slice):
res_shape += (slice_len(idx, n),)
elif isinstance(getattr(idx, "type", None), SliceType):
if idx.owner:
idx_inputs = idx.owner.inputs
if idx.owner is None:
if not isinstance(idx, Constant):
# This is an input slice, we can't reason symbolically on it.
# We don't even know if we will get None entries or integers
res_shape += (None,)
continue
else:
sl: slice = idx.data
slice_inputs = (sl.start, sl.stop, sl.step)
elif isinstance(idx.owner.op, MakeSlice):
slice_inputs = idx.owner.inputs
else:
idx_inputs = (None,)
res_shape += (slice_len(slice(*idx_inputs), n),)
raise ValueError(f"Unexpected Slice producing Op {idx.owner.op}")
res_shape += (slice_len(slice(*slice_inputs), n),)
elif idx is None:
res_shape += (ps.ScalarConstant(ps.int64, 1),)
elif isinstance(getattr(idx, "type", None), NoneTypeT):
......@@ -2728,6 +2738,11 @@ class AdvancedSubtensor(Op):
res_shape = list(
indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True)
)
for i, res_dim_length in enumerate(res_shape):
if res_dim_length is None:
# This can happen when we have a Slice provided by the user (not a constant nor the result of MakeSlice)
# We must compute the Op to find its shape
res_shape[i] = Shape_i(i)(node.out)
adv_indices = [idx for idx in indices if not is_basic_idx(idx)]
bool_indices = [idx for idx in adv_indices if is_bool_index(idx)]
......
......@@ -15,6 +15,7 @@ from pytensor.compile.io import In
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.gradient import grad
from pytensor.graph import Constant
from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.utils import is_same_graph
from pytensor.printing import pprint
......@@ -37,6 +38,7 @@ from pytensor.tensor.subtensor import (
advanced_inc_subtensor1,
advanced_set_subtensor,
advanced_set_subtensor1,
advanced_subtensor,
advanced_subtensor1,
as_index_literal,
basic_shape,
......@@ -2145,7 +2147,17 @@ class TestAdvancedSubtensor:
slc = slicetype()
f = pytensor.function([slc], var[slc], mode=self.mode)
s = slice(1, 3)
f(s)
assert f(s).shape == (2, 3)
f_shape0 = pytensor.function([slc], var[slc].shape[0], mode=self.mode)
assert f_shape0(s) == 2
f_shape1 = pytensor.function([slc], var[slc].shape[1], mode=self.mode)
assert not any(
isinstance(node.op, AdvancedSubtensor)
for node in f_shape1.maker.fgraph.toposort()
)
assert f_shape1(s) == 3
def test_adv_grouped(self):
# Reported in https://github.com/Theano/Theano/issues/6152
......@@ -2611,6 +2623,14 @@ class TestInferShape(utt.InferShapeTester):
AdvancedSubtensor,
)
def test_advanced_subtensor_constant_slice(self):
x = dmatrix("x")
constant_slice = pytensor.as_symbolic(slice(1, None, None))
assert isinstance(constant_slice, Constant)
adv_indices = ptb.constant(np.zeros((2, 3)), dtype="int")
y = advanced_subtensor(x, constant_slice, adv_indices)
assert tuple(y.shape.eval({x: np.zeros((10, 10))})) == (9, 2, 3)
@config.change_flags(compute_test_value="raise")
def test_basic_shape():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论