提交 74fb5433 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Prevent unnecessary Scan inplace rewrites

上级 a3dc0a72
......@@ -4,7 +4,7 @@ import copy
import dataclasses
from itertools import chain
from sys import maxsize
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, cast
import numpy as np
......@@ -928,32 +928,32 @@ class ScanInplaceOptimizer(GlobalOptimizer):
"""
def __init__(self, typeInfer=None):
super().__init__()
self.typeInfer = typeInfer
alloc_ops = (Alloc, AllocEmpty)
"""
Classes that represent operation that allocate new memory and that the
optimization should duplicate so it can operate inplace on them.
"""
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
fgraph.attach_feature(DestroyHandler())
def attempt_scan_inplace(self, fgraph, node, output_indices, alloc_ops):
def attempt_scan_inplace(
self, fgraph: FunctionGraph, node: Apply, output_indices: List[int]
) -> Optional[Apply]:
"""Attempt to replace a `Scan` node by one which computes the specified outputs inplace.
Parameters
----------
fgraph : FunctionGraph
fgraph
Function graph in which to attempt the replacement
node : Apply node
node
Scan node to replace by an inplace version
output_indices : list of integers
output_indices
Indices of the outputs to attempt to compute inplace
alloc_ops : list of Op classes
Classes that represent operation that allocate new memory and
that the optimization should duplicate so it can operate inplace
on them.
"""
op = node.op
op: Scan = cast(Scan, node.op)
# inputs corresponding to sequences and n_steps
ls_begin = node.inputs[: 1 + op.info.n_seqs]
......@@ -964,14 +964,14 @@ class ScanInplaceOptimizer(GlobalOptimizer):
ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs)
# In `ls`, duplicate any input which has more then one client and is
# In `ls`, duplicate any input which has more than one client and is
# the output of an eligible allocation op
for i in range(len(ls)):
inp = ls[i]
if (
len(fgraph.clients[inp]) > 1
and inp.owner
and isinstance(inp.owner.op, alloc_ops)
and isinstance(inp.owner.op, self.alloc_ops)
):
new_lsi = inp.owner.op.make_node(*inp.owner.inputs)
......@@ -991,23 +991,8 @@ class ScanInplaceOptimizer(GlobalOptimizer):
ls[idx] = deep_copy_op(ls[idx])
inputs = ls_begin + ls + ls_end
if self.typeInfer is None:
typeConstructor = None
else:
typeConstructor = self.typeInfer(node)
new_op = Scan(
op.inner_inputs,
op.inner_outputs,
op.info,
mode=op.mode,
typeConstructor=typeConstructor,
profile=op.profile,
truncate_gradient=op.truncate_gradient,
# TODO: This seems questionable
name=op.name,
allow_gc=op.allow_gc,
)
new_op = op.clone()
destroy_map = op.destroy_map.copy()
for out_idx in output_indices:
......@@ -1016,9 +1001,16 @@ class ScanInplaceOptimizer(GlobalOptimizer):
new_op.destroy_map = destroy_map
# Do not call make_node for test_value
new_outs = new_op(*inputs, return_list=True)
new_outs: List[Variable] = new_op(*inputs, return_list=True)
try:
fgraph.replace_all_validate_remove(
# TODO FIXME: We need to stop using this approach (i.e. attempt
# in-place replacements and wait for downstream failures to revert
# the changes). It prevents us from making smart, clear
# rewrites and it adds a lot of unnecessary overhead that
# involves dealing with inconsistent graphs.
# This whole rewrite should be a simple local rewrite, but, because
# of this awful approach, it can't be.
fgraph.replace_all_validate_remove( # type: ignore
list(zip(node.outputs, new_outs)),
remove=[node],
reason="scan_make_inplace",
......@@ -1026,20 +1018,19 @@ class ScanInplaceOptimizer(GlobalOptimizer):
return new_outs[0].owner
except InconsistencyError:
# Failed moving output to be computed inplace
return node
return None
def apply(self, fgraph):
alloc_ops = (Alloc, AllocEmpty)
nodes = fgraph.toposort()[::-1]
scan_nodes = [x for x in nodes if (isinstance(x.op, Scan))]
for scan_idx in range(len(scan_nodes)):
for scan_idx, original_node in enumerate(reversed(fgraph.toposort())):
if not isinstance(original_node.op, Scan):
continue
# First attempt to make the Scan compute inplace every recurrent
# output that seems like it could be computed inplace. If that
# fails, go through these outputs individually, trying each of
# them.
original_node = scan_nodes[scan_idx]
op = original_node.op
n_outs = op.info.n_mit_mot + op.info.n_mit_sot + op.info.n_sit_sot
......@@ -1053,7 +1044,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
# If the input is from an eligible allocation node, attempt to
# be inplace on it, even if other nodes are modifying it
# inplace.
if inp.owner and isinstance(inp.owner.op, alloc_ops):
if inp.owner and isinstance(inp.owner.op, self.alloc_ops):
out_indices.append(out_idx)
continue
......@@ -1079,16 +1070,21 @@ class ScanInplaceOptimizer(GlobalOptimizer):
if not input_used_inplace:
out_indices.append(out_idx)
node = self.attempt_scan_inplace(
fgraph, scan_nodes[scan_idx], out_indices, alloc_ops
)
if len(out_indices) == 0:
continue
new_node = self.attempt_scan_inplace(fgraph, original_node, out_indices)
if node is original_node:
if new_node is None:
# Making the scan compute all plausible recurrent outputs
# inplace has failed. Attempt all plausible recurrent output
# inplace has failed. Attempt all plausible recurrent outputs
# individually.
new_node = original_node
for pos in out_indices:
node = self.attempt_scan_inplace(fgraph, node, [pos], alloc_ops)
new_node = (
self.attempt_scan_inplace(fgraph, new_node, [pos]) or new_node
)
def select_min(x, y):
......@@ -2367,7 +2363,7 @@ optdb.register(
)
optdb.register(
"scan_make_inplace",
ScanInplaceOptimizer(typeInfer=None),
ScanInplaceOptimizer(),
"fast_run",
"inplace",
"scan",
......
......@@ -9,9 +9,10 @@ from aesara.compile.io import In
from aesara.compile.mode import get_default_mode
from aesara.configdefaults import config
from aesara.gradient import grad, jacobian
from aesara.graph.basic import clone_replace
from aesara.graph.basic import clone_replace, equal_computations
from aesara.graph.fg import FunctionGraph
from aesara.scan.op import Scan
from aesara.scan.opt import ScanMerge
from aesara.scan.opt import ScanInplaceOptimizer, ScanMerge
from aesara.scan.utils import until
from aesara.tensor.blas import Dot22
from aesara.tensor.elemwise import Elemwise
......@@ -912,6 +913,49 @@ class TestScanMerge:
class TestScanInplaceOptimizer:
mode = get_default_mode().including("scan_make_inplace", "inplace")
def test_no_inplace(self):
"""Make sure the rewrite doesn't make unnecessary replacements."""
x = at.vector("x")
scan_out, _ = aesara.scan(
lambda x: (x + 1) / 2 + 1,
sequences=[x],
)
fgraph = FunctionGraph(
outputs=[scan_out], clone=True, copy_inputs=False, copy_orphans=False
)
_ = ScanInplaceOptimizer().apply(fgraph)
fgraph_op = fgraph.outputs[0].owner.inputs[0].owner.op
assert not fgraph_op.destroy_map
assert equal_computations([scan_out], fgraph.outputs)
def test_inplace_basic(self):
scan_out, _ = aesara.scan(
lambda x: x + 1,
outputs_info=[at.zeros(1)],
n_steps=3,
)
fgraph = FunctionGraph(
outputs=[scan_out], clone=True, copy_inputs=False, copy_orphans=False
)
assert equal_computations([scan_out], fgraph.outputs)
_ = ScanInplaceOptimizer().apply(fgraph)
# The graphs shouldn't change; only the `Op.destroy_map`s
assert equal_computations([scan_out], fgraph.outputs)
fgraph_op = fgraph.outputs[0].owner.inputs[0].owner.op
assert fgraph_op.destroy_map == {0: [1]}
assert not scan_out.owner.inputs[0].owner.op.destroy_map
@utt.assertFailure_fast
def test_simple_rnn(self):
"""Simple RNN; compute inplace version 1."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论