提交 74fb5433 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Prevent unnecessary Scan inplace rewrites

上级 a3dc0a72
...@@ -4,7 +4,7 @@ import copy ...@@ -4,7 +4,7 @@ import copy
import dataclasses import dataclasses
from itertools import chain from itertools import chain
from sys import maxsize from sys import maxsize
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, cast
import numpy as np import numpy as np
...@@ -928,32 +928,32 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -928,32 +928,32 @@ class ScanInplaceOptimizer(GlobalOptimizer):
""" """
def __init__(self, typeInfer=None): alloc_ops = (Alloc, AllocEmpty)
super().__init__() """
self.typeInfer = typeInfer Classes that represent operation that allocate new memory and that the
optimization should duplicate so it can operate inplace on them.
"""
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate()) fgraph.attach_feature(ReplaceValidate())
fgraph.attach_feature(DestroyHandler()) fgraph.attach_feature(DestroyHandler())
def attempt_scan_inplace(self, fgraph, node, output_indices, alloc_ops): def attempt_scan_inplace(
self, fgraph: FunctionGraph, node: Apply, output_indices: List[int]
) -> Optional[Apply]:
"""Attempt to replace a `Scan` node by one which computes the specified outputs inplace. """Attempt to replace a `Scan` node by one which computes the specified outputs inplace.
Parameters Parameters
---------- ----------
fgraph : FunctionGraph fgraph
Function graph in which to attempt the replacement Function graph in which to attempt the replacement
node : Apply node node
Scan node to replace by an inplace version Scan node to replace by an inplace version
output_indices : list of integers output_indices
Indices of the outputs to attempt to compute inplace Indices of the outputs to attempt to compute inplace
alloc_ops : list of Op classes
Classes that represent operation that allocate new memory and
that the optimization should duplicate so it can operate inplace
on them.
""" """
op = node.op op: Scan = cast(Scan, node.op)
# inputs corresponding to sequences and n_steps # inputs corresponding to sequences and n_steps
ls_begin = node.inputs[: 1 + op.info.n_seqs] ls_begin = node.inputs[: 1 + op.info.n_seqs]
...@@ -964,14 +964,14 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -964,14 +964,14 @@ class ScanInplaceOptimizer(GlobalOptimizer):
ls_end += op.outer_nitsot(node.inputs) ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs) ls_end += op.outer_non_seqs(node.inputs)
# In `ls`, duplicate any input which has more then one client and is # In `ls`, duplicate any input which has more than one client and is
# the output of an eligible allocation op # the output of an eligible allocation op
for i in range(len(ls)): for i in range(len(ls)):
inp = ls[i] inp = ls[i]
if ( if (
len(fgraph.clients[inp]) > 1 len(fgraph.clients[inp]) > 1
and inp.owner and inp.owner
and isinstance(inp.owner.op, alloc_ops) and isinstance(inp.owner.op, self.alloc_ops)
): ):
new_lsi = inp.owner.op.make_node(*inp.owner.inputs) new_lsi = inp.owner.op.make_node(*inp.owner.inputs)
...@@ -991,23 +991,8 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -991,23 +991,8 @@ class ScanInplaceOptimizer(GlobalOptimizer):
ls[idx] = deep_copy_op(ls[idx]) ls[idx] = deep_copy_op(ls[idx])
inputs = ls_begin + ls + ls_end inputs = ls_begin + ls + ls_end
if self.typeInfer is None:
typeConstructor = None
else:
typeConstructor = self.typeInfer(node)
new_op = Scan( new_op = op.clone()
op.inner_inputs,
op.inner_outputs,
op.info,
mode=op.mode,
typeConstructor=typeConstructor,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
name=op.name,
allow_gc=op.allow_gc,
)
destroy_map = op.destroy_map.copy() destroy_map = op.destroy_map.copy()
for out_idx in output_indices: for out_idx in output_indices:
...@@ -1016,9 +1001,16 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1016,9 +1001,16 @@ class ScanInplaceOptimizer(GlobalOptimizer):
new_op.destroy_map = destroy_map new_op.destroy_map = destroy_map
# Do not call make_node for test_value # Do not call make_node for test_value
new_outs = new_op(*inputs, return_list=True) new_outs: List[Variable] = new_op(*inputs, return_list=True)
try: try:
fgraph.replace_all_validate_remove( # TODO FIXME: We need to stop using this approach (i.e. attempt
# in-place replacements and wait for downstream failures to revert
# the changes). It prevents us from making smart, clear
# rewrites and it adds a lot of unnecessary overhead that
# involves dealing with inconsistent graphs.
# This whole rewrite should be a simple local rewrite, but, because
# of this awful approach, it can't be.
fgraph.replace_all_validate_remove( # type: ignore
list(zip(node.outputs, new_outs)), list(zip(node.outputs, new_outs)),
remove=[node], remove=[node],
reason="scan_make_inplace", reason="scan_make_inplace",
...@@ -1026,20 +1018,19 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1026,20 +1018,19 @@ class ScanInplaceOptimizer(GlobalOptimizer):
return new_outs[0].owner return new_outs[0].owner
except InconsistencyError: except InconsistencyError:
# Failed moving output to be computed inplace # Failed moving output to be computed inplace
return node return None
def apply(self, fgraph): def apply(self, fgraph):
alloc_ops = (Alloc, AllocEmpty) for scan_idx, original_node in enumerate(reversed(fgraph.toposort())):
nodes = fgraph.toposort()[::-1]
scan_nodes = [x for x in nodes if (isinstance(x.op, Scan))] if not isinstance(original_node.op, Scan):
for scan_idx in range(len(scan_nodes)): continue
# First attempt to make the Scan compute inplace every recurrent # First attempt to make the Scan compute inplace every recurrent
# output that seems like it could be computed inplace. If that # output that seems like it could be computed inplace. If that
# fails, go through these outputs individually, trying each of # fails, go through these outputs individually, trying each of
# them. # them.
original_node = scan_nodes[scan_idx]
op = original_node.op op = original_node.op
n_outs = op.info.n_mit_mot + op.info.n_mit_sot + op.info.n_sit_sot n_outs = op.info.n_mit_mot + op.info.n_mit_sot + op.info.n_sit_sot
...@@ -1053,7 +1044,7 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1053,7 +1044,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
# If the input is from an eligible allocation node, attempt to # If the input is from an eligible allocation node, attempt to
# be inplace on it, even if other nodes are modifying it # be inplace on it, even if other nodes are modifying it
# inplace. # inplace.
if inp.owner and isinstance(inp.owner.op, alloc_ops): if inp.owner and isinstance(inp.owner.op, self.alloc_ops):
out_indices.append(out_idx) out_indices.append(out_idx)
continue continue
...@@ -1079,16 +1070,21 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1079,16 +1070,21 @@ class ScanInplaceOptimizer(GlobalOptimizer):
if not input_used_inplace: if not input_used_inplace:
out_indices.append(out_idx) out_indices.append(out_idx)
node = self.attempt_scan_inplace( if len(out_indices) == 0:
fgraph, scan_nodes[scan_idx], out_indices, alloc_ops continue
)
new_node = self.attempt_scan_inplace(fgraph, original_node, out_indices)
if node is original_node: if new_node is None:
# Making the scan compute all plausible recurrent outputs # Making the scan compute all plausible recurrent outputs
# inplace has failed. Attempt all plausible recurrent output # inplace has failed. Attempt all plausible recurrent outputs
# individually. # individually.
new_node = original_node
for pos in out_indices: for pos in out_indices:
node = self.attempt_scan_inplace(fgraph, node, [pos], alloc_ops) new_node = (
self.attempt_scan_inplace(fgraph, new_node, [pos]) or new_node
)
def select_min(x, y): def select_min(x, y):
...@@ -2367,7 +2363,7 @@ optdb.register( ...@@ -2367,7 +2363,7 @@ optdb.register(
) )
optdb.register( optdb.register(
"scan_make_inplace", "scan_make_inplace",
ScanInplaceOptimizer(typeInfer=None), ScanInplaceOptimizer(),
"fast_run", "fast_run",
"inplace", "inplace",
"scan", "scan",
......
...@@ -9,9 +9,10 @@ from aesara.compile.io import In ...@@ -9,9 +9,10 @@ from aesara.compile.io import In
from aesara.compile.mode import get_default_mode from aesara.compile.mode import get_default_mode
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import grad, jacobian from aesara.gradient import grad, jacobian
from aesara.graph.basic import clone_replace from aesara.graph.basic import clone_replace, equal_computations
from aesara.graph.fg import FunctionGraph
from aesara.scan.op import Scan from aesara.scan.op import Scan
from aesara.scan.opt import ScanMerge from aesara.scan.opt import ScanInplaceOptimizer, ScanMerge
from aesara.scan.utils import until from aesara.scan.utils import until
from aesara.tensor.blas import Dot22 from aesara.tensor.blas import Dot22
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
...@@ -912,6 +913,49 @@ class TestScanMerge: ...@@ -912,6 +913,49 @@ class TestScanMerge:
class TestScanInplaceOptimizer: class TestScanInplaceOptimizer:
mode = get_default_mode().including("scan_make_inplace", "inplace") mode = get_default_mode().including("scan_make_inplace", "inplace")
def test_no_inplace(self):
"""Make sure the rewrite doesn't make unnecessary replacements."""
x = at.vector("x")
scan_out, _ = aesara.scan(
lambda x: (x + 1) / 2 + 1,
sequences=[x],
)
fgraph = FunctionGraph(
outputs=[scan_out], clone=True, copy_inputs=False, copy_orphans=False
)
_ = ScanInplaceOptimizer().apply(fgraph)
fgraph_op = fgraph.outputs[0].owner.inputs[0].owner.op
assert not fgraph_op.destroy_map
assert equal_computations([scan_out], fgraph.outputs)
def test_inplace_basic(self):
scan_out, _ = aesara.scan(
lambda x: x + 1,
outputs_info=[at.zeros(1)],
n_steps=3,
)
fgraph = FunctionGraph(
outputs=[scan_out], clone=True, copy_inputs=False, copy_orphans=False
)
assert equal_computations([scan_out], fgraph.outputs)
_ = ScanInplaceOptimizer().apply(fgraph)
# The graphs shouldn't change; only the `Op.destroy_map`s
assert equal_computations([scan_out], fgraph.outputs)
fgraph_op = fgraph.outputs[0].owner.inputs[0].owner.op
assert fgraph_op.destroy_map == {0: [1]}
assert not scan_out.owner.inputs[0].owner.op.destroy_map
@utt.assertFailure_fast @utt.assertFailure_fast
def test_simple_rnn(self): def test_simple_rnn(self):
"""Simple RNN; compute inplace version 1.""" """Simple RNN; compute inplace version 1."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论