提交 5fbaecc3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Allow AdvancedIncSubtensor to be in-placed

上级 ebea5e22
...@@ -3290,10 +3290,6 @@ compile.optdb.register( ...@@ -3290,10 +3290,6 @@ compile.optdb.register(
@local_optimizer([IncSubtensor], inplace=True) @local_optimizer([IncSubtensor], inplace=True)
def local_inplace_setsubtensor(fgraph, node): def local_inplace_setsubtensor(fgraph, node):
"""
Also work for GpuIncSubtensor.
"""
if isinstance(node.op, IncSubtensor) and not node.op.inplace: if isinstance(node.op, IncSubtensor) and not node.op.inplace:
dta = node.op.destroyhandler_tolerate_aliased dta = node.op.destroyhandler_tolerate_aliased
new_op = node.op.__class__( new_op = node.op.__class__(
...@@ -3322,36 +3318,51 @@ compile.optdb.register( ...@@ -3322,36 +3318,51 @@ compile.optdb.register(
60, 60,
"fast_run", "fast_run",
"inplace", "inplace",
) # DEBUG )
@local_optimizer([AdvancedIncSubtensor1], inplace=True) @local_optimizer([AdvancedIncSubtensor1], inplace=True)
def local_inplace_incsubtensor1(fgraph, node): def local_inplace_AdvancedIncSubtensor1(fgraph, node):
"""
Also work for GpuAdvancedIncSubtensor1.
"""
if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace: if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace:
new_op = node.op.clone_inplace() new_op = node.op.clone_inplace()
new_node = new_op(*node.inputs) new_node = new_op(*node.inputs)
copy_stack_trace(node.outputs, new_node)
return [new_node]
return False
# Copy stacktrace from original outputs to new outputs.
# This is sensible, because the new operation is the compile.optdb.register(
# same as the old one, but now with different attributes. "local_inplace_AdvancedIncSubtensor1",
TopoOptimizer(
local_inplace_AdvancedIncSubtensor1, failure_callback=TopoOptimizer.warn_inplace
),
60,
"fast_run",
"inplace",
)
@local_optimizer([AdvancedIncSubtensor], inplace=True)
def local_inplace_AdvancedIncSubtensor(fgraph, node):
if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace:
new_op = type(node.op)(
inplace=True, set_instead_of_inc=node.op.set_instead_of_inc
)
new_node = new_op(*node.inputs)
copy_stack_trace(node.outputs, new_node) copy_stack_trace(node.outputs, new_node)
return [new_node] return [new_node]
return False return False
compile.optdb.register( compile.optdb.register(
"local_inplace_incsubtensor1", "local_inplace_AdvancedIncSubtensor",
TopoOptimizer( TopoOptimizer(
local_inplace_incsubtensor1, failure_callback=TopoOptimizer.warn_inplace local_inplace_AdvancedIncSubtensor, failure_callback=TopoOptimizer.warn_inplace
), ),
60, 60,
"fast_run", "fast_run",
"inplace", "inplace",
) # DEBUG )
# Register old name # Register old name
......
...@@ -2605,13 +2605,10 @@ class AdvancedIncSubtensor(Op): ...@@ -2605,13 +2605,10 @@ class AdvancedIncSubtensor(Op):
__props__ = ("inplace", "set_instead_of_inc") __props__ = ("inplace", "set_instead_of_inc")
def __init__(self, inplace=False, set_instead_of_inc=False): def __init__(self, inplace=False, set_instead_of_inc=False):
self.inplace = inplace
self.set_instead_of_inc = set_instead_of_inc self.set_instead_of_inc = set_instead_of_inc
# The assert is needed as in the pass the first argument was self.inplace = inplace
# something else that was not used. if inplace:
assert isinstance(inplace, bool) self.destroy_map = {0: [0]}
if self.inplace:
raise NotImplementedError("In place computation is not" " implemented")
def __str__(self): def __str__(self):
return "{}{{{}, {}}}".format( return "{}{{{}, {}}}".format(
...@@ -2636,8 +2633,6 @@ class AdvancedIncSubtensor(Op): ...@@ -2636,8 +2633,6 @@ class AdvancedIncSubtensor(Op):
) )
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
# TODO: 1. opt to make this in place 2. generalize as described in
# AdvancedSubtensor's perform TODO
check_advanced_indexing_dimensions(inputs[0], inputs[2:]) check_advanced_indexing_dimensions(inputs[0], inputs[2:])
......
...@@ -1673,6 +1673,24 @@ class TestSubtensorIncSubtensor: ...@@ -1673,6 +1673,24 @@ class TestSubtensorIncSubtensor:
def setup_class(cls): def setup_class(cls):
cls.mode = get_default_mode().including("local_subtensor_inc_subtensor") cls.mode = get_default_mode().including("local_subtensor_inc_subtensor")
@pytest.mark.parametrize(
"val, indices, optype",
[
(vector(), (iscalar(),), IncSubtensor),
(vector(), (ivector(),), AdvancedIncSubtensor1),
(vector(), (ivector(), ivector()), AdvancedIncSubtensor),
],
)
def test_inplace(self, val, indices, optype):
x = matrix("x")
y = set_subtensor((2 * x)[indices], val, inplace=False)
assert isinstance(y.owner.op, optype)
assert y.owner.op.inplace is False
f = function(
[x, val] + list(indices), y, mode=get_default_mode().including("inplace")
)
assert f.maker.fgraph.outputs[0].owner.op.inplace is True
def test_basic(self): def test_basic(self):
# basic test # basic test
x = matrix("x") x = matrix("x")
......
...@@ -1732,7 +1732,14 @@ class TestAdvancedSubtensor: ...@@ -1732,7 +1732,14 @@ class TestAdvancedSubtensor:
self.ix2 = lmatrix() self.ix2 = lmatrix()
self.ixr = lrow() self.ixr = lrow()
def test_advinc_subtensor(self): @pytest.mark.parametrize(
"inplace",
[
True,
False,
],
)
def test_advinc_subtensor(self, inplace):
x_shp = (20, 15, 10, 5) x_shp = (20, 15, 10, 5)
def check(idx, y_val, x_val, true): def check(idx, y_val, x_val, true):
...@@ -1741,8 +1748,12 @@ class TestAdvancedSubtensor: ...@@ -1741,8 +1748,12 @@ class TestAdvancedSubtensor:
dtype="float32", broadcastable=(False,) * len(y_val.shape), name="y" dtype="float32", broadcastable=(False,) * len(y_val.shape), name="y"
) )
sym_idx = [aet.as_tensor_variable(ix) for ix in idx] sym_idx = [aet.as_tensor_variable(ix) for ix in idx]
expr = advanced_inc_subtensor(x, y, *sym_idx) expr = AdvancedIncSubtensor(inplace=inplace)(x, y, *sym_idx)
f = aesara.function([y], expr, mode=self.mode) f = aesara.function(
[y], expr, mode=self.mode.excluding("inplace"), accept_inplace=inplace
)
fgraph = f.maker.fgraph
assert fgraph.outputs[0].owner.op.inplace == inplace
rval = f(y_val) rval = f(y_val)
assert np.allclose(rval, true) assert np.allclose(rval, true)
...@@ -1759,11 +1770,14 @@ class TestAdvancedSubtensor: ...@@ -1759,11 +1770,14 @@ class TestAdvancedSubtensor:
x_val = np.arange(np.prod(x_shp), dtype="float32").reshape(x_shp) + 1 x_val = np.arange(np.prod(x_shp), dtype="float32").reshape(x_shp) + 1
y_val = np.arange(np.prod(y_shp), dtype="float32").reshape(y_shp) + 1 y_val = np.arange(np.prod(y_shp), dtype="float32").reshape(y_shp) + 1
rep = x_val.copy() rep = x_val.copy()
try: try:
rep[idx] += y_val rep[idx] += y_val
except ValueError: except ValueError:
continue continue
check(idx, y_val, x_val, rep) check(idx, y_val, x_val, rep)
x_val = np.arange(np.prod(x_shp), dtype="float32").reshape(x_shp) + 1 x_val = np.arange(np.prod(x_shp), dtype="float32").reshape(x_shp) + 1
y_val = np.array(1).astype(np.float32) y_val = np.array(1).astype(np.float32)
rep = x_val.copy() rep = x_val.copy()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论