提交 730f790e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Convert Scan global optimizers to local optimizers

上级 5a3c0195
""" """This module provides optimizations for the `Scan` `Op`."""
This module provides optimizations for scan.
The Optimization provided in this file:
local opt: remove_constants_and_unused_inputs_scan,
constant_folding_for_scan2,
scan_merge_inouts
They are wrapped in in2out to create global opt.
global opt: ScanInplaceOptimizer,
PushOutNonSeqScan,
PushOutSeqScan,
PushOutDot1,
ScanMerge,
ScanSaveMem
How the are registered:
optdb: scan_eqopt1 (.1), scan_eqopt2(1.6), scan_inplace(75)
scan_eqopt1 -> scan_seqopt1
scan_seqopt1 -> in2out(remove_constants_and_unused_inputs_scan)(1),
PushOutNonSeqScan(2),
PushOutSeqScan(3), PushOutDot1(4)
scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
This is important, as the order is important and all global
optimizer run before local optimizer in the order they where
registered. (So don't change the order we register them!)
If we convert to local optimizer, we must convert all of them
to local optimizer. But:
1) can ScanMerge be made local? Can we keep only this one
global?
2) ScanSaveMem assert that we remove all nodes outputs,
we need to keep this.
3) It is ScanSaveMem suppose the the others ran before.
I added an assert at one place, but didn't looked for
other place.
4) Moving this to local opt could speed up significant this opt,
as we pass frequently on all nodes in the graph for no
good reason.
5) We register remove_constant_* many places, as some
opt create them and let this one clean up the mess.
Doing it that way, make things simpler for those already
complex opt.
in2out(constant_folding),
in2out(remove_constants_and_unused_inputs_scan1),
ScanMerge,
in2out(remove_constants_and_unused_inputs_scan2),
in2out(scan_merge_inouts),
ScanSaveMem,
in2out(remove_constants_and_unused_inputs_scan3)
"""
import copy import copy
import dataclasses import dataclasses
import logging import logging
from sys import maxsize from sys import maxsize
from typing import Dict, List, Tuple
import numpy as np import numpy as np
...@@ -65,6 +16,7 @@ from aesara.compile.function.types import deep_copy_op ...@@ -65,6 +16,7 @@ from aesara.compile.function.types import deep_copy_op
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import ( from aesara.graph.basic import (
Constant, Constant,
Node,
Variable, Variable,
clone_replace, clone_replace,
equal_computations, equal_computations,
...@@ -74,7 +26,7 @@ from aesara.graph.basic import ( ...@@ -74,7 +26,7 @@ from aesara.graph.basic import (
) )
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import ReplaceValidate from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import InconsistencyError from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import compute_test_value from aesara.graph.op import compute_test_value
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB from aesara.graph.optdb import EquilibriumDB, SequenceDB
...@@ -236,38 +188,19 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -236,38 +188,19 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
return False return False
class PushOutNonSeqScan(GlobalOptimizer): @local_optimizer([Scan])
r"""Pushing out the variables inside the `Scan` that depend only on non-sequences. def push_out_non_seq_scan(fgraph, node):
r"""Push out the variables inside the `Scan` that depend only on non-sequences.
This optimizations pushes, out of `Scan`'s inner function and into the outer This optimizations pushes, out of `Scan`'s inner function and into the outer
function, computation that depends only on non-sequence inputs. Such function, computation that depends only on non-sequence inputs. Such
computation ends up being done every iteration on the same values so moving computation ends up being done every iteration on the same values so moving
it to the outer function to be executed only once, before the `Scan` `Op`, it to the outer function to be executed only once, before the `Scan` `Op`,
reduces the amount of computation that needs to be performed. reduces the amount of computation that needs to be performed.
TODO: This is a global opt for historical reasonons. It should be possible
to change it to a local opt.
"""
def __init__(self):
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
def apply(self, fgraph):
nodelist = [x for x in fgraph.toposort() if isinstance(x.op, Scan)]
for node in nodelist:
self.process_node(fgraph, node)
def process_node(self, fgraph, node):
""" """
IMPORTANT NOTE: This function uses set and dictionary data structures. if not isinstance(node.op, Scan):
By default they are not ordered for efficiency reasons. Take care return False
and make sure of changing them with their Ordered counterparts if you
need to iterate over these variables.
"""
# this flag tells if there was any change during the last iterations # this flag tells if there was any change during the last iterations
clean_inputs, clean_outputs = reconstruct_graph(node.op.inputs, node.op.outputs) clean_inputs, clean_outputs = reconstruct_graph(node.op.inputs, node.op.outputs)
...@@ -337,16 +270,9 @@ class PushOutNonSeqScan(GlobalOptimizer): ...@@ -337,16 +270,9 @@ class PushOutNonSeqScan(GlobalOptimizer):
elif isinstance(x, Constant): elif isinstance(x, Constant):
outside_ins.append(x.clone()) outside_ins.append(x.clone())
else: else:
raise Exception( # TODO: Explain why is this an error, and raise an
( # appropriate exception type.
"Error in the `scan_pushout_non_seq_" raise RuntimeError()
"operations`. The optimization tries "
"to move some computation from scan "
"which is not allowed to move. Report "
"this on aesara-users list"
),
x,
)
outside_ins = [ outside_ins = [
x.type.filter_variable(y) for x, y in zip(nd.inputs, outside_ins) x.type.filter_variable(y) for x, y in zip(nd.inputs, outside_ins)
] ]
...@@ -425,12 +351,9 @@ class PushOutNonSeqScan(GlobalOptimizer): ...@@ -425,12 +351,9 @@ class PushOutNonSeqScan(GlobalOptimizer):
# Do not call make_node for test_value # Do not call make_node for test_value
nw_node = nwScan(*(node.inputs + nw_outer), return_list=True)[0].owner nw_node = nwScan(*(node.inputs + nw_outer), return_list=True)[0].owner
fgraph.replace_all_validate_remove( replacements = dict(zip(node.outputs, nw_node.outputs))
list(zip(node.outputs, nw_node.outputs)), replacements["remove"] = [node]
remove=[node], return replacements
reason="scanOp_pushout_nonseqs_ops",
)
return True
elif not to_keep_set: elif not to_keep_set:
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = {} replace_with = {}
...@@ -448,11 +371,8 @@ class PushOutNonSeqScan(GlobalOptimizer): ...@@ -448,11 +371,8 @@ class PushOutNonSeqScan(GlobalOptimizer):
if len(node.outputs) == len(replace_with): if len(node.outputs) == len(replace_with):
# Every output of the node has a replacement, the Scan # Every output of the node has a replacement, the Scan
# node can be removed from the graph # node can be removed from the graph
fgraph.replace_all_validate_remove( replace_with["remove"] = [node]
replace_with.items(), return replace_with
remove=[node],
reason="scanOp_pushout_nonseqs_ops",
)
else: else:
# The node has some outputs for which no replacement has # The node has some outputs for which no replacement has
# been established. This can occur for outputs that are # been established. This can occur for outputs that are
...@@ -461,15 +381,14 @@ class PushOutNonSeqScan(GlobalOptimizer): ...@@ -461,15 +381,14 @@ class PushOutNonSeqScan(GlobalOptimizer):
# passed directly as outputs. The replacements can be # passed directly as outputs. The replacements can be
# performed but the Scan node can't be removed at this # performed but the Scan node can't be removed at this
# point. # point.
fgraph.replace_all_validate( return replace_with
replace_with.items(), reason="scanOp_pushout_nonseqs_ops"
)
else: else:
return False return False
class PushOutSeqScan(GlobalOptimizer): @local_optimizer([Scan])
def push_out_seq_scan(fgraph, node):
r"""Push out the variables inside the `Scan` that depend only on constants and sequences. r"""Push out the variables inside the `Scan` that depend only on constants and sequences.
This optimization resembles `PushOutNonSeqScan` but it tries to push, out of This optimization resembles `PushOutNonSeqScan` but it tries to push, out of
...@@ -479,30 +398,10 @@ class PushOutSeqScan(GlobalOptimizer): ...@@ -479,30 +398,10 @@ class PushOutSeqScan(GlobalOptimizer):
a single operation on a large tensor rather then perform that same operation a single operation on a large tensor rather then perform that same operation
many times on many smaller tensors. In many cases, this optimization can many times on many smaller tensors. In many cases, this optimization can
increase memory usage but, in some specific cases, it can also decrease it. increase memory usage but, in some specific cases, it can also decrease it.
TODO: This is a global opt for historical reasonons. It should be possible
to change it to a local opt.
""" """
if not isinstance(node.op, Scan):
return False
def __init__(self):
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
def apply(self, fgraph):
nodelist = [x for x in fgraph.toposort() if isinstance(x.op, Scan)]
for node in nodelist:
self.process_node(fgraph, node)
def process_node(self, fgraph, node):
"""
IMPORTANT NOTE: This function uses set and dictionary data structure.
By default they are not ordered for efficiency reasons. Take care
and make sure of changing them to Ordered versions if you need to
iterate over those variables.
"""
# this flag tells if there was any change during the last iterations # this flag tells if there was any change during the last iterations
clean_inputs, clean_outputs = reconstruct_graph(node.op.inputs, node.op.outputs) clean_inputs, clean_outputs = reconstruct_graph(node.op.inputs, node.op.outputs)
...@@ -607,10 +506,7 @@ class PushOutSeqScan(GlobalOptimizer): ...@@ -607,10 +506,7 @@ class PushOutSeqScan(GlobalOptimizer):
elif ( elif (
nd not in to_remove_set nd not in to_remove_set
and isinstance(nd.op, DimShuffle) and isinstance(nd.op, DimShuffle)
and ( and (nd.inputs[0] in inner_seqs_set or nd.inputs[0].owner in to_remove_set)
nd.inputs[0] in inner_seqs_set
or nd.inputs[0].owner in to_remove_set
)
): ):
to_remove_set.add(nd) to_remove_set.add(nd)
...@@ -687,9 +583,7 @@ class PushOutSeqScan(GlobalOptimizer): ...@@ -687,9 +583,7 @@ class PushOutSeqScan(GlobalOptimizer):
op_ins = nw_inner + clean_inputs op_ins = nw_inner + clean_inputs
# Reconstruct node # Reconstruct node
nw_info = dataclasses.replace( nw_info = dataclasses.replace(op.info, n_seqs=op.info.n_seqs + len(nw_inner))
op.info, n_seqs=op.info.n_seqs + len(nw_inner)
)
nwScan = Scan( nwScan = Scan(
op_ins, op_ins,
op_outs, op_outs,
...@@ -709,12 +603,10 @@ class PushOutSeqScan(GlobalOptimizer): ...@@ -709,12 +603,10 @@ class PushOutSeqScan(GlobalOptimizer):
return_list=True, return_list=True,
)[0].owner )[0].owner
fgraph.replace_all_validate_remove( replacements = dict(zip(node.outputs, nw_node.outputs))
list(zip(node.outputs, nw_node.outputs)), replacements["remove"] = [node]
remove=[node], return replacements
reason="scanOp_pushout_seqs_ops",
)
return True
elif not to_keep_set and not op.as_while and not op.outer_mitmot(node.inputs): elif not to_keep_set and not op.as_while and not op.outer_mitmot(node.inputs):
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = {} replace_with = {}
...@@ -740,154 +632,21 @@ class PushOutSeqScan(GlobalOptimizer): ...@@ -740,154 +632,21 @@ class PushOutSeqScan(GlobalOptimizer):
# We need to add one extra dimension to the outputs # We need to add one extra dimension to the outputs
if replace_with and len(replace_with) == len(node.outputs): if replace_with and len(replace_with) == len(node.outputs):
fgraph.replace_all_validate_remove( replacements = dict(replace_with.items())
list(replace_with.items()), replacements["remove"] = [node]
remove=[node], return replacements
reason="scanOp_pushout_seqs_ops",
)
return True
else: else:
return False return False
class PushOutScanOutput(GlobalOptimizer): def inner_sitsot_only_last_step_used(
r"""Push operations performed at the end of the inner graph of `Scan` to outside of `Scan`. fgraph: FunctionGraph, var: Variable, scan_args: ScanArgs
) -> bool:
This optimizations attempts to push out some of the computation at the end
of the inner function to the outer function, to be executed after the `Scan`
node. Like `PushOutSeqScan`, this optimization aims to replace many operations
on small tensors by few operations on large tensors. It can also lead to
increased memory usage.
"""
def __init__(self):
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
def apply(self, fgraph):
# Don't perform the optimization on as_while scans. Because these scans
# don't run for a predetermined number of steps, handling them is
# more complicated and this optimization doesn't support it at the
# moment.
nodelist = [
x
for x in fgraph.toposort()
if (isinstance(x.op, Scan) and not x.op.as_while)
]
for node in nodelist:
# Process the node as long as something gets optimized
while node is not None:
node = self.process_node(fgraph, node)
def process_node(self, fgraph, node):
op = node.op
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args = ScanArgs(
node.inputs, node.outputs, op.inputs, op.outputs, op.info, op.as_while
)
new_scan_node = None
clients = {}
local_fgraph_topo = io_toposort(
args.inner_inputs, args.inner_outputs, clients=clients
)
for nd in local_fgraph_topo:
if (
isinstance(nd.op, Elemwise)
and isinstance(nd.op.scalar_op, aes.Add)
and nd.out in args.inner_out_sit_sot
and self.inner_sitsot_only_last_step_used(fgraph, nd.out, args)
):
# Ensure that one of the input to the add is the output of
# the add from a previous iteration of the inner function
sitsot_idx = args.inner_out_sit_sot.index(nd.out)
if args.inner_in_sit_sot[sitsot_idx] in nd.inputs:
# Ensure that the other input to the add is a dot product
# between 2 matrices which will become a tensor3 and a
# matrix if pushed outside of the scan. Also make sure
# that the output of the Dot is ONLY used by the 'add'
# otherwise doing a Dot in the outer graph will only
# duplicate computation.
sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[sitsot_idx])
# 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
dot_in_idx = 1 - sitsot_in_idx
dot_input = nd.inputs[dot_in_idx]
if (
dot_input.owner is not None
and isinstance(dot_input.owner.op, Dot)
and len(clients[dot_input]) == 1
and dot_input.owner.inputs[0].ndim == 2
and dot_input.owner.inputs[1].ndim == 2
and self.get_outer_ndim(dot_input.owner.inputs[0], args) == 3
and self.get_outer_ndim(dot_input.owner.inputs[1], args) == 3
):
# The optimization can be be applied in this case.
# Move out of scan the two inputs to the Dot and
# perform a dot outside of scan on these two inputs
inner_dot_inputs = nd.inputs[dot_in_idx].owner.inputs
(
outer_dot_inputs,
new_scan_node,
new_scan_args,
) = self.push_out_inner_vars(
fgraph, inner_dot_inputs, node, args
)
# Collapse some of the dimensions of the tensors
# so that they become matrices. This is because a
# dot is usually faster on two large matrices than
# a bunch of small ones
outer_dot_inputs[0] = aet.flatten(
outer_dot_inputs[0].dimshuffle(1, 0, 2), ndim=2
)
shape_input1 = shape(outer_dot_inputs[1])
outer_dot_inputs[1] = outer_dot_inputs[1].reshape(
(shape_input1[0] * shape_input1[1], shape_input1[2])
)
# Perform the dot on the newly obtained matrices and
# add the initial value
outer_dot_output = dot(*outer_dot_inputs)
init_value = new_scan_args.outer_in_sit_sot[sitsot_idx][0]
replacement = outer_dot_output + init_value
# Alter the outer graph to use the output of the
# external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot
outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx]
subtensor_node = fgraph.clients[outer_sitsot][0][0]
outer_sitsot_last_step = subtensor_node.outputs[0]
fgraph.replace_all(
[(outer_sitsot_last_step, replacement)],
reason="scanOp_pushout_output",
)
break
return new_scan_node
def inner_sitsot_only_last_step_used(self, fgraph, var, scan_args):
""" """
Given a inner nit_sot output of scan, return True iff the outer Given a inner nit-sot output of `Scan`, return ``True`` iff the outer
nit_sot output has only one client and that client is a Subtensor nit-sot output has only one client and that client is a `Subtensor`
instance that takes only the last step (last element along the first instance that takes only the last step (last element along the first
axis). axis).
""" """
idx = scan_args.inner_out_sit_sot.index(var) idx = scan_args.inner_out_sit_sot.index(var)
outer_var = scan_args.outer_out_sit_sot[idx] outer_var = scan_args.outer_out_sit_sot[idx]
...@@ -901,23 +660,28 @@ class PushOutScanOutput(GlobalOptimizer): ...@@ -901,23 +660,28 @@ class PushOutScanOutput(GlobalOptimizer):
return False return False
def get_outer_ndim(self, var, scan_args):
# Given a variable, determine the number of dimension it would have if def get_outer_ndim(var: Variable, scan_args: ScanArgs) -> int:
# it was pushed out of scan """Determine the number of dimension a variable would have if it was pushed out of a `Scan`."""
if var in scan_args.inner_in_non_seqs or isinstance(var, Constant): if var in scan_args.inner_in_non_seqs or isinstance(var, Constant):
outer_ndim = var.ndim outer_ndim = var.ndim
else: else:
outer_ndim = var.ndim + 1 outer_ndim = var.ndim + 1
return outer_ndim return outer_ndim
def push_out_inner_vars(self, fgraph, inner_vars, old_scan_node, old_scan_args):
def push_out_inner_vars(
fgraph: FunctionGraph,
inner_vars: List[Variable],
old_scan_node: Node,
old_scan_args: ScanArgs,
) -> Tuple[List[Variable], ScanArgs, Dict[Variable, Variable]]:
outer_vars = [None] * len(inner_vars) outer_vars = [None] * len(inner_vars)
new_scan_node = old_scan_node new_scan_node = old_scan_node
new_scan_args = old_scan_args new_scan_args = old_scan_args
replacements = {}
# For the inner_vars that already exist in the outer graph, # For the inner_vars that already exist in the outer graph,
# simply obtain a reference to them # simply obtain a reference to them
...@@ -942,14 +706,12 @@ class PushOutScanOutput(GlobalOptimizer): ...@@ -942,14 +706,12 @@ class PushOutScanOutput(GlobalOptimizer):
# For the inner_vars that don't already exist in the outer graph, add # For the inner_vars that don't already exist in the outer graph, add
# them as new nitsot outputs to the scan node. # them as new nitsot outputs to the scan node.
idx_add_as_nitsots = [ idx_add_as_nitsots = [i for i in range(len(outer_vars)) if outer_vars[i] is None]
i for i in range(len(outer_vars)) if outer_vars[i] is None
]
add_as_nitsots = [inner_vars[idx] for idx in idx_add_as_nitsots] add_as_nitsots = [inner_vars[idx] for idx in idx_add_as_nitsots]
if len(add_as_nitsots) > 0: if len(add_as_nitsots) > 0:
new_scan_node = self.add_nitsot_outputs( new_scan_node, replacements = add_nitsot_outputs(
fgraph, old_scan_node, old_scan_args, add_as_nitsots fgraph, old_scan_node, old_scan_args, add_as_nitsots
) )
...@@ -966,20 +728,22 @@ class PushOutScanOutput(GlobalOptimizer): ...@@ -966,20 +728,22 @@ class PushOutScanOutput(GlobalOptimizer):
for i in range(len(new_outs)): for i in range(len(new_outs)):
outer_vars[idx_add_as_nitsots[i]] = new_outs[i] outer_vars[idx_add_as_nitsots[i]] = new_outs[i]
return outer_vars, new_scan_node, new_scan_args return outer_vars, new_scan_args, replacements
def add_nitsot_outputs(
self, fgraph, old_scan_node, old_scan_args, new_outputs_inner def add_nitsot_outputs(
): fgraph: FunctionGraph,
old_scan_node: Node,
old_scan_args: ScanArgs,
new_outputs_inner,
) -> Tuple[Node, Dict[Variable, Variable]]:
nb_new_outs = len(new_outputs_inner) nb_new_outs = len(new_outputs_inner)
# Create the initial values for the new nitsot outputs # Create the initial values for the new nitsot outputs
# (the initial value is the nb of steps to store. For a nistot, # (the initial value is the nb of steps to store. For a nistot,
# it should be the number of steps performed by scan) # it should be the number of steps performed by scan)
new_nitsots_initial_value = [ new_nitsots_initial_value = [old_scan_node.inputs[0] for i in range(nb_new_outs)]
old_scan_node.inputs[0] for i in range(nb_new_outs)
]
# Create the `ScanArgs` corresponding to the new `Scan` `Op` to create # Create the `ScanArgs` corresponding to the new `Scan` `Op` to create
new_scan_args = copy.copy(old_scan_args) new_scan_args = copy.copy(old_scan_args)
...@@ -1002,9 +766,7 @@ class PushOutScanOutput(GlobalOptimizer): ...@@ -1002,9 +766,7 @@ class PushOutScanOutput(GlobalOptimizer):
) )
# Create the Apply node for the scan op # Create the Apply node for the scan op
new_scan_node = new_scan_op(*new_scan_args.outer_inputs, return_list=True)[ new_scan_node = new_scan_op(*new_scan_args.outer_inputs, return_list=True)[0].owner
0
].owner
# Modify the outer graph to make sure the outputs of the new scan are # Modify the outer graph to make sure the outputs of the new scan are
# used instead of the outputs of the old scan # used instead of the outputs of the old scan
...@@ -1017,13 +779,123 @@ class PushOutScanOutput(GlobalOptimizer): ...@@ -1017,13 +779,123 @@ class PushOutScanOutput(GlobalOptimizer):
+ new_scan_node.outputs[new_node_new_outputs_idx + nb_new_outs :] + new_scan_node.outputs[new_node_new_outputs_idx + nb_new_outs :]
) )
# TODO FIXME:
# replacements = dict(zip(old_scan_node.outputs, new_node_old_outputs))
# replacements["remove"] = [old_scan_node]
# return new_scan_node, replacements
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
list(zip(old_scan_node.outputs, new_node_old_outputs)), list(zip(old_scan_node.outputs, new_node_old_outputs)),
remove=[old_scan_node], remove=[old_scan_node],
reason="scanOp_pushout_output", reason="scan_pushout_output",
) )
return new_scan_node, {}
@local_optimizer([Scan])
def push_out_add_scan(fgraph, node):
r"""Push `Add` operations performed at the end of the inner graph to the outside.
return new_scan_node Like `push_out_seq_scan`, this optimization aims to replace many operations
on small tensors by few operations on large tensors. It can also lead to
increased memory usage.
"""
# Don't perform the optimization on `as_while` `Scan`s. Because these
# `Scan`s don't run for a predetermined number of steps, handling them is
# more complicated and this optimization doesn't support it at the moment.
if not (isinstance(node.op, Scan) and not node.op.as_while):
return False
op = node.op
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args = ScanArgs(
node.inputs, node.outputs, op.inputs, op.outputs, op.info, op.as_while
)
clients = {}
local_fgraph_topo = io_toposort(
args.inner_inputs, args.inner_outputs, clients=clients
)
for nd in local_fgraph_topo:
if (
isinstance(nd.op, Elemwise)
and isinstance(nd.op.scalar_op, aes.Add)
and nd.out in args.inner_out_sit_sot
and inner_sitsot_only_last_step_used(fgraph, nd.out, args)
):
# Ensure that one of the input to the add is the output of
# the add from a previous iteration of the inner function
sitsot_idx = args.inner_out_sit_sot.index(nd.out)
if args.inner_in_sit_sot[sitsot_idx] in nd.inputs:
# Ensure that the other input to the add is a dot product
# between 2 matrices which will become a tensor3 and a
# matrix if pushed outside of the scan. Also make sure
# that the output of the Dot is ONLY used by the 'add'
# otherwise doing a Dot in the outer graph will only
# duplicate computation.
sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[sitsot_idx])
# 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
dot_in_idx = 1 - sitsot_in_idx
dot_input = nd.inputs[dot_in_idx]
if (
dot_input.owner is not None
and isinstance(dot_input.owner.op, Dot)
and len(clients[dot_input]) == 1
and dot_input.owner.inputs[0].ndim == 2
and dot_input.owner.inputs[1].ndim == 2
and get_outer_ndim(dot_input.owner.inputs[0], args) == 3
and get_outer_ndim(dot_input.owner.inputs[1], args) == 3
):
# The optimization can be be applied in this case.
# Move out of scan the two inputs to the Dot and
# perform a dot outside of scan on these two inputs
inner_dot_inputs = nd.inputs[dot_in_idx].owner.inputs
(
outer_dot_inputs,
new_scan_args,
replacements,
) = push_out_inner_vars(fgraph, inner_dot_inputs, node, args)
# Collapse some of the dimensions of the tensors
# so that they become matrices. This is because a
# dot is usually faster on two large matrices than
# a bunch of small ones
outer_dot_inputs[0] = aet.flatten(
outer_dot_inputs[0].dimshuffle(1, 0, 2), ndim=2
)
shape_input1 = shape(outer_dot_inputs[1])
outer_dot_inputs[1] = outer_dot_inputs[1].reshape(
(shape_input1[0] * shape_input1[1], shape_input1[2])
)
# Perform the dot on the newly obtained matrices and
# add the initial value
outer_dot_output = dot(*outer_dot_inputs)
init_value = new_scan_args.outer_in_sit_sot[sitsot_idx][0]
replacement = outer_dot_output + init_value
# Alter the outer graph to use the output of the
# external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot
outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx]
subtensor_node = fgraph.clients[outer_sitsot][0][0]
outer_sitsot_last_step = subtensor_node.outputs[0]
replacements[outer_sitsot_last_step] = replacement
return replacements
return False
class ScanInplaceOptimizer(GlobalOptimizer): class ScanInplaceOptimizer(GlobalOptimizer):
...@@ -1203,7 +1075,31 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1203,7 +1075,31 @@ class ScanInplaceOptimizer(GlobalOptimizer):
node = self.attempt_scan_inplace(fgraph, node, [pos], alloc_ops) node = self.attempt_scan_inplace(fgraph, node, [pos], alloc_ops)
class ScanSaveMem(GlobalOptimizer): def select_min(x, y):
if x is None:
return y
if y is None:
return x
return minimum(x, y)
def select_max(x, y):
if x is None:
return y
if y is None:
return x
return maximum(x, y)
def sanitize(x):
if x is None:
return None
else:
return aet.as_tensor_variable(x)
@local_optimizer([Scan])
def save_mem_new_scan(fgraph, node):
r"""Graph optimizer that reduces scan memory consumption. r"""Graph optimizer that reduces scan memory consumption.
This optimizations attempts to determine if a `Scan` node, during its execution, This optimizations attempts to determine if a `Scan` node, during its execution,
...@@ -1224,35 +1120,8 @@ class ScanSaveMem(GlobalOptimizer): ...@@ -1224,35 +1120,8 @@ class ScanSaveMem(GlobalOptimizer):
be kept in memory. be kept in memory.
""" """
if not isinstance(node.op, Scan):
def __init__(self): return False
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
def process_node(self, fgraph, node):
# helpful functions
def select_min(x, y):
if x is None:
return y
if y is None:
return x
return minimum(x, y)
def select_max(x, y):
if x is None:
return y
if y is None:
return x
return maximum(x, y)
def sanitize(x):
if x is None:
return None
else:
return aet.as_tensor_variable(x)
if hasattr(fgraph, "shape_feature"): if hasattr(fgraph, "shape_feature"):
shape_of = fgraph.shape_feature.shape_of shape_of = fgraph.shape_feature.shape_of
...@@ -1487,17 +1356,12 @@ class ScanSaveMem(GlobalOptimizer): ...@@ -1487,17 +1356,12 @@ class ScanSaveMem(GlobalOptimizer):
first_mitsot_idx = node.op.n_mit_mot first_mitsot_idx = node.op.n_mit_mot
last_sitsot_idx = ( last_sitsot_idx = (
node.op.n_mit_mot node.op.n_mit_mot + node.op.n_mit_sot + node.op.n_sit_sot - 1
+ node.op.n_mit_sot
+ node.op.n_sit_sot
- 1
) )
preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx
if prealloc_outs and preallocable_output: if prealloc_outs and preallocable_output:
pval = select_max( pval = select_max(nw_steps - start + init_l[i], init_l[i] + 1)
nw_steps - start + init_l[i], init_l[i] + 1
)
else: else:
pval = select_max(nw_steps - start + init_l[i], init_l[i]) pval = select_max(nw_steps - start + init_l[i], init_l[i])
...@@ -1544,9 +1408,7 @@ class ScanSaveMem(GlobalOptimizer): ...@@ -1544,9 +1408,7 @@ class ScanSaveMem(GlobalOptimizer):
# TODO: commit change below with Razvan # TODO: commit change below with Razvan
if ( if (
nw_inputs[offset + idx].owner nw_inputs[offset + idx].owner
and isinstance( and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor)
nw_inputs[offset + idx].owner.op, IncSubtensor
)
and isinstance( and isinstance(
nw_inputs[offset + idx].owner.op.idx_list[0], slice nw_inputs[offset + idx].owner.op.idx_list[0], slice
) )
...@@ -1558,9 +1420,7 @@ class ScanSaveMem(GlobalOptimizer): ...@@ -1558,9 +1420,7 @@ class ScanSaveMem(GlobalOptimizer):
_nw_input = nw_inputs[offset + idx].owner.inputs[1] _nw_input = nw_inputs[offset + idx].owner.inputs[1]
cval = aet.as_tensor_variable(val) cval = aet.as_tensor_variable(val)
initl = aet.as_tensor_variable(init_l[i]) initl = aet.as_tensor_variable(init_l[i])
tmp_idx = aet.switch( tmp_idx = aet.switch(cval < initl, cval + initl, cval - initl)
cval < initl, cval + initl, cval - initl
)
nw_input = expand_empty(_nw_input, tmp_idx) nw_input = expand_empty(_nw_input, tmp_idx)
else: else:
tmp = aet.as_tensor_variable(val) tmp = aet.as_tensor_variable(val)
...@@ -1645,7 +1505,7 @@ class ScanSaveMem(GlobalOptimizer): ...@@ -1645,7 +1505,7 @@ class ScanSaveMem(GlobalOptimizer):
# TODO: currently we don't support scan with 0 step. So # TODO: currently we don't support scan with 0 step. So
# don't create one. # don't create one.
if aet.extract_constant(node_ins[0]) == 0: if aet.extract_constant(node_ins[0]) == 0:
return return False
# Do not call make_node for test_value # Do not call make_node for test_value
new_op = Scan( new_op = Scan(
...@@ -1758,19 +1618,18 @@ class ScanSaveMem(GlobalOptimizer): ...@@ -1758,19 +1618,18 @@ class ScanSaveMem(GlobalOptimizer):
] ]
if any(old_scan_is_used): if any(old_scan_is_used):
return False return False
remove = [old.owner for (old, new) in old_new]
replacements = dict(old_new)
# remove = [old.owner for (old, new) in old_new]
# As Fred suggested assert that also the old node is not in # As Fred suggested assert that also the old node is not in
# the Graph as that will make things suboptimal # the Graph as that will make things suboptimal
remove.append(node) # remove.append(node)
fgraph.replace_all_validate_remove( replacements["remove"] = [node]
old_new, remove, reason="scanOp_save_mem"
)
def apply(self, fgraph): return replacements
nodelist = [x for x in fgraph.toposort() if isinstance(x.op, Scan)] return False
for node in nodelist:
self.process_node(fgraph, node)
class ScanMerge(GlobalOptimizer): class ScanMerge(GlobalOptimizer):
...@@ -2271,27 +2130,16 @@ def scan_merge_inouts(fgraph, node): ...@@ -2271,27 +2130,16 @@ def scan_merge_inouts(fgraph, node):
return na.outer_outputs return na.outer_outputs
class PushOutDot1(GlobalOptimizer): @local_optimizer([Scan])
def push_out_dot1_scan(fgraph, node):
r""" r"""
This is another optimization that attempts to detect certain patterns of This is another optimization that attempts to detect certain patterns of
computation in a `Scan` `Op`'s inner function and move this computation to the computation in a `Scan` `Op`'s inner function and move this computation to the
outer graph. outer graph.
""" """
if not isinstance(node.op, Scan):
return False
def __init__(self):
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
def apply(self, fgraph):
nodes = fgraph.toposort()
scan_nodes = [x for x in nodes if (isinstance(x.op, Scan))]
for node in scan_nodes:
self.apply_opt(fgraph, node)
def apply_opt(self, fgraph, node):
# Replace pattern of the form # Replace pattern of the form
# x[t] = x[t-1] + dot(seq[t], value) # x[t] = x[t-1] + dot(seq[t], value)
# with Sequence.reshape((-1, seq.shape[2])) \dot Value # with Sequence.reshape((-1, seq.shape[2])) \dot Value
...@@ -2470,9 +2318,11 @@ class PushOutDot1(GlobalOptimizer): ...@@ -2470,9 +2318,11 @@ class PushOutDot1(GlobalOptimizer):
old = fgraph.clients[node.outputs[pos]][0][0].outputs[0] old = fgraph.clients[node.outputs[pos]][0][0].outputs[0]
old_new.append((old, new_out)) old_new.append((old, new_out))
old_new += list(zip(node.outputs[pos + 1 :], new_outs[pos:])) old_new += list(zip(node.outputs[pos + 1 :], new_outs[pos:]))
fgraph.replace_all_validate_remove( replacements = dict(old_new)
old_new, remove=[node], reason="scan_pushout_dot1" replacements["remove"] = [node]
) return replacements
return False
# I've added an equilibrium because later scan optimization in the sequence # I've added an equilibrium because later scan optimization in the sequence
...@@ -2490,7 +2340,13 @@ optdb.register("scan_eqopt1", scan_eqopt1, 0.05, "fast_run", "scan") ...@@ -2490,7 +2340,13 @@ optdb.register("scan_eqopt1", scan_eqopt1, 0.05, "fast_run", "scan")
# but after stabilize at 1.5. Should we put it before stabilize? # but after stabilize at 1.5. Should we put it before stabilize?
optdb.register("scan_eqopt2", scan_eqopt2, 1.6, "fast_run", "scan") optdb.register("scan_eqopt2", scan_eqopt2, 1.6, "fast_run", "scan")
# ScanSaveMem should execute only once per node. # ScanSaveMem should execute only once per node.
optdb.register("scanOp_save_mem", ScanSaveMem(), 1.61, "fast_run", "scan") optdb.register(
"scanOp_save_mem",
in2out(save_mem_new_scan, ignore_newtrees=True),
1.61,
"fast_run",
"scan",
)
optdb.register( optdb.register(
"scanOp_make_inplace", "scanOp_make_inplace",
ScanInplaceOptimizer(typeInfer=None), ScanInplaceOptimizer(typeInfer=None),
...@@ -2514,22 +2370,41 @@ scan_seqopt1.register( ...@@ -2514,22 +2370,41 @@ scan_seqopt1.register(
scan_seqopt1.register( scan_seqopt1.register(
"scanOp_pushout_nonseqs_ops", PushOutNonSeqScan(), 2, "fast_run", "scan" "scanOp_pushout_nonseqs_ops",
in2out(push_out_non_seq_scan, ignore_newtrees=True),
2,
"fast_run",
"scan",
) )
scan_seqopt1.register( scan_seqopt1.register(
"scanOp_pushout_seqs_ops", PushOutSeqScan(), 3, "fast_run", "scan" "scanOp_pushout_seqs_ops",
in2out(push_out_seq_scan, ignore_newtrees=True),
3,
"fast_run",
"scan",
) )
scan_seqopt1.register( scan_seqopt1.register(
"scan_pushout_dot1", PushOutDot1(), 4, "fast_run", "more_mem", "scan" "scan_pushout_dot1",
in2out(push_out_dot1_scan, ignore_newtrees=True),
4,
"fast_run",
"more_mem",
"scan",
) )
scan_seqopt1.register( scan_seqopt1.register(
"scanOp_pushout_output", PushOutScanOutput(), 5, "fast_run", "more_mem", "scan" "scanOp_pushout_output",
# TODO: Perhaps this should be an `EquilibriumOptimizer`?
in2out(push_out_add_scan, ignore_newtrees=False),
5,
"fast_run",
"more_mem",
"scan",
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论