提交 fd4c5d91 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor get_canonical_form_slice so that it uses as_index_literal

上级 0cbf8557
......@@ -168,7 +168,9 @@ def get_idx_list(inputs, idx_list):
return indices_from_subtensor(inputs[1:], idx_list)
def get_canonical_form_slice(theslice, length):
def get_canonical_form_slice(
theslice: Union[slice, Variable], length: Variable
) -> Tuple[Variable, int]:
"""Convert slices to canonical form.
Given a slice [start:stop:step] transform it into a canonical form
......@@ -179,16 +181,24 @@ def get_canonical_form_slice(theslice, length):
if the resulting set of numbers needs to be reversed or not.
"""
from aesara.tensor import extract_constant, ge, lt, sgn, switch
from aesara.tensor import ge, lt, sgn, switch
if isinstance(theslice, slice):
if not isinstance(theslice, slice):
try:
value = as_index_literal(theslice)
except NotScalarConstantError:
value = theslice
value = switch(lt(value, 0), (value + length), value)
return value, 1
def analyze(x):
try:
x_constant = get_scalar_constant_value(x)
x_constant = as_index_literal(x)
is_constant = True
except NotScalarConstantError:
x_constant = extract_constant(x)
x_constant = x
is_constant = False
return x_constant, is_constant
......@@ -298,9 +308,7 @@ def get_canonical_form_slice(theslice, length):
else:
start = switch(lt(start, 0), start + length, start)
start = switch(lt(start, 0), switch_neg_step(-1, 0), start)
start = switch(
ge(start, length), switch_neg_step(length - 1, length), start
)
start = switch(ge(start, length), switch_neg_step(length - 1, length), start)
if stop is None or stop == sys.maxsize:
# The special "maxsize" case is probably not needed here,
# as slices containing maxsize are not generated by
......@@ -328,11 +336,6 @@ def get_canonical_form_slice(theslice, length):
return slice(nw_start, nw_stop, nw_step), reverse
else:
return slice(nw_start, nw_stop, nw_step), 1
else:
value = extract_constant(theslice)
value = switch(lt(value, 0), (value + length), value)
return value, 1
def range_len(slc):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论