提交 790b46fd authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Ricardo Vieira

enh: Improve static shape of subtensor

上级 1d9fa843
......@@ -221,6 +221,16 @@ def get_canonical_form_slice(
step, is_step_constant = analyze(theslice.step)
length, is_length_constant = analyze(length)
if (
is_start_constant
and is_stop_constant
and is_step_constant
and is_length_constant
):
_start, _stop, _step = slice(start, stop, step).indices(length)
if _start <= _stop and _step >= 1:
return slice(_start, _stop, _step), 1
if step is None:
step = 1
is_step_constant = True
......@@ -722,32 +732,51 @@ class Subtensor(COp):
f"Incompatible types for Subtensor template. Expected {input.type}, got {expected_type}."
)
# infer the broadcasting pattern
padded = get_constant_idx(
self.idx_list, (None,) + inputs, allow_partial=True
) + [slice(None, None, None)] * (x.type.ndim - len(idx_list))
padded = [
*get_idx_list((None,) + inputs, self.idx_list),
*[slice(None, None, None)] * (x.type.ndim - len(idx_list)),
]
out_shape = []
for i, (p, s) in enumerate(zip(padded, x.type.shape)):
if isinstance(p, slice):
if s == 1:
start = p.start
def extract_const(value):
if value is None:
return value, True
try:
start = get_underlying_scalar_constant_value(start)
value = get_underlying_scalar_constant_value(value)
return value, True
except NotScalarConstantError:
pass
if start is None or start == 0:
start = p.start
if start is None:
start = 0
if p.stop is None or (
isinstance(p.stop, (int, np.integer, np.ndarray))
and p.stop > start
):
out_shape.append(1)
return value, False
for the_slice, length in zip(padded, x.type.shape):
if not isinstance(the_slice, slice):
continue
if length is None:
out_shape.append(None)
continue
start = the_slice.start
stop = the_slice.stop
step = the_slice.step
is_slice_const = True
start, is_const = extract_const(start)
is_slice_const = is_slice_const and is_const
stop, is_const = extract_const(stop)
is_slice_const = is_slice_const and is_const
step, is_const = extract_const(step)
is_slice_const = is_slice_const and is_const
if not is_slice_const:
out_shape.append(None)
continue
slice_length = len(range(*slice(start, stop, step).indices(length)))
out_shape.append(slice_length)
return Apply(
self,
......
......@@ -2693,3 +2693,18 @@ def test_index_vars_to_types():
assert isinstance(x.type, scal.ScalarType)
res = index_vars_to_types(x)
assert res == x.type
@pytest.mark.parametrize(
"x_shape, indices, expected",
[
[(None,), (slice(None, None),), (None,)],
[(13,), (slice(None, 100),), (13,)],
[(13,), (slice(-1, None),), (1,)],
[(7, 13), (slice(None, None, 2), slice(-1, 1, -1)), (4, 11)],
],
)
def test_static_shapes(x_shape, indices, expected):
x = at.tensor(dtype="float64", shape=x_shape)
y = x[indices]
assert y.type.shape == expected
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论