提交 92cf4887 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add an option to improve the speed of AdvancedIncSubtensor

This improvement is due to the use of `x[indices] += y` instead of `np.add.at(x, indices, y)`, but comes at the cost of not handling duplicate `indices` values.
上级 7a61628f
...@@ -1258,33 +1258,55 @@ def set_subtensor(x, y, inplace=False, tolerate_inplace_aliasing=False): ...@@ -1258,33 +1258,55 @@ def set_subtensor(x, y, inplace=False, tolerate_inplace_aliasing=False):
def inc_subtensor( def inc_subtensor(
x, y, inplace=False, set_instead_of_inc=False, tolerate_inplace_aliasing=False x,
y,
inplace=False,
set_instead_of_inc=False,
tolerate_inplace_aliasing=False,
ignore_duplicates=False,
): ):
""" """Update the value of an indexed array by a given amount.
Return x with the given subtensor incremented by y.
This is equivalent to ``x[indices] += y`` or ``np.add.at(x, indices, y)``,
depending on the value of `ignore_duplicates`.
Parameters Parameters
---------- ----------
x x
The symbolic result of a Subtensor operation. The symbolic result of a Subtensor operation.
y y
The amount by which to increment the subtensor in question. The amount by which to increment the array.
inplace inplace
Don't use. Aesara will do it when possible. Don't use. Aesara will do in-place operations itself, when possible.
set_instead_of_inc set_instead_of_inc
If True, do a set_subtensor instead. If True, do a set_subtensor instead.
tolerate_inplace_aliasing: tolerate_inplace_aliasing:
Allow x and y to be views of a single underlying array even while Allow `x` and `y` to be views of a single underlying array even while
working inplace. For correct results, x and y must not be overlapping working in-place. For correct results, `x` and `y` must not be overlapping
views; if they overlap, the result of this Op will generally be views; if they overlap, the result of this `Op` will generally be
incorrect. This value has no effect if inplace=False. incorrect. This value has no effect if ``inplace=False``.
ignore_duplicates
This determines whether or not ``x[indices] += y`` is used or
``np.add.at(x, indices, y)``. When the special duplicates handling of
``np.add.at`` isn't required, setting this option to ``True``
(i.e. using ``x[indices] += y``) can resulting in faster compiled
graphs.
Examples Examples
-------- --------
To replicate the numpy expression "r[10:] += 5", type To replicate the expression ``r[10:] += 5``:
>>> r = ivector() ..code-block:: python
>>> new_r = inc_subtensor(r[10:], 5)
r = ivector()
new_r = inc_subtensor(r[10:], 5)
To replicate the expression ``r[[0, 1, 0]] += 5``:
..code-block:: python
r = ivector()
new_r = inc_subtensor(r[10:], 5, ignore_duplicates=True)
""" """
# First of all, y cannot have a higher dimension than x, # First of all, y cannot have a higher dimension than x,
...@@ -1329,12 +1351,23 @@ def inc_subtensor( ...@@ -1329,12 +1351,23 @@ def inc_subtensor(
elif isinstance(x.owner.op, AdvancedSubtensor1): elif isinstance(x.owner.op, AdvancedSubtensor1):
real_x = x.owner.inputs[0] real_x = x.owner.inputs[0]
ilist = x.owner.inputs[1] ilist = x.owner.inputs[1]
the_op = AdvancedIncSubtensor1(inplace, set_instead_of_inc=set_instead_of_inc) if ignore_duplicates:
the_op = AdvancedIncSubtensor(
inplace, set_instead_of_inc=set_instead_of_inc, ignore_duplicates=True
)
else:
the_op = AdvancedIncSubtensor1(
inplace, set_instead_of_inc=set_instead_of_inc
)
return the_op(real_x, y, ilist) return the_op(real_x, y, ilist)
elif isinstance(x.owner.op, AdvancedSubtensor): elif isinstance(x.owner.op, AdvancedSubtensor):
real_x = x.owner.inputs[0] real_x = x.owner.inputs[0]
ilist = x.owner.inputs[1:] ilist = x.owner.inputs[1:]
the_op = AdvancedIncSubtensor(inplace, set_instead_of_inc=set_instead_of_inc) the_op = AdvancedIncSubtensor(
inplace,
set_instead_of_inc=set_instead_of_inc,
ignore_duplicates=ignore_duplicates,
)
return the_op(real_x, y, *ilist) return the_op(real_x, y, *ilist)
elif isinstance(x.owner.op, DimShuffle): elif isinstance(x.owner.op, DimShuffle):
inner_x = x.owner.inputs[0] inner_x = x.owner.inputs[0]
...@@ -1366,6 +1399,7 @@ def inc_subtensor( ...@@ -1366,6 +1399,7 @@ def inc_subtensor(
inplace=inplace, inplace=inplace,
set_instead_of_inc=set_instead_of_inc, set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing, tolerate_inplace_aliasing=tolerate_inplace_aliasing,
ignore_duplicates=ignore_duplicates,
) )
# The broadcastable pattern of inner_x may not be the same as # The broadcastable pattern of inner_x may not be the same as
# the one of x, so we have to build a new dimshuffle here, # the one of x, so we have to build a new dimshuffle here,
...@@ -1398,6 +1432,7 @@ def inc_subtensor( ...@@ -1398,6 +1432,7 @@ def inc_subtensor(
inplace=inplace, inplace=inplace,
set_instead_of_inc=set_instead_of_inc, set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing, tolerate_inplace_aliasing=tolerate_inplace_aliasing,
ignore_duplicates=ignore_duplicates,
) )
return inner_incsubtensor return inner_incsubtensor
else: else:
...@@ -2610,13 +2645,16 @@ advanced_subtensor = AdvancedSubtensor() ...@@ -2610,13 +2645,16 @@ advanced_subtensor = AdvancedSubtensor()
class AdvancedIncSubtensor(Op): class AdvancedIncSubtensor(Op):
"""Increments a subtensor using advanced indexing.""" """Increments a subtensor using advanced indexing."""
__props__ = ("inplace", "set_instead_of_inc") __props__ = ("inplace", "set_instead_of_inc", "ignore_duplicates")
def __init__(self, inplace=False, set_instead_of_inc=False): def __init__(
self, inplace=False, set_instead_of_inc=False, ignore_duplicates=False
):
self.set_instead_of_inc = set_instead_of_inc self.set_instead_of_inc = set_instead_of_inc
self.inplace = inplace self.inplace = inplace
if inplace: if inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
self.ignore_duplicates = ignore_duplicates
def __str__(self): def __str__(self):
return "{}{{{}, {}}}".format( return "{}{{{}, {}}}".format(
...@@ -2642,18 +2680,22 @@ class AdvancedIncSubtensor(Op): ...@@ -2642,18 +2680,22 @@ class AdvancedIncSubtensor(Op):
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
check_advanced_indexing_dimensions(inputs[0], inputs[2:]) x, y, *indices = inputs
check_advanced_indexing_dimensions(x, indices)
(out,) = out_ (out,) = out_
if not self.inplace: if not self.inplace:
out[0] = inputs[0].copy() out[0] = x.copy()
else: else:
out[0] = inputs[0] out[0] = x
if self.set_instead_of_inc: if self.set_instead_of_inc:
out[0][tuple(inputs[2:])] = inputs[1] out[0][tuple(indices)] = y
elif self.ignore_duplicates:
out[0][tuple(indices)] += y
else: else:
np.add.at(out[0], tuple(inputs[2:]), inputs[1]) np.add.at(out[0], tuple(indices), y)
def infer_shape(self, fgraph, node, ishapes): def infer_shape(self, fgraph, node, ishapes):
return [ishapes[0]] return [ishapes[0]]
...@@ -2699,6 +2741,10 @@ class AdvancedIncSubtensor(Op): ...@@ -2699,6 +2741,10 @@ class AdvancedIncSubtensor(Op):
advanced_inc_subtensor = AdvancedIncSubtensor() advanced_inc_subtensor = AdvancedIncSubtensor()
advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True) advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True)
advanced_inc_subtensor_nodup = AdvancedIncSubtensor(ignore_duplicates=True)
advanced_set_subtensor_nodup = AdvancedIncSubtensor(
set_instead_of_inc=True, ignore_duplicates=True
)
def take(a, indices, axis=None, mode="raise"): def take(a, indices, axis=None, mode="raise"):
......
...@@ -239,7 +239,8 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): ...@@ -239,7 +239,8 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
This is only done when there's a single vector index. This is only done when there's a single vector index.
""" """
if not isinstance(node.op, AdvancedIncSubtensor): if not isinstance(node.op, AdvancedIncSubtensor) or node.op.ignore_duplicates:
# `AdvancedIncSubtensor1` does not ignore duplicate index values
return return
res = node.inputs[0] res = node.inputs[0]
...@@ -249,15 +250,17 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): ...@@ -249,15 +250,17 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
axis = get_advsubtensor_axis(indices) axis = get_advsubtensor_axis(indices)
if axis is None or indices[axis].dtype == "bool": if axis is None or indices[axis].dtype == "bool":
# Booleans aren't handled # Booleans aren't currently handled by `AdvancedIncSubtensor1`
return return
new_subtensor = transform_take(res, indices[axis], axis) new_subtensor = transform_take(res, indices[axis], axis)
set_instead_of_inc = node.op.set_instead_of_inc
inplace = node.op.inplace
new_res = inc_subtensor( new_res = inc_subtensor(
new_subtensor, val, inplace=inplace, set_instead_of_inc=set_instead_of_inc new_subtensor,
val,
inplace=node.op.inplace,
set_instead_of_inc=node.op.set_instead_of_inc,
ignore_duplicates=False,
) )
copy_stack_trace(node.outputs[0], new_res) copy_stack_trace(node.outputs[0], new_res)
return [new_res] return [new_res]
...@@ -1290,7 +1293,9 @@ compile.optdb.register( ...@@ -1290,7 +1293,9 @@ compile.optdb.register(
def local_inplace_AdvancedIncSubtensor(fgraph, node): def local_inplace_AdvancedIncSubtensor(fgraph, node):
if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace: if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace:
new_op = type(node.op)( new_op = type(node.op)(
inplace=True, set_instead_of_inc=node.op.set_instead_of_inc inplace=True,
set_instead_of_inc=node.op.set_instead_of_inc,
ignore_duplicates=node.op.ignore_duplicates,
) )
new_node = new_op(*node.inputs) new_node = new_op(*node.inputs)
copy_stack_trace(node.outputs, new_node) copy_stack_trace(node.outputs, new_node)
......
...@@ -1617,9 +1617,6 @@ class TestIncSubtensor: ...@@ -1617,9 +1617,6 @@ class TestIncSubtensor:
class TestIncSubtensor1: class TestIncSubtensor1:
# test inc_subtensor
# also tests set_subtensor
def setup_method(self): def setup_method(self):
self.rng = np.random.default_rng(seed=utt.fetch_seed()) self.rng = np.random.default_rng(seed=utt.fetch_seed())
...@@ -1650,6 +1647,13 @@ class TestIncSubtensor1: ...@@ -1650,6 +1647,13 @@ class TestIncSubtensor1:
aval = f([0.4, 0.9, 0.1], [1, 2]) aval = f([0.4, 0.9, 0.1], [1, 2])
assert np.allclose(aval, [0.4, 0.9, 0.1]) assert np.allclose(aval, [0.4, 0.9, 0.1])
@pytest.mark.parametrize("ignore_duplicates", [True, False])
def test_inc_subtensor_AdvancedSubtensor1(self, ignore_duplicates):
x = AdvancedSubtensor1()(self.v, self.adv1q)
a = inc_subtensor(x, self.v[self.adv1q], ignore_duplicates=ignore_duplicates)
assert isinstance(a.owner.op, (AdvancedIncSubtensor1, AdvancedIncSubtensor))
assert getattr(a.owner.op, "ignore_duplicates", False) == ignore_duplicates
def test_1d_inc_adv_selection(self): def test_1d_inc_adv_selection(self):
a = inc_subtensor(self.v[self.adv1q], self.v[self.adv1q]) a = inc_subtensor(self.v[self.adv1q], self.v[self.adv1q])
...@@ -1886,83 +1890,142 @@ class TestAdvancedSubtensor: ...@@ -1886,83 +1890,142 @@ class TestAdvancedSubtensor:
rval = ft4v[:, :, ix2v, None, :] rval = ft4v[:, :, ix2v, None, :]
utt.assert_allclose(rval, aval) utt.assert_allclose(rval, aval)
def test_inc_adv_subtensor_w_2vec(self): @pytest.mark.parametrize(
"ignore_duplicates",
[
True,
False,
],
)
def test_inc_adv_subtensor_w_2vec(self, ignore_duplicates):
subt = self.m[self.ix1, self.ix12] subt = self.m[self.ix1, self.ix12]
a = inc_subtensor(subt, subt) a = inc_subtensor(subt, subt, ignore_duplicates=ignore_duplicates)
typ = TensorType(self.m.type.dtype, self.ix2.type.broadcastable) typ = TensorType(self.m.type.dtype, self.ix2.type.broadcastable)
assert a.type == typ, (a.type, typ) assert a.type == typ
f = aesara.function( f = aesara.function(
[self.m, self.ix1, self.ix12], a, allow_input_downcast=True, mode=self.mode [self.m, self.ix1, self.ix12], a, allow_input_downcast=True, mode=self.mode
) )
aval = f([[0.4, 0.9, 0.1], [5, 6, 7], [0.5, 0.3, 0.15]], [1, 2, 1], [0, 1, 0])
assert np.allclose(
aval, [[0.4, 0.9, 0.1], [5 * 3, 6, 7], [0.5, 0.3 * 2, 0.15]]
), aval
def test_inc_adv_subtensor_with_broadcasting(self): m_val = [[0.4, 0.9, 0.1], [5, 6, 7], [0.5, 0.3, 0.15]]
ix1_val = [1, 2, 1]
ix2_val = [0, 1, 0]
exp_aval = np.array(m_val)
if ignore_duplicates:
exp_aval[ix1_val, ix2_val] += exp_aval[ix1_val, ix2_val]
else:
np.add.at(exp_aval, (ix1_val, ix2_val), exp_aval[ix1_val, ix2_val])
aval = f(m_val, ix1_val, ix2_val)
assert np.allclose(aval, exp_aval)
@pytest.mark.parametrize(
"ignore_duplicates",
[
True,
False,
],
)
def test_inc_adv_subtensor_with_broadcasting(self, ignore_duplicates):
inc = dscalar() inc = dscalar()
a = inc_subtensor(self.m[self.ix1, self.ix12], inc) a = inc_subtensor(
self.m[self.ix1, self.ix12], inc, ignore_duplicates=ignore_duplicates
)
g_inc = aesara.grad(a.sum(), inc) g_inc = aesara.grad(a.sum(), inc)
assert a.type == self.m.type, (a.type, self.m.type) assert a.type == self.m.type
f = aesara.function( f = aesara.function(
[self.m, self.ix1, self.ix12, inc], [self.m, self.ix1, self.ix12, inc],
[a, g_inc], [a, g_inc],
allow_input_downcast=True, allow_input_downcast=True,
mode=self.mode, mode=self.mode,
) )
aval, gval = f(
[[0.4, 0.9, 0.1], [5, 6, 7], [0.5, 0.3, 0.15]], [1, 2, 1], [0, 1, 0], 2.1
)
assert np.allclose(
aval, [[0.4, 0.9, 0.1], [5 + 2.1 * 2, 6, 7], [0.5, 0.3 + 2.1, 0.15]]
), aval
assert np.allclose(gval, 3.0), gval
def test_inc_adv_subtensor1_with_broadcasting(self): m_val = [[0.4, 0.9, 0.1], [5, 6, 7], [0.5, 0.3, 0.15]]
ix1_val = [1, 2, 1]
ix2_val = [0, 1, 0]
inc_val = 2.1
exp_aval = np.array(m_val)
if ignore_duplicates:
exp_aval[ix1_val, ix2_val] += inc_val
else:
np.add.at(exp_aval, (ix1_val, ix2_val), inc_val)
aval, gval = f(m_val, ix1_val, ix2_val, inc_val)
assert np.allclose(aval, exp_aval)
assert np.allclose(gval, 3.0)
@pytest.mark.parametrize(
"ignore_duplicates",
[
True,
False,
],
)
def test_inc_adv_subtensor1_with_broadcasting(self, ignore_duplicates):
inc = dscalar() inc = dscalar()
a = inc_subtensor(self.m[self.ix1], inc) a = inc_subtensor(self.m[self.ix1], inc, ignore_duplicates=ignore_duplicates)
g_inc = aesara.grad(a.sum(), inc) g_inc = aesara.grad(a.sum(), inc)
assert a.type == self.m.type, (a.type, self.m.type) assert a.type == self.m.type
f = aesara.function( f = aesara.function(
[self.m, self.ix1, inc], [self.m, self.ix1, inc],
[a, g_inc], [a, g_inc],
allow_input_downcast=True, allow_input_downcast=True,
mode=self.mode, mode=self.mode,
) )
aval, gval = f([[0.4, 0.9, 0.1], [5, 6, 7], [0.5, 0.3, 0.15]], [0, 1, 0], 2.1)
assert np.allclose(
aval,
[
[0.4 + 2.1 * 2, 0.9 + 2.1 * 2, 0.1 + 2.1 * 2],
[5 + 2.1, 6 + 2.1, 7 + 2.1],
[0.5, 0.3, 0.15],
],
), aval
assert np.allclose(gval, 9.0), gval
def test_inc_adv_subtensor_with_index_broadcasting(self): m_val = [[0.4, 0.9, 0.1], [5, 6, 7], [0.5, 0.3, 0.15]]
a = inc_subtensor(self.m[self.ix1, self.ix2], 2.1) ix1_val = [0, 1, 0]
inc_val = 2.1
exp_aval = np.array(m_val).copy()
if ignore_duplicates:
exp_aval[ix1_val] += inc_val
else:
np.add.at(exp_aval, ix1_val, inc_val)
aval, gval = f(m_val, ix1_val, inc_val)
assert np.allclose(aval, exp_aval)
assert np.allclose(gval, 9.0)
@pytest.mark.parametrize(
"ignore_duplicates",
[
True,
False,
],
)
def test_inc_adv_subtensor_with_index_broadcasting(self, ignore_duplicates):
a = inc_subtensor(
self.m[self.ix1, self.ix2], 2.1, ignore_duplicates=ignore_duplicates
)
assert a.type == self.m.type
assert a.type == self.m.type, (a.type, self.m.type)
f = aesara.function( f = aesara.function(
[self.m, self.ix1, self.ix2], a, allow_input_downcast=True, mode=self.mode [self.m, self.ix1, self.ix2], a, allow_input_downcast=True, mode=self.mode
) )
aval = f(
[[0.4, 0.9, 0.1], [5, 6, 7], [0.5, 0.3, 0.15]], m_val = [[0.4, 0.9, 0.1], [5, 6, 7], [0.5, 0.3, 0.15]]
[0, 2, 0], ix1_val = [0, 2, 0]
[[0, 1, 0], [2, 2, 2]], ix2_val = [[0, 1, 0], [2, 2, 2]]
)
assert np.allclose( inc_val = 2.1
aval, exp_aval = np.array(m_val)
[ if ignore_duplicates:
[0.4 + 2 * 2.1, 0.9, 0.1 + 2 * 2.1], exp_aval[ix1_val, ix2_val] += inc_val
[5, 6, 7], else:
[0.5, 0.3 + 2.1, 0.15 + 2.1], np.add.at(exp_aval, (ix1_val, ix2_val), inc_val)
],
), aval aval = f(m_val, ix1_val, ix2_val)
assert np.allclose(aval, exp_aval)
def test_2d_3d_tensors(self): def test_2d_3d_tensors(self):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论