提交 fd4c5d91 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor get_canonical_form_slice so that it uses as_index_literal

上级 0cbf8557
...@@ -168,7 +168,9 @@ def get_idx_list(inputs, idx_list): ...@@ -168,7 +168,9 @@ def get_idx_list(inputs, idx_list):
return indices_from_subtensor(inputs[1:], idx_list) return indices_from_subtensor(inputs[1:], idx_list)
def get_canonical_form_slice(theslice, length): def get_canonical_form_slice(
theslice: Union[slice, Variable], length: Variable
) -> Tuple[Variable, int]:
"""Convert slices to canonical form. """Convert slices to canonical form.
Given a slice [start:stop:step] transform it into a canonical form Given a slice [start:stop:step] transform it into a canonical form
...@@ -179,160 +181,161 @@ def get_canonical_form_slice(theslice, length): ...@@ -179,160 +181,161 @@ def get_canonical_form_slice(theslice, length):
if the resulting set of numbers needs to be reversed or not. if the resulting set of numbers needs to be reversed or not.
""" """
from aesara.tensor import extract_constant, ge, lt, sgn, switch from aesara.tensor import ge, lt, sgn, switch
if isinstance(theslice, slice): if not isinstance(theslice, slice):
try:
value = as_index_literal(theslice)
except NotScalarConstantError:
value = theslice
def analyze(x): value = switch(lt(value, 0), (value + length), value)
try:
x_constant = get_scalar_constant_value(x)
is_constant = True
except NotScalarConstantError:
x_constant = extract_constant(x)
is_constant = False
return x_constant, is_constant
start, is_start_constant = analyze(theslice.start)
stop, is_stop_constant = analyze(theslice.stop)
step, is_step_constant = analyze(theslice.step)
length, is_length_constant = analyze(length)
if step is None:
step = 1
is_step_constant = True
# First handle the easier and common case where `step` is 1 and
# either `start` or `stop` is a range boundary. More specializations
# could be added later. This makes the resulting graph smaller than
# in the generic case below.
if step == 1:
is_start_0 = (
start is None
or start == 0
or (
is_start_constant
and is_length_constant
and start < 0
and start + length <= 0
)
)
is_stop_length = (
stop is None
or stop in [length, sys.maxsize]
or (is_stop_constant and is_length_constant and stop >= length)
)
if is_start_0:
# 0:stop:1
if is_stop_length:
# Full slice.
return slice(0, length, 1), 1
if is_stop_constant and stop >= 0:
return (slice(0, switch(lt(stop, length), stop, length), 1), 1)
stop_plus_len = stop + length
stop = switch(
lt(stop, 0),
# stop < 0
switch(
lt(stop_plus_len, 0),
# stop + len < 0
0,
# stop + len >= 0
stop_plus_len,
),
# stop >= 0: use min(stop, length)
switch(lt(stop, length), stop, length),
)
return slice(0, stop, 1), 1
elif is_stop_length:
# start:length:1
if is_start_constant and start >= 0:
return slice(switch(lt(start, length), start, length), length, 1), 1
start_plus_len = start + length
start = switch(
lt(start, 0),
# start < 0
switch(
lt(start_plus_len, 0),
# start + len < 0
0,
# start + len >= 0
start_plus_len,
),
# start >= 0: use min(start, length)
switch(lt(start, length), start, length),
)
return slice(start, length, 1), 1
# This is the generic case. return value, 1
if is_step_constant: def analyze(x):
# When we know the sign of `step`, the graph can be made simpler. try:
assert step != 0 x_constant = as_index_literal(x)
if step > 0: is_constant = True
except NotScalarConstantError:
x_constant = x
is_constant = False
return x_constant, is_constant
start, is_start_constant = analyze(theslice.start)
stop, is_stop_constant = analyze(theslice.stop)
step, is_step_constant = analyze(theslice.step)
length, is_length_constant = analyze(length)
if step is None:
step = 1
is_step_constant = True
# First handle the easier and common case where `step` is 1 and
# either `start` or `stop` is a range boundary. More specializations
# could be added later. This makes the resulting graph smaller than
# in the generic case below.
if step == 1:
is_start_0 = (
start is None
or start == 0
or (
is_start_constant
and is_length_constant
and start < 0
and start + length <= 0
)
)
is_stop_length = (
stop is None
or stop in [length, sys.maxsize]
or (is_stop_constant and is_length_constant and stop >= length)
)
if is_start_0:
# 0:stop:1
if is_stop_length:
# Full slice.
return slice(0, length, 1), 1
if is_stop_constant and stop >= 0:
return (slice(0, switch(lt(stop, length), stop, length), 1), 1)
stop_plus_len = stop + length
stop = switch(
lt(stop, 0),
# stop < 0
switch(
lt(stop_plus_len, 0),
# stop + len < 0
0,
# stop + len >= 0
stop_plus_len,
),
# stop >= 0: use min(stop, length)
switch(lt(stop, length), stop, length),
)
return slice(0, stop, 1), 1
elif is_stop_length:
# start:length:1
if is_start_constant and start >= 0:
return slice(switch(lt(start, length), start, length), length, 1), 1
start_plus_len = start + length
start = switch(
lt(start, 0),
# start < 0
switch(
lt(start_plus_len, 0),
# start + len < 0
0,
# start + len >= 0
start_plus_len,
),
# start >= 0: use min(start, length)
switch(lt(start, length), start, length),
)
return slice(start, length, 1), 1
def switch_neg_step(a, b): # This is the generic case.
return b
abs_step = step if is_step_constant:
sgn_step = 1 # When we know the sign of `step`, the graph can be made simpler.
else: assert step != 0
if step > 0:
def switch_neg_step(a, b): def switch_neg_step(a, b):
return a return b
abs_step = -step abs_step = step
sgn_step = -1 sgn_step = 1
else: else:
is_step_neg = lt(step, 0)
def switch_neg_step(a, b): def switch_neg_step(a, b):
return switch(is_step_neg, a, b) return a
abs_step = abs(step)
sgn_step = sgn(step)
defstart = switch_neg_step(length - 1, 0) abs_step = -step
defstop = switch_neg_step(-1, length) sgn_step = -1
if start is None:
start = defstart
else:
start = switch(lt(start, 0), start + length, start)
start = switch(lt(start, 0), switch_neg_step(-1, 0), start)
start = switch(
ge(start, length), switch_neg_step(length - 1, length), start
)
if stop is None or stop == sys.maxsize:
# The special "maxsize" case is probably not needed here,
# as slices containing maxsize are not generated by
# __getslice__ anymore.
stop = defstop
else:
stop = switch(lt(stop, 0), stop + length, stop)
stop = switch(lt(stop, 0), -1, stop)
stop = switch(ge(stop, length), length, stop)
nw_stop = switch_neg_step(start + 1, stop)
slice_len = (start - stop - 1) // abs_step + 1
slice_len = switch(lt(slice_len, 0), 0, slice_len)
neg_start = nw_stop - (slice_len - 1) * abs_step - 1
neg_start = switch(lt(neg_start, 0), (nw_stop - 1), neg_start)
nw_start = switch_neg_step(neg_start, start)
nw_start = switch(lt(nw_start, 0), 0, nw_start)
nw_stop = switch(lt(nw_stop, 0), 0, nw_stop)
# Ensure start <= stop.
nw_start = switch(lt(nw_start, nw_stop), nw_start, nw_stop)
nw_step = abs_step
if step != 1:
reverse = sgn_step
return slice(nw_start, nw_stop, nw_step), reverse
else:
return slice(nw_start, nw_stop, nw_step), 1
else: else:
value = extract_constant(theslice) is_step_neg = lt(step, 0)
value = switch(lt(value, 0), (value + length), value)
return value, 1 def switch_neg_step(a, b):
return switch(is_step_neg, a, b)
abs_step = abs(step)
sgn_step = sgn(step)
defstart = switch_neg_step(length - 1, 0)
defstop = switch_neg_step(-1, length)
if start is None:
start = defstart
else:
start = switch(lt(start, 0), start + length, start)
start = switch(lt(start, 0), switch_neg_step(-1, 0), start)
start = switch(ge(start, length), switch_neg_step(length - 1, length), start)
if stop is None or stop == sys.maxsize:
# The special "maxsize" case is probably not needed here,
# as slices containing maxsize are not generated by
# __getslice__ anymore.
stop = defstop
else:
stop = switch(lt(stop, 0), stop + length, stop)
stop = switch(lt(stop, 0), -1, stop)
stop = switch(ge(stop, length), length, stop)
nw_stop = switch_neg_step(start + 1, stop)
slice_len = (start - stop - 1) // abs_step + 1
slice_len = switch(lt(slice_len, 0), 0, slice_len)
neg_start = nw_stop - (slice_len - 1) * abs_step - 1
neg_start = switch(lt(neg_start, 0), (nw_stop - 1), neg_start)
nw_start = switch_neg_step(neg_start, start)
nw_start = switch(lt(nw_start, 0), 0, nw_start)
nw_stop = switch(lt(nw_stop, 0), 0, nw_stop)
# Ensure start <= stop.
nw_start = switch(lt(nw_start, nw_stop), nw_start, nw_stop)
nw_step = abs_step
if step != 1:
reverse = sgn_step
return slice(nw_start, nw_stop, nw_step), reverse
else:
return slice(nw_start, nw_stop, nw_step), 1
def range_len(slc): def range_len(slc):
......
...@@ -14,6 +14,7 @@ from aesara.compile.io import In ...@@ -14,6 +14,7 @@ from aesara.compile.io import In
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.op import get_test_value from aesara.graph.op import get_test_value
from aesara.graph.opt_utils import is_same_graph from aesara.graph.opt_utils import is_same_graph
from aesara.scalar.basic import as_scalar
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.math import exp, isinf from aesara.tensor.math import exp, isinf
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
...@@ -96,6 +97,186 @@ def test_as_index_literal(): ...@@ -96,6 +97,186 @@ def test_as_index_literal():
assert res is np.newaxis assert res is np.newaxis
class TestGetCanonicalFormSlice:
def test_scalar_constant(self):
a = as_scalar(0)
length = lscalar()
res = get_canonical_form_slice(a, length)
assert res[0].owner.op == aet.switch
assert res[1] == 1
def test_all_symbolic(self):
start = iscalar("b")
stop = iscalar("e")
step = iscalar("s")
length = iscalar("l")
cnf = get_canonical_form_slice(slice(start, stop, step), length)
f = aesara.function(
[start, stop, step, length],
[
aet.as_tensor_variable(cnf[0].start),
aet.as_tensor_variable(cnf[0].stop),
aet.as_tensor_variable(cnf[0].step),
aet.as_tensor_variable(cnf[1]),
],
)
length = 5
a = np.arange(length)
for start in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
for stop in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
for step in [-6, -3, -1, 2, 5]:
out = f(start, stop, step, length)
t_out = a[out[0] : out[1] : out[2]][:: out[3]]
v_out = a[start:stop:step]
assert np.all(t_out == v_out)
assert np.all(t_out.shape == v_out.shape)
def test_start_None(self):
stop = iscalar("e")
step = iscalar("s")
length = iscalar("l")
cnf = get_canonical_form_slice(slice(None, stop, step), length)
f = aesara.function(
[stop, step, length],
[
aet.as_tensor_variable(cnf[0].start),
aet.as_tensor_variable(cnf[0].stop),
aet.as_tensor_variable(cnf[0].step),
aet.as_tensor_variable(cnf[1]),
],
)
length = 5
a = np.arange(length)
for stop in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
for step in [-6, -3, -1, 2, 5]:
out = f(stop, step, length)
t_out = a[out[0] : out[1] : out[2]][:: out[3]]
v_out = a[:stop:step]
assert np.all(t_out == v_out)
assert np.all(t_out.shape == v_out.shape)
def test_stop_None(self):
start = iscalar("b")
step = iscalar("s")
length = iscalar("l")
cnf = get_canonical_form_slice(slice(start, None, step), length)
f = aesara.function(
[start, step, length],
[
aet.as_tensor_variable(cnf[0].start),
aet.as_tensor_variable(cnf[0].stop),
aet.as_tensor_variable(cnf[0].step),
aet.as_tensor_variable(cnf[1]),
],
)
length = 5
a = np.arange(length)
for start in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
for step in [-6, -3, -1, 2, 5]:
out = f(start, step, length)
t_out = a[out[0] : out[1] : out[2]][:: out[3]]
v_out = a[start:None:step]
assert np.all(t_out == v_out)
assert np.all(t_out.shape == v_out.shape)
def test_step_None(self):
start = iscalar("b")
stop = iscalar("e")
length = iscalar("l")
cnf = get_canonical_form_slice(slice(start, stop, None), length)
f = aesara.function(
[start, stop, length],
[
aet.as_tensor_variable(cnf[0].start),
aet.as_tensor_variable(cnf[0].stop),
aet.as_tensor_variable(cnf[0].step),
aet.as_tensor_variable(cnf[1]),
],
)
length = 5
a = np.arange(length)
for start in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
for stop in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
out = f(start, stop, length)
t_out = a[out[0] : out[1] : out[2]][:: out[3]]
v_out = a[start:stop:None]
assert np.all(t_out == v_out)
assert np.all(t_out.shape == v_out.shape)
def test_start_stop_None(self):
step = iscalar("s")
length = iscalar("l")
cnf = get_canonical_form_slice(slice(None, None, step), length)
f = aesara.function(
[step, length],
[
aet.as_tensor_variable(cnf[0].start),
aet.as_tensor_variable(cnf[0].stop),
aet.as_tensor_variable(cnf[0].step),
aet.as_tensor_variable(cnf[1]),
],
)
length = 5
a = np.arange(length)
for step in [-6, -3, -1, 2, 5]:
out = f(step, length)
t_out = a[out[0] : out[1] : out[2]][:: out[3]]
v_out = a[None:None:step]
assert np.all(t_out == v_out)
assert np.all(t_out.shape == v_out.shape)
def test_stop_step_None(self):
start = iscalar("b")
length = iscalar("l")
cnf = get_canonical_form_slice(slice(start, None, None), length)
f = aesara.function(
[start, length],
[
aet.as_tensor_variable(cnf[0].start),
aet.as_tensor_variable(cnf[0].stop),
aet.as_tensor_variable(cnf[0].step),
aet.as_tensor_variable(cnf[1]),
],
)
length = 5
a = np.arange(length)
for start in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
out = f(start, length)
t_out = a[out[0] : out[1] : out[2]][:: out[3]]
v_out = a[start:None:None]
assert np.all(t_out == v_out)
assert np.all(t_out.shape == v_out.shape)
def test_start_step_None(self):
stop = iscalar("e")
length = iscalar("l")
cnf = get_canonical_form_slice(slice(None, stop, None), length)
f = aesara.function(
[stop, length],
[
aet.as_tensor_variable(cnf[0].start),
aet.as_tensor_variable(cnf[0].stop),
aet.as_tensor_variable(cnf[0].step),
aet.as_tensor_variable(cnf[1]),
],
)
length = 5
a = np.arange(length)
for stop in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
out = f(stop, length)
t_out = a[out[0] : out[1] : out[2]][:: out[3]]
v_out = a[None:stop:None]
assert np.all(t_out == v_out)
assert np.all(t_out.shape == v_out.shape)
class TestSubtensor(utt.OptimizationTestMixin): class TestSubtensor(utt.OptimizationTestMixin):
""" """
This is designed to be sub-classed (e.g. by the GPU tests). This is designed to be sub-classed (e.g. by the GPU tests).
...@@ -846,191 +1027,6 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -846,191 +1027,6 @@ class TestSubtensor(utt.OptimizationTestMixin):
for step in [-3, -1, 2, 5]: for step in [-3, -1, 2, 5]:
assert np.all(f(start, stop, step) == v_data[start:stop:step].shape) assert np.all(f(start, stop, step) == v_data[start:stop:step].shape)
def test_slice_canonical_form_0(self):
start = iscalar("b")
stop = iscalar("e")
step = iscalar("s")
length = iscalar("l")
cnf = get_canonical_form_slice(slice(start, stop, step), length)
f = self.function(
[start, stop, step, length],
[
aet.as_tensor_variable(cnf[0].start),
aet.as_tensor_variable(cnf[0].stop),
aet.as_tensor_variable(cnf[0].step),
aet.as_tensor_variable(cnf[1]),
],
N=0,
op=subtensor_ops,
)
length = 5
a = np.arange(length)
for start in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
for stop in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
for step in [-6, -3, -1, 2, 5]:
out = f(start, stop, step, length)
t_out = a[out[0] : out[1] : out[2]][:: out[3]]
v_out = a[start:stop:step]
assert np.all(t_out == v_out)
assert np.all(t_out.shape == v_out.shape)
def test_slice_canonical_form_1(self):
stop = iscalar("e")
step = iscalar("s")
length = iscalar("l")
cnf = get_canonical_form_slice(slice(None, stop, step), length)
f = self.function(
[stop, step, length],
[
aet.as_tensor_variable(cnf[0].start),
aet.as_tensor_variable(cnf[0].stop),
aet.as_tensor_variable(cnf[0].step),
aet.as_tensor_variable(cnf[1]),
],
N=0,
op=subtensor_ops,
)
length = 5
a = np.arange(length)
for stop in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
for step in [-6, -3, -1, 2, 5]:
out = f(stop, step, length)
t_out = a[out[0] : out[1] : out[2]][:: out[3]]
v_out = a[:stop:step]
assert np.all(t_out == v_out)
assert np.all(t_out.shape == v_out.shape)
def test_slice_canonical_form_2(self):
start = iscalar("b")
step = iscalar("s")
length = iscalar("l")
cnf = get_canonical_form_slice(slice(start, None, step), length)
f = self.function(
[start, step, length],
[
aet.as_tensor_variable(cnf[0].start),
aet.as_tensor_variable(cnf[0].stop),
aet.as_tensor_variable(cnf[0].step),
aet.as_tensor_variable(cnf[1]),
],
N=0,
op=subtensor_ops,
)
length = 5
a = np.arange(length)
for start in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
for step in [-6, -3, -1, 2, 5]:
out = f(start, step, length)
t_out = a[out[0] : out[1] : out[2]][:: out[3]]
v_out = a[start:None:step]
assert np.all(t_out == v_out)
assert np.all(t_out.shape == v_out.shape)
def test_slice_canonical_form_3(self):
start = iscalar("b")
stop = iscalar("e")
length = iscalar("l")
cnf = get_canonical_form_slice(slice(start, stop, None), length)
f = self.function(
[start, stop, length],
[
aet.as_tensor_variable(cnf[0].start),
aet.as_tensor_variable(cnf[0].stop),
aet.as_tensor_variable(cnf[0].step),
aet.as_tensor_variable(cnf[1]),
],
N=0,
op=subtensor_ops,
)
length = 5
a = np.arange(length)
for start in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
for stop in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
out = f(start, stop, length)
t_out = a[out[0] : out[1] : out[2]][:: out[3]]
v_out = a[start:stop:None]
assert np.all(t_out == v_out)
assert np.all(t_out.shape == v_out.shape)
def test_slice_canonical_form_4(self):
step = iscalar("s")
length = iscalar("l")
cnf = get_canonical_form_slice(slice(None, None, step), length)
f = self.function(
[step, length],
[
aet.as_tensor_variable(cnf[0].start),
aet.as_tensor_variable(cnf[0].stop),
aet.as_tensor_variable(cnf[0].step),
aet.as_tensor_variable(cnf[1]),
],
N=0,
op=subtensor_ops,
)
length = 5
a = np.arange(length)
for step in [-6, -3, -1, 2, 5]:
out = f(step, length)
t_out = a[out[0] : out[1] : out[2]][:: out[3]]
v_out = a[None:None:step]
assert np.all(t_out == v_out)
assert np.all(t_out.shape == v_out.shape)
def test_slice_canonical_form_5(self):
start = iscalar("b")
length = iscalar("l")
cnf = get_canonical_form_slice(slice(start, None, None), length)
f = self.function(
[start, length],
[
aet.as_tensor_variable(cnf[0].start),
aet.as_tensor_variable(cnf[0].stop),
aet.as_tensor_variable(cnf[0].step),
aet.as_tensor_variable(cnf[1]),
],
N=0,
op=subtensor_ops,
)
length = 5
a = np.arange(length)
for start in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
out = f(start, length)
t_out = a[out[0] : out[1] : out[2]][:: out[3]]
v_out = a[start:None:None]
assert np.all(t_out == v_out)
assert np.all(t_out.shape == v_out.shape)
def test_slice_canonical_form_6(self):
stop = iscalar("e")
length = iscalar("l")
cnf = get_canonical_form_slice(slice(None, stop, None), length)
f = self.function(
[stop, length],
[
aet.as_tensor_variable(cnf[0].start),
aet.as_tensor_variable(cnf[0].stop),
aet.as_tensor_variable(cnf[0].step),
aet.as_tensor_variable(cnf[1]),
],
N=0,
op=subtensor_ops,
)
length = 5
a = np.arange(length)
for stop in [-8, -5, -4, -1, 0, 1, 4, 5, 8]:
out = f(stop, length)
t_out = a[out[0] : out[1] : out[2]][:: out[3]]
v_out = a[None:stop:None]
assert np.all(t_out == v_out)
assert np.all(t_out.shape == v_out.shape)
def grad_list_(self, idxs, data): def grad_list_(self, idxs, data):
n = self.shared(data) n = self.shared(data)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论