提交 ea5401cc authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add rewrite to merge consecutive joined subtensors

上级 14482079
......@@ -12,6 +12,7 @@ from aesara.raise_op import Assert
from aesara.tensor.basic import (
Alloc,
ARange,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
......@@ -19,6 +20,7 @@ from aesara.tensor.basic import (
alloc,
as_tensor,
cast,
concatenate,
extract_constant,
get_scalar_constant_value,
patternbroadcast,
......@@ -1661,3 +1663,106 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
if new_obj_arg.ndim == 0:
return [new_obj_arg]
return [specify_shape(new_obj_arg, shape_arg[len(indices) :])]
@register_specialize
@local_optimizer([Join])
def local_join_subtensors(fgraph, node):
r"""Simplify contiguous :class:`Subtensor`\s inside a :class:`Join`.
`join((x[:3], x[3:5]), axis=0) -> x[:5]`
"""
# TODO: Generalize to AdvancedSubtensors
axis, tensors = node.inputs[0], node.inputs[1:]
try:
axis = get_scalar_constant_value(axis)
except NotScalarConstantError:
return
for subtensor1_idx, (subtensor1, subtensor2) in enumerate(
zip(tensors[:-1], tensors[1:])
):
# Check that two consecutive Subtensors are operating on the same base tensor
if not (
(
subtensor1.owner is not None
and isinstance(subtensor1.owner.op, Subtensor)
)
and (
subtensor2.owner is not None
and isinstance(subtensor2.owner.op, Subtensor)
)
and (subtensor1.owner.inputs[0] is subtensor2.owner.inputs[0])
):
continue
# Check that subtensors have consecutive indexes across the join axis
idxs_subtensor1 = indices_from_subtensor(
subtensor1.owner.inputs[1:], subtensor1.owner.op.idx_list
)
idxs_subtensor2 = indices_from_subtensor(
subtensor2.owner.inputs[1:], subtensor2.owner.op.idx_list
)
try:
idxs_axis_subtensor1 = idxs_subtensor1[axis]
idxs_axis_subtensor2 = idxs_subtensor2[axis]
except IndexError:
continue
if not (
isinstance(idxs_axis_subtensor1, slice)
and isinstance(idxs_axis_subtensor2, slice)
):
continue
start_subtensor1, stop_subtensor1, step_subtensor1 = (
idxs_axis_subtensor1.start,
idxs_axis_subtensor1.stop,
idxs_axis_subtensor1.step,
)
start_subtensor2, stop_subtensor2, step_subtensor2 = (
idxs_axis_subtensor2.start,
idxs_axis_subtensor2.stop,
idxs_axis_subtensor2.step,
)
if not (
(stop_subtensor1 is not None and start_subtensor2 is not None)
and (stop_subtensor1 == start_subtensor2)
):
continue
# Check that step is None or 1
# For non-unit steps (perhaps except for -1) we would need to know the
# exact values of start and stop to know if they can be merged
for step in (step_subtensor1, step_subtensor2):
if step is None:
continue
try:
if get_scalar_constant_value(step, only_process_constants=True) != 1:
return None
except NotScalarConstantError:
return None
# Check that all other idxs of subtensor are the same
if all(
idxs_nonaxis_subtensor1 == idxs_nonaxis_subtensor2
for i, (idxs_nonaxis_subtensor1, idxs_nonaxis_subtensor2) in enumerate(
zip(idxs_subtensor1, idxs_subtensor2)
)
if i != axis
):
base_tensor = subtensor1.owner.inputs[0]
new_idxs = list(idxs_subtensor1)
new_idxs[axis] = slice(start_subtensor1, stop_subtensor2, step_subtensor1)
merged_subtensors = base_tensor[new_idxs]
new_joined_tensors = [
*tensors[:subtensor1_idx],
merged_subtensors,
*tensors[subtensor1_idx + 2 :],
]
if len(new_joined_tensors) > 1:
return [concatenate(new_joined_tensors, axis=axis)]
else:
return [merged_subtensors]
......@@ -2199,3 +2199,112 @@ def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx):
y_opt = optimize_graph(y, clone=False)
assert not isinstance(y_opt.owner.op, SpecifyShape)
@pytest.mark.parametrize(
"axis, slices_fn, expected_nodes",
[
# Below should be merged
(0, lambda _: ((slice(None, 5, None),), (slice(5, None, None),)), 1),
(0, lambda _: ((slice(0, 5, 1),), (slice(5, None, 1),)), 1),
(
0,
lambda _: (
(slice(0, 2, 1),),
(slice(2, 4, None),),
(slice(4, None, 1)),
),
1,
),
(
0,
lambda _: (
(slice(None, 5, None), slice(None, -1, None)),
(slice(5, None, None), slice(None, -1, None)),
),
2,
),
(
1,
lambda step: (
(slice(2, None, step), slice(None, 2, None)),
(slice(2, None, step), slice(2, 4, None)),
(slice(2, None, step), slice(4, 6, None)),
),
3,
),
(
0,
lambda stop: (
(slice(1, stop, None),),
(slice(stop, 5, None),),
(slice(5, 7, None)),
),
2,
),
(
0,
lambda stop: (
(slice(1, stop + 1, None),),
(slice(stop + 1, 5, None),),
(slice(5, 7, None)),
),
2,
),
# Below NotImplemented: These could be merged, but we would need to evaluate the
# start and stop values
(0, lambda _: ((slice(None, 6, 3),), (slice(6, None, 3),)), 3),
(0, lambda step: ((slice(None, 6, step),), (slice(6, None, step),)), 4),
# Below should not be merged
(0, lambda _: ((slice(5, None, None),), (slice(None, 5, None),)), 3),
(0, lambda _: ((slice(None, 5, None),), (slice(4, None, None),)), 3),
(1, lambda _: ((slice(None, 5, None),), (slice(5, None, None),)), 3),
(
0,
lambda _: (
(slice(2, None, None), slice(None, 2, None)),
(slice(2, None, None), slice(2, 4, None)),
(slice(2, None, None), slice(4, 6, None)),
),
4,
),
(
0,
lambda _: (
(slice(None, 5, 2), slice(None, -1, None)),
(slice(5, None, 3), slice(None, -1, None)),
),
3,
),
(
0,
lambda _: (
(slice(None, 5, None), slice(None, -1, None)),
(slice(5, None, None), slice(1, None, None)),
),
3,
),
(0, lambda stop: ((slice(None, stop, None),), (slice(3, None, None),)), 4),
(0, lambda _: ((slice(None, 5, 2),), (slice(5, None, 2),)), 3),
],
)
def test_local_join_subtensors(axis, slices_fn, expected_nodes):
x = at.dmatrix("x")
slice_scalar = at.iscalar("slice_scalar")
slices = slices_fn(slice_scalar)
y = at.concatenate([x[slice] for slice in slices], axis=axis)
f = aesara.function(
[x, slice_scalar],
y,
mode=Mode("py").excluding("fusion"),
on_unused_input="ignore",
)
nodes = f.maker.fgraph.toposort()
assert len(nodes) == expected_nodes, nodes
x_val = np.arange(100).reshape(10, 10)
stop_val = 3
slices_val = slices_fn(stop_val)
f_val = np.concatenate([x_val[slice_val] for slice_val in slices_val], axis=axis)
np.testing.assert_array_equal(f(x_val, stop_val), f_val)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论