提交 1d28ac59 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up and extend local_useless_inc_subtensor

Aside from some basic refactoring, this commit allows the `local_useless_inc_subtensor` rewrite to handle more increment zero cases.
上级 edcbac8e
......@@ -71,6 +71,7 @@ from aesara.tensor.subtensor import (
get_idx_list,
get_slice_elements,
inc_subtensor,
indices_from_subtensor,
)
from aesara.tensor.type import TensorType
from aesara.tensor.type_other import NoneTypeT, SliceConstant, SliceType
......@@ -785,34 +786,38 @@ def local_subtensor_make_vector(fgraph, node):
@register_specialize
@local_optimizer([IncSubtensor])
def local_useless_inc_subtensor(fgraph, node):
"""
Remove IncSubtensor, when we overwrite the full inputs with the
new value.
r"""Remove redundant `IncSubtensor`\s.
More specifically, ``set_subtensor(x[indices], y)`` is replaced by
``y[indices]`` when ``indices`` are full `slice`\s and ``y``'s shape is
equal to ``x[indices]``, and ``inc_subtensor(x[indices], y)`` is replaced
by ``y[indices]`` when ``x[indices]`` is some array of ``0``\s, ``indices``
are full slices, and the shapes are equal.
"""
if not isinstance(node.op, IncSubtensor):
return
if not hasattr(fgraph, "shape_feature"):
return
x, y, *index_inputs = node.inputs
if node.op.set_instead_of_inc is False:
# This is an IncSubtensor, so the init value must be zeros
# This is an increment operation, so the array being incremented must
# consist of all zeros in order for the entire operation to be useless
try:
c = get_scalar_constant_value(node.inputs[0], only_process_constants=True)
c = get_scalar_constant_value(x)
if c != 0:
return
except NotScalarConstantError:
return
if (
node.inputs[0].ndim != node.inputs[1].ndim
or node.inputs[0].broadcastable != node.inputs[1].broadcastable
):
# FB: I didn't check if this case can happen, but this opt
# don't support it.
return
# We have a SetSubtensor or an IncSubtensor on zeros
# If is this IncSubtensor useful?
# Check that we keep all the original data.
# Put the constant inputs in the slice.
idx_cst = get_idx_list(node.inputs[1:], node.op.idx_list)
idx_cst = indices_from_subtensor(list(index_inputs), node.op.idx_list)
# Check that all indices are full slices with only reversals and no step
# sizes
# TODO: It seems like there should be a basic `IncSubtensor`
# canonicalization that removes these redundant slices.
if all(
isinstance(e, slice)
and e.start is None
......@@ -823,20 +828,22 @@ def local_useless_inc_subtensor(fgraph, node):
)
for e in idx_cst
):
# IncSubtensor broadcast node.inputs[1] on node.inputs[0]
# based on run time shapes, so we must check they are the same.
if not hasattr(fgraph, "shape_feature"):
return
if not fgraph.shape_feature.same_shape(node.inputs[0], node.inputs[1]):
# `IncSubtensor` broadcasts `x` on `y` based on run-time shapes, so we
# must check that they are the same
if not fgraph.shape_feature.same_shape(x, y):
return
# There is no reverse, so we don't need a replacement.
# There are no reversals, so we don't need a replacement.
if all(e.step is None for e in node.op.idx_list):
# They are the same shape, so we can remove this IncSubtensor
return [node.inputs[1]]
ret = Subtensor(node.op.idx_list)(*node.inputs[1:])
# Copy over previous output stacktrace
copy_stack_trace(node.outputs, ret)
return [ret]
# They are exactly the same shapes, so we can remove this `IncSubtensor`
return [y]
new_node = Subtensor(node.op.idx_list).make_node(y, *index_inputs)
new_out = new_node.outputs[0]
copy_stack_trace(node.outputs, new_out)
return [new_out]
@register_canonicalize
......
......@@ -127,46 +127,81 @@ def test_local_replace_AdvancedSubtensor(indices, is_none):
assert np.array_equal(res_val, exp_res_val)
def test_local_useless_inc_subtensor():
@pytest.mark.parametrize("s", [slice(None), slice(None, None, -1)])
def test_local_useless_inc_subtensor(s):
x = matrix("x")
y = matrix("y")
o = set_subtensor(x[:, s], y)
mode = get_default_mode().including("local_useless_inc_subtensor")
for s in [slice(None), slice(None, None, -1)]:
o = set_subtensor(x[::, s], y)
f = function([x, y], o, mode=mode)
o_shape = set_subtensor(x[::, s], specify_shape(y, x.shape))
f_shape = function([x, y], o_shape, mode=mode)
# Test with shape info
topo = f_shape.maker.fgraph.toposort()
assert not any(isinstance(n.op, IncSubtensor) for n in topo)
out = f_shape([[2, 3]], [[3, 4]])
assert (out == np.asarray([[3, 4]])[::, s]).all()
# Test that without shape info, we don't apply the opt.
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, IncSubtensor)
out = f([[2, 3]], [[3, 4]])
assert (out == np.asarray([[3, 4]])[::, s]).all()
# Test that we don't remove shape error
with pytest.raises(ValueError):
f([[2, 3]], [[3, 4], [4, 5]])
# Test that we don't remove broadcastability
out = f([[2, 3], [3, 4]], [[5, 6]])
assert (out == np.asarray([[5, 6], [5, 6]])[::, s]).all()
# Test that we do not optimize others strides even when sub and y
# have same shapes
s = x[::, ::2]
# Test without shape info (i.e. don't apply the opt)
f = function([x, y], o, mode=mode)
topo = f.maker.fgraph.toposort()
assert len(topo) == 1
assert isinstance(topo[0].op, IncSubtensor)
# Test with shape info
o_shape = set_subtensor(x[:, s], specify_shape(y, x.shape))
f_shape = function([x, y], o_shape, mode=mode)
topo = f_shape.maker.fgraph.toposort()
assert not any(isinstance(n.op, IncSubtensor) for n in topo)
out = f_shape([[2, 3]], [[3, 4]])
assert np.array_equal(out, np.asarray([[3, 4]])[::, s])
def test_local_useless_inc_subtensor_increment_zeros():
r"""Make sure we remove `IncSubtensor`\s that are increments on entire zero arrays."""
y = matrix("y")
s = aet.zeros((2, 2))[:, :]
o_shape = inc_subtensor(s, specify_shape(y, s.shape))
mode = get_default_mode().including("local_useless_inc_subtensor")
f_shape = function([y], o_shape, mode=mode)
topo = f_shape.maker.fgraph.toposort()
assert not any(isinstance(n.op, IncSubtensor) for n in topo)
def test_local_useless_inc_subtensor_no_opt():
r"""Make sure we don't remove `IncSubtensor`\s that involve slices with steps that skip elements and non-zero increments."""
x = matrix("x")
y = matrix("y")
s = x[:, ::2]
o_shape = set_subtensor(s, specify_shape(y, s.shape))
f_shape = function([x, y], o_shape)
mode = get_default_mode().including("local_useless_inc_subtensor")
f_shape = function([x, y], o_shape, mode=mode)
topo = f_shape.maker.fgraph.toposort()
assert any(isinstance(n.op, IncSubtensor) for n in topo)
out = f_shape([[2, 3, 6, 7]], [[8, 9]])
assert (out == np.asarray([[8, 3, 9, 7]])).all()
assert np.array_equal(out, np.asarray([[8, 3, 9, 7]]))
# This is an increment with a non-constant target array
s = x[:, :]
o_shape = inc_subtensor(s, specify_shape(y, s.shape))
f_shape = function([x, y], o_shape, mode=mode)
topo = f_shape.maker.fgraph.toposort()
assert any(isinstance(n.op, IncSubtensor) for n in topo)
# This is an increment with a non-zero target array
s = aet.ones((2, 2))[:, :]
o_shape = inc_subtensor(s, specify_shape(y, s.shape))
f_shape = function([y], o_shape, mode=mode)
topo = f_shape.maker.fgraph.toposort()
assert any(isinstance(n.op, IncSubtensor) for n in topo)
def test_local_useless_subtensor():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论