提交 790b46fd authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Ricardo Vieira

enh: Improve static shape of subtensor

上级 1d9fa843
...@@ -221,6 +221,16 @@ def get_canonical_form_slice( ...@@ -221,6 +221,16 @@ def get_canonical_form_slice(
step, is_step_constant = analyze(theslice.step) step, is_step_constant = analyze(theslice.step)
length, is_length_constant = analyze(length) length, is_length_constant = analyze(length)
if (
is_start_constant
and is_stop_constant
and is_step_constant
and is_length_constant
):
_start, _stop, _step = slice(start, stop, step).indices(length)
if _start <= _stop and _step >= 1:
return slice(_start, _stop, _step), 1
if step is None: if step is None:
step = 1 step = 1
is_step_constant = True is_step_constant = True
...@@ -722,32 +732,51 @@ class Subtensor(COp): ...@@ -722,32 +732,51 @@ class Subtensor(COp):
f"Incompatible types for Subtensor template. Expected {input.type}, got {expected_type}." f"Incompatible types for Subtensor template. Expected {input.type}, got {expected_type}."
) )
# infer the broadcasting pattern padded = [
padded = get_constant_idx( *get_idx_list((None,) + inputs, self.idx_list),
self.idx_list, (None,) + inputs, allow_partial=True *[slice(None, None, None)] * (x.type.ndim - len(idx_list)),
) + [slice(None, None, None)] * (x.type.ndim - len(idx_list)) ]
out_shape = [] out_shape = []
for i, (p, s) in enumerate(zip(padded, x.type.shape)):
if isinstance(p, slice):
if s == 1:
start = p.start
try:
start = get_underlying_scalar_constant_value(start)
except NotScalarConstantError:
pass
if start is None or start == 0:
start = p.start
if start is None:
start = 0
if p.stop is None or (
isinstance(p.stop, (int, np.integer, np.ndarray))
and p.stop > start
):
out_shape.append(1)
continue
def extract_const(value):
if value is None:
return value, True
try:
value = get_underlying_scalar_constant_value(value)
return value, True
except NotScalarConstantError:
return value, False
for the_slice, length in zip(padded, x.type.shape):
if not isinstance(the_slice, slice):
continue
if length is None:
out_shape.append(None) out_shape.append(None)
continue
start = the_slice.start
stop = the_slice.stop
step = the_slice.step
is_slice_const = True
start, is_const = extract_const(start)
is_slice_const = is_slice_const and is_const
stop, is_const = extract_const(stop)
is_slice_const = is_slice_const and is_const
step, is_const = extract_const(step)
is_slice_const = is_slice_const and is_const
if not is_slice_const:
out_shape.append(None)
continue
slice_length = len(range(*slice(start, stop, step).indices(length)))
out_shape.append(slice_length)
return Apply( return Apply(
self, self,
......
...@@ -2693,3 +2693,18 @@ def test_index_vars_to_types(): ...@@ -2693,3 +2693,18 @@ def test_index_vars_to_types():
assert isinstance(x.type, scal.ScalarType) assert isinstance(x.type, scal.ScalarType)
res = index_vars_to_types(x) res = index_vars_to_types(x)
assert res == x.type assert res == x.type
@pytest.mark.parametrize(
"x_shape, indices, expected",
[
[(None,), (slice(None, None),), (None,)],
[(13,), (slice(None, 100),), (13,)],
[(13,), (slice(-1, None),), (1,)],
[(7, 13), (slice(None, None, 2), slice(-1, 1, -1)), (4, 11)],
],
)
def test_static_shapes(x_shape, indices, expected):
x = at.tensor(dtype="float64", shape=x_shape)
y = x[indices]
assert y.type.shape == expected
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论