提交 1d28ac59 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up and extend local_useless_inc_subtensor

Aside from some basic refactoring, this commit allows the `local_useless_inc_subtensor` rewrite to handle more increment zero cases.
上级 edcbac8e
...@@ -71,6 +71,7 @@ from aesara.tensor.subtensor import ( ...@@ -71,6 +71,7 @@ from aesara.tensor.subtensor import (
get_idx_list, get_idx_list,
get_slice_elements, get_slice_elements,
inc_subtensor, inc_subtensor,
indices_from_subtensor,
) )
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
from aesara.tensor.type_other import NoneTypeT, SliceConstant, SliceType from aesara.tensor.type_other import NoneTypeT, SliceConstant, SliceType
...@@ -785,34 +786,38 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -785,34 +786,38 @@ def local_subtensor_make_vector(fgraph, node):
@register_specialize @register_specialize
@local_optimizer([IncSubtensor]) @local_optimizer([IncSubtensor])
def local_useless_inc_subtensor(fgraph, node): def local_useless_inc_subtensor(fgraph, node):
""" r"""Remove redundant `IncSubtensor`\s.
Remove IncSubtensor, when we overwrite the full inputs with the
new value.
More specifically, ``set_subtensor(x[indices], y)`` is replaced by
``y[indices]`` when ``indices`` are full `slice`\s and ``y``'s shape is
equal to ``x[indices]``, and ``inc_subtensor(x[indices], y)`` is replaced
by ``y[indices]`` when ``x[indices]`` is some array of ``0``\s, ``indices``
are full slices, and the shapes are equal.
""" """
if not isinstance(node.op, IncSubtensor): if not isinstance(node.op, IncSubtensor):
return return
if not hasattr(fgraph, "shape_feature"):
return
x, y, *index_inputs = node.inputs
if node.op.set_instead_of_inc is False: if node.op.set_instead_of_inc is False:
# This is an IncSubtensor, so the init value must be zeros # This is an increment operation, so the array being incremented must
# consist of all zeros in order for the entire operation to be useless
try: try:
c = get_scalar_constant_value(node.inputs[0], only_process_constants=True) c = get_scalar_constant_value(x)
if c != 0: if c != 0:
return return
except NotScalarConstantError: except NotScalarConstantError:
return return
if (
node.inputs[0].ndim != node.inputs[1].ndim
or node.inputs[0].broadcastable != node.inputs[1].broadcastable
):
# FB: I didn't check if this case can happen, but this opt
# don't support it.
return
# We have a SetSubtensor or an IncSubtensor on zeros
# If is this IncSubtensor useful?
# Check that we keep all the original data. idx_cst = indices_from_subtensor(list(index_inputs), node.op.idx_list)
# Put the constant inputs in the slice.
idx_cst = get_idx_list(node.inputs[1:], node.op.idx_list) # Check that all indices are full slices with only reversals and no step
# sizes
# TODO: It seems like there should be a basic `IncSubtensor`
# canonicalization that removes these redundant slices.
if all( if all(
isinstance(e, slice) isinstance(e, slice)
and e.start is None and e.start is None
...@@ -823,20 +828,22 @@ def local_useless_inc_subtensor(fgraph, node): ...@@ -823,20 +828,22 @@ def local_useless_inc_subtensor(fgraph, node):
) )
for e in idx_cst for e in idx_cst
): ):
# IncSubtensor broadcast node.inputs[1] on node.inputs[0]
# based on run time shapes, so we must check they are the same. # `IncSubtensor` broadcasts `x` on `y` based on run-time shapes, so we
if not hasattr(fgraph, "shape_feature"): # must check that they are the same
return if not fgraph.shape_feature.same_shape(x, y):
if not fgraph.shape_feature.same_shape(node.inputs[0], node.inputs[1]):
return return
# There is no reverse, so we don't need a replacement.
# There are no reversals, so we don't need a replacement.
if all(e.step is None for e in node.op.idx_list): if all(e.step is None for e in node.op.idx_list):
# They are the same shape, so we can remove this IncSubtensor # They are exactly the same shapes, so we can remove this `IncSubtensor`
return [node.inputs[1]] return [y]
ret = Subtensor(node.op.idx_list)(*node.inputs[1:])
# Copy over previous output stacktrace new_node = Subtensor(node.op.idx_list).make_node(y, *index_inputs)
copy_stack_trace(node.outputs, ret) new_out = new_node.outputs[0]
return [ret] copy_stack_trace(node.outputs, new_out)
return [new_out]
@register_canonicalize @register_canonicalize
......
...@@ -127,46 +127,81 @@ def test_local_replace_AdvancedSubtensor(indices, is_none): ...@@ -127,46 +127,81 @@ def test_local_replace_AdvancedSubtensor(indices, is_none):
assert np.array_equal(res_val, exp_res_val) assert np.array_equal(res_val, exp_res_val)
def test_local_useless_inc_subtensor(): @pytest.mark.parametrize("s", [slice(None), slice(None, None, -1)])
def test_local_useless_inc_subtensor(s):
x = matrix("x") x = matrix("x")
y = matrix("y") y = matrix("y")
o = set_subtensor(x[:, s], y)
mode = get_default_mode().including("local_useless_inc_subtensor") mode = get_default_mode().including("local_useless_inc_subtensor")
for s in [slice(None), slice(None, None, -1)]:
o = set_subtensor(x[::, s], y) # Test without shape info (i.e. don't apply the opt)
f = function([x, y], o, mode=mode) f = function([x, y], o, mode=mode)
o_shape = set_subtensor(x[::, s], specify_shape(y, x.shape))
f_shape = function([x, y], o_shape, mode=mode) topo = f.maker.fgraph.toposort()
assert len(topo) == 1
# Test with shape info assert isinstance(topo[0].op, IncSubtensor)
topo = f_shape.maker.fgraph.toposort()
assert not any(isinstance(n.op, IncSubtensor) for n in topo) # Test with shape info
out = f_shape([[2, 3]], [[3, 4]]) o_shape = set_subtensor(x[:, s], specify_shape(y, x.shape))
assert (out == np.asarray([[3, 4]])[::, s]).all() f_shape = function([x, y], o_shape, mode=mode)
# Test that without shape info, we don't apply the opt. topo = f_shape.maker.fgraph.toposort()
topo = f.maker.fgraph.toposort() assert not any(isinstance(n.op, IncSubtensor) for n in topo)
assert len(topo) == 1
assert isinstance(topo[0].op, IncSubtensor) out = f_shape([[2, 3]], [[3, 4]])
out = f([[2, 3]], [[3, 4]]) assert np.array_equal(out, np.asarray([[3, 4]])[::, s])
assert (out == np.asarray([[3, 4]])[::, s]).all()
# Test that we don't remove shape error def test_local_useless_inc_subtensor_increment_zeros():
with pytest.raises(ValueError): r"""Make sure we remove `IncSubtensor`\s that are increments on entire zero arrays."""
f([[2, 3]], [[3, 4], [4, 5]]) y = matrix("y")
# Test that we don't remove broadcastability s = aet.zeros((2, 2))[:, :]
out = f([[2, 3], [3, 4]], [[5, 6]]) o_shape = inc_subtensor(s, specify_shape(y, s.shape))
assert (out == np.asarray([[5, 6], [5, 6]])[::, s]).all()
mode = get_default_mode().including("local_useless_inc_subtensor")
# Test that we do not optimize others strides even when sub and y f_shape = function([y], o_shape, mode=mode)
# have same shapes
s = x[::, ::2] topo = f_shape.maker.fgraph.toposort()
assert not any(isinstance(n.op, IncSubtensor) for n in topo)
def test_local_useless_inc_subtensor_no_opt():
r"""Make sure we don't remove `IncSubtensor`\s that involve slices with steps that skip elements and non-zero increments."""
x = matrix("x")
y = matrix("y")
s = x[:, ::2]
o_shape = set_subtensor(s, specify_shape(y, s.shape)) o_shape = set_subtensor(s, specify_shape(y, s.shape))
f_shape = function([x, y], o_shape)
mode = get_default_mode().including("local_useless_inc_subtensor")
f_shape = function([x, y], o_shape, mode=mode)
topo = f_shape.maker.fgraph.toposort() topo = f_shape.maker.fgraph.toposort()
assert any(isinstance(n.op, IncSubtensor) for n in topo) assert any(isinstance(n.op, IncSubtensor) for n in topo)
out = f_shape([[2, 3, 6, 7]], [[8, 9]]) out = f_shape([[2, 3, 6, 7]], [[8, 9]])
assert (out == np.asarray([[8, 3, 9, 7]])).all() assert np.array_equal(out, np.asarray([[8, 3, 9, 7]]))
# This is an increment with a non-constant target array
s = x[:, :]
o_shape = inc_subtensor(s, specify_shape(y, s.shape))
f_shape = function([x, y], o_shape, mode=mode)
topo = f_shape.maker.fgraph.toposort()
assert any(isinstance(n.op, IncSubtensor) for n in topo)
# This is an increment with a non-zero target array
s = aet.ones((2, 2))[:, :]
o_shape = inc_subtensor(s, specify_shape(y, s.shape))
f_shape = function([y], o_shape, mode=mode)
topo = f_shape.maker.fgraph.toposort()
assert any(isinstance(n.op, IncSubtensor) for n in topo)
def test_local_useless_subtensor(): def test_local_useless_subtensor():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论