提交 b721b669 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in shape inference of AdvancedSubtensor with slices

上级 773da102
...@@ -59,6 +59,7 @@ from pytensor.tensor.type import ( ...@@ -59,6 +59,7 @@ from pytensor.tensor.type import (
zscalar, zscalar,
) )
from pytensor.tensor.type_other import ( from pytensor.tensor.type_other import (
MakeSlice,
NoneConst, NoneConst,
NoneTypeT, NoneTypeT,
SliceConstant, SliceConstant,
...@@ -527,11 +528,20 @@ def basic_shape(shape, indices): ...@@ -527,11 +528,20 @@ def basic_shape(shape, indices):
if isinstance(idx, slice): if isinstance(idx, slice):
res_shape += (slice_len(idx, n),) res_shape += (slice_len(idx, n),)
elif isinstance(getattr(idx, "type", None), SliceType): elif isinstance(getattr(idx, "type", None), SliceType):
if idx.owner: if idx.owner is None:
idx_inputs = idx.owner.inputs if not isinstance(idx, Constant):
# This is an input slice, we can't reason symbolically on it.
# We don't even know if we will get None entries or integers
res_shape += (None,)
continue
else:
sl: slice = idx.data
slice_inputs = (sl.start, sl.stop, sl.step)
elif isinstance(idx.owner.op, MakeSlice):
slice_inputs = idx.owner.inputs
else: else:
idx_inputs = (None,) raise ValueError(f"Unexpected Slice producing Op {idx.owner.op}")
res_shape += (slice_len(slice(*idx_inputs), n),) res_shape += (slice_len(slice(*slice_inputs), n),)
elif idx is None: elif idx is None:
res_shape += (ps.ScalarConstant(ps.int64, 1),) res_shape += (ps.ScalarConstant(ps.int64, 1),)
elif isinstance(getattr(idx, "type", None), NoneTypeT): elif isinstance(getattr(idx, "type", None), NoneTypeT):
...@@ -2728,6 +2738,11 @@ class AdvancedSubtensor(Op): ...@@ -2728,6 +2738,11 @@ class AdvancedSubtensor(Op):
res_shape = list( res_shape = list(
indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True) indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True)
) )
for i, res_dim_length in enumerate(res_shape):
if res_dim_length is None:
# This can happen when we have a Slice provided by the user (not a constant nor the result of MakeSlice)
# We must compute the Op to find its shape
res_shape[i] = Shape_i(i)(node.out)
adv_indices = [idx for idx in indices if not is_basic_idx(idx)] adv_indices = [idx for idx in indices if not is_basic_idx(idx)]
bool_indices = [idx for idx in adv_indices if is_bool_index(idx)] bool_indices = [idx for idx in adv_indices if is_bool_index(idx)]
......
...@@ -15,6 +15,7 @@ from pytensor.compile.io import In ...@@ -15,6 +15,7 @@ from pytensor.compile.io import In
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph import Constant
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.utils import is_same_graph from pytensor.graph.rewriting.utils import is_same_graph
from pytensor.printing import pprint from pytensor.printing import pprint
...@@ -37,6 +38,7 @@ from pytensor.tensor.subtensor import ( ...@@ -37,6 +38,7 @@ from pytensor.tensor.subtensor import (
advanced_inc_subtensor1, advanced_inc_subtensor1,
advanced_set_subtensor, advanced_set_subtensor,
advanced_set_subtensor1, advanced_set_subtensor1,
advanced_subtensor,
advanced_subtensor1, advanced_subtensor1,
as_index_literal, as_index_literal,
basic_shape, basic_shape,
...@@ -2145,7 +2147,17 @@ class TestAdvancedSubtensor: ...@@ -2145,7 +2147,17 @@ class TestAdvancedSubtensor:
slc = slicetype() slc = slicetype()
f = pytensor.function([slc], var[slc], mode=self.mode) f = pytensor.function([slc], var[slc], mode=self.mode)
s = slice(1, 3) s = slice(1, 3)
f(s) assert f(s).shape == (2, 3)
f_shape0 = pytensor.function([slc], var[slc].shape[0], mode=self.mode)
assert f_shape0(s) == 2
f_shape1 = pytensor.function([slc], var[slc].shape[1], mode=self.mode)
assert not any(
isinstance(node.op, AdvancedSubtensor)
for node in f_shape1.maker.fgraph.toposort()
)
assert f_shape1(s) == 3
def test_adv_grouped(self): def test_adv_grouped(self):
# Reported in https://github.com/Theano/Theano/issues/6152 # Reported in https://github.com/Theano/Theano/issues/6152
...@@ -2611,6 +2623,14 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2611,6 +2623,14 @@ class TestInferShape(utt.InferShapeTester):
AdvancedSubtensor, AdvancedSubtensor,
) )
def test_advanced_subtensor_constant_slice(self):
x = dmatrix("x")
constant_slice = pytensor.as_symbolic(slice(1, None, None))
assert isinstance(constant_slice, Constant)
adv_indices = ptb.constant(np.zeros((2, 3)), dtype="int")
y = advanced_subtensor(x, constant_slice, adv_indices)
assert tuple(y.shape.eval({x: np.zeros((10, 10))})) == (9, 2, 3)
@config.change_flags(compute_test_value="raise") @config.change_flags(compute_test_value="raise")
def test_basic_shape(): def test_basic_shape():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论