提交 730f790e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Convert Scan global optimizers to local optimizers

上级 5a3c0195
"""
This module provides optimizations for scan.
The Optimization provided in this file:
local opt: remove_constants_and_unused_inputs_scan,
constant_folding_for_scan2,
scan_merge_inouts
They are wrapped in in2out to create global opt.
global opt: ScanInplaceOptimizer,
PushOutNonSeqScan,
PushOutSeqScan,
PushOutDot1,
ScanMerge,
ScanSaveMem
How the are registered:
optdb: scan_eqopt1 (.1), scan_eqopt2(1.6), scan_inplace(75)
scan_eqopt1 -> scan_seqopt1
scan_seqopt1 -> in2out(remove_constants_and_unused_inputs_scan)(1),
PushOutNonSeqScan(2),
PushOutSeqScan(3), PushOutDot1(4)
scan_eqopt2 -> They are all global optimizer. (in2out convert local to global).
This is important, as the order is important and all global
optimizer run before local optimizer in the order they where
registered. (So don't change the order we register them!)
If we convert to local optimizer, we must convert all of them
to local optimizer. But:
1) can ScanMerge be made local? Can we keep only this one
global?
2) ScanSaveMem assert that we remove all nodes outputs,
we need to keep this.
3) It is ScanSaveMem suppose the the others ran before.
I added an assert at one place, but didn't looked for
other place.
4) Moving this to local opt could speed up significant this opt,
as we pass frequently on all nodes in the graph for no
good reason.
5) We register remove_constant_* many places, as some
opt create them and let this one clean up the mess.
Doing it that way, make things simpler for those already
complex opt.
in2out(constant_folding),
in2out(remove_constants_and_unused_inputs_scan1),
ScanMerge,
in2out(remove_constants_and_unused_inputs_scan2),
in2out(scan_merge_inouts),
ScanSaveMem,
in2out(remove_constants_and_unused_inputs_scan3)
"""
"""This module provides optimizations for the `Scan` `Op`."""
import copy
import dataclasses
import logging
from sys import maxsize
from typing import Dict, List, Tuple
import numpy as np
......@@ -65,6 +16,7 @@ from aesara.compile.function.types import deep_copy_op
from aesara.configdefaults import config
from aesara.graph.basic import (
Constant,
Node,
Variable,
clone_replace,
equal_computations,
......@@ -74,7 +26,7 @@ from aesara.graph.basic import (
)
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import ReplaceValidate
from aesara.graph.fg import InconsistencyError
from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import compute_test_value
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB
......@@ -236,38 +188,19 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
return False
class PushOutNonSeqScan(GlobalOptimizer):
r"""Pushing out the variables inside the `Scan` that depend only on non-sequences.
@local_optimizer([Scan])
def push_out_non_seq_scan(fgraph, node):
r"""Push out the variables inside the `Scan` that depend only on non-sequences.
This optimizations pushes, out of `Scan`'s inner function and into the outer
function, computation that depends only on non-sequence inputs. Such
computation ends up being done every iteration on the same values so moving
it to the outer function to be executed only once, before the `Scan` `Op`,
reduces the amount of computation that needs to be performed.
TODO: This is a global opt for historical reasonons. It should be possible
to change it to a local opt.
"""
def __init__(self):
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
def apply(self, fgraph):
nodelist = [x for x in fgraph.toposort() if isinstance(x.op, Scan)]
for node in nodelist:
self.process_node(fgraph, node)
def process_node(self, fgraph, node):
"""
IMPORTANT NOTE: This function uses set and dictionary data structures.
By default they are not ordered for efficiency reasons. Take care
and make sure of changing them with their Ordered counterparts if you
need to iterate over these variables.
if not isinstance(node.op, Scan):
return False
"""
# this flag tells if there was any change during the last iterations
clean_inputs, clean_outputs = reconstruct_graph(node.op.inputs, node.op.outputs)
......@@ -337,16 +270,9 @@ class PushOutNonSeqScan(GlobalOptimizer):
elif isinstance(x, Constant):
outside_ins.append(x.clone())
else:
raise Exception(
(
"Error in the `scan_pushout_non_seq_"
"operations`. The optimization tries "
"to move some computation from scan "
"which is not allowed to move. Report "
"this on aesara-users list"
),
x,
)
# TODO: Explain why is this an error, and raise an
# appropriate exception type.
raise RuntimeError()
outside_ins = [
x.type.filter_variable(y) for x, y in zip(nd.inputs, outside_ins)
]
......@@ -425,12 +351,9 @@ class PushOutNonSeqScan(GlobalOptimizer):
# Do not call make_node for test_value
nw_node = nwScan(*(node.inputs + nw_outer), return_list=True)[0].owner
fgraph.replace_all_validate_remove(
list(zip(node.outputs, nw_node.outputs)),
remove=[node],
reason="scanOp_pushout_nonseqs_ops",
)
return True
replacements = dict(zip(node.outputs, nw_node.outputs))
replacements["remove"] = [node]
return replacements
elif not to_keep_set:
# Nothing in the inner graph should be kept
replace_with = {}
......@@ -448,11 +371,8 @@ class PushOutNonSeqScan(GlobalOptimizer):
if len(node.outputs) == len(replace_with):
# Every output of the node has a replacement, the Scan
# node can be removed from the graph
fgraph.replace_all_validate_remove(
replace_with.items(),
remove=[node],
reason="scanOp_pushout_nonseqs_ops",
)
replace_with["remove"] = [node]
return replace_with
else:
# The node has some outputs for which no replacement has
# been established. This can occur for outputs that are
......@@ -461,15 +381,14 @@ class PushOutNonSeqScan(GlobalOptimizer):
# passed directly as outputs. The replacements can be
# performed but the Scan node can't be removed at this
# point.
fgraph.replace_all_validate(
replace_with.items(), reason="scanOp_pushout_nonseqs_ops"
)
return replace_with
else:
return False
class PushOutSeqScan(GlobalOptimizer):
@local_optimizer([Scan])
def push_out_seq_scan(fgraph, node):
r"""Push out the variables inside the `Scan` that depend only on constants and sequences.
This optimization resembles `PushOutNonSeqScan` but it tries to push, out of
......@@ -479,30 +398,10 @@ class PushOutSeqScan(GlobalOptimizer):
a single operation on a large tensor rather then perform that same operation
many times on many smaller tensors. In many cases, this optimization can
increase memory usage but, in some specific cases, it can also decrease it.
TODO: This is a global opt for historical reasonons. It should be possible
to change it to a local opt.
"""
if not isinstance(node.op, Scan):
return False
def __init__(self):
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
def apply(self, fgraph):
nodelist = [x for x in fgraph.toposort() if isinstance(x.op, Scan)]
for node in nodelist:
self.process_node(fgraph, node)
def process_node(self, fgraph, node):
"""
IMPORTANT NOTE: This function uses set and dictionary data structure.
By default they are not ordered for efficiency reasons. Take care
and make sure of changing them to Ordered versions if you need to
iterate over those variables.
"""
# this flag tells if there was any change during the last iterations
clean_inputs, clean_outputs = reconstruct_graph(node.op.inputs, node.op.outputs)
......@@ -607,10 +506,7 @@ class PushOutSeqScan(GlobalOptimizer):
elif (
nd not in to_remove_set
and isinstance(nd.op, DimShuffle)
and (
nd.inputs[0] in inner_seqs_set
or nd.inputs[0].owner in to_remove_set
)
and (nd.inputs[0] in inner_seqs_set or nd.inputs[0].owner in to_remove_set)
):
to_remove_set.add(nd)
......@@ -687,9 +583,7 @@ class PushOutSeqScan(GlobalOptimizer):
op_ins = nw_inner + clean_inputs
# Reconstruct node
nw_info = dataclasses.replace(
op.info, n_seqs=op.info.n_seqs + len(nw_inner)
)
nw_info = dataclasses.replace(op.info, n_seqs=op.info.n_seqs + len(nw_inner))
nwScan = Scan(
op_ins,
op_outs,
......@@ -709,12 +603,10 @@ class PushOutSeqScan(GlobalOptimizer):
return_list=True,
)[0].owner
fgraph.replace_all_validate_remove(
list(zip(node.outputs, nw_node.outputs)),
remove=[node],
reason="scanOp_pushout_seqs_ops",
)
return True
replacements = dict(zip(node.outputs, nw_node.outputs))
replacements["remove"] = [node]
return replacements
elif not to_keep_set and not op.as_while and not op.outer_mitmot(node.inputs):
# Nothing in the inner graph should be kept
replace_with = {}
......@@ -740,154 +632,21 @@ class PushOutSeqScan(GlobalOptimizer):
# We need to add one extra dimension to the outputs
if replace_with and len(replace_with) == len(node.outputs):
fgraph.replace_all_validate_remove(
list(replace_with.items()),
remove=[node],
reason="scanOp_pushout_seqs_ops",
)
return True
replacements = dict(replace_with.items())
replacements["remove"] = [node]
return replacements
else:
return False
class PushOutScanOutput(GlobalOptimizer):
r"""Push operations performed at the end of the inner graph of `Scan` to outside of `Scan`.
This optimizations attempts to push out some of the computation at the end
of the inner function to the outer function, to be executed after the `Scan`
node. Like `PushOutSeqScan`, this optimization aims to replace many operations
on small tensors by few operations on large tensors. It can also lead to
increased memory usage.
"""
def __init__(self):
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
def apply(self, fgraph):
# Don't perform the optimization on as_while scans. Because these scans
# don't run for a predetermined number of steps, handling them is
# more complicated and this optimization doesn't support it at the
# moment.
nodelist = [
x
for x in fgraph.toposort()
if (isinstance(x.op, Scan) and not x.op.as_while)
]
for node in nodelist:
# Process the node as long as something gets optimized
while node is not None:
node = self.process_node(fgraph, node)
def process_node(self, fgraph, node):
op = node.op
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args = ScanArgs(
node.inputs, node.outputs, op.inputs, op.outputs, op.info, op.as_while
)
new_scan_node = None
clients = {}
local_fgraph_topo = io_toposort(
args.inner_inputs, args.inner_outputs, clients=clients
)
for nd in local_fgraph_topo:
if (
isinstance(nd.op, Elemwise)
and isinstance(nd.op.scalar_op, aes.Add)
and nd.out in args.inner_out_sit_sot
and self.inner_sitsot_only_last_step_used(fgraph, nd.out, args)
):
# Ensure that one of the input to the add is the output of
# the add from a previous iteration of the inner function
sitsot_idx = args.inner_out_sit_sot.index(nd.out)
if args.inner_in_sit_sot[sitsot_idx] in nd.inputs:
# Ensure that the other input to the add is a dot product
# between 2 matrices which will become a tensor3 and a
# matrix if pushed outside of the scan. Also make sure
# that the output of the Dot is ONLY used by the 'add'
# otherwise doing a Dot in the outer graph will only
# duplicate computation.
sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[sitsot_idx])
# 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
dot_in_idx = 1 - sitsot_in_idx
dot_input = nd.inputs[dot_in_idx]
if (
dot_input.owner is not None
and isinstance(dot_input.owner.op, Dot)
and len(clients[dot_input]) == 1
and dot_input.owner.inputs[0].ndim == 2
and dot_input.owner.inputs[1].ndim == 2
and self.get_outer_ndim(dot_input.owner.inputs[0], args) == 3
and self.get_outer_ndim(dot_input.owner.inputs[1], args) == 3
):
# The optimization can be be applied in this case.
# Move out of scan the two inputs to the Dot and
# perform a dot outside of scan on these two inputs
inner_dot_inputs = nd.inputs[dot_in_idx].owner.inputs
(
outer_dot_inputs,
new_scan_node,
new_scan_args,
) = self.push_out_inner_vars(
fgraph, inner_dot_inputs, node, args
)
# Collapse some of the dimensions of the tensors
# so that they become matrices. This is because a
# dot is usually faster on two large matrices than
# a bunch of small ones
outer_dot_inputs[0] = aet.flatten(
outer_dot_inputs[0].dimshuffle(1, 0, 2), ndim=2
)
shape_input1 = shape(outer_dot_inputs[1])
outer_dot_inputs[1] = outer_dot_inputs[1].reshape(
(shape_input1[0] * shape_input1[1], shape_input1[2])
)
# Perform the dot on the newly obtained matrices and
# add the initial value
outer_dot_output = dot(*outer_dot_inputs)
init_value = new_scan_args.outer_in_sit_sot[sitsot_idx][0]
replacement = outer_dot_output + init_value
# Alter the outer graph to use the output of the
# external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot
outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx]
subtensor_node = fgraph.clients[outer_sitsot][0][0]
outer_sitsot_last_step = subtensor_node.outputs[0]
fgraph.replace_all(
[(outer_sitsot_last_step, replacement)],
reason="scanOp_pushout_output",
)
break
return new_scan_node
def inner_sitsot_only_last_step_used(self, fgraph, var, scan_args):
def inner_sitsot_only_last_step_used(
fgraph: FunctionGraph, var: Variable, scan_args: ScanArgs
) -> bool:
"""
Given a inner nit_sot output of scan, return True iff the outer
nit_sot output has only one client and that client is a Subtensor
Given a inner nit-sot output of `Scan`, return ``True`` iff the outer
nit-sot output has only one client and that client is a `Subtensor`
instance that takes only the last step (last element along the first
axis).
"""
idx = scan_args.inner_out_sit_sot.index(var)
outer_var = scan_args.outer_out_sit_sot[idx]
......@@ -901,23 +660,28 @@ class PushOutScanOutput(GlobalOptimizer):
return False
def get_outer_ndim(self, var, scan_args):
# Given a variable, determine the number of dimension it would have if
# it was pushed out of scan
def get_outer_ndim(var: Variable, scan_args: ScanArgs) -> int:
"""Determine the number of dimension a variable would have if it was pushed out of a `Scan`."""
if var in scan_args.inner_in_non_seqs or isinstance(var, Constant):
outer_ndim = var.ndim
else:
outer_ndim = var.ndim + 1
return outer_ndim
def push_out_inner_vars(self, fgraph, inner_vars, old_scan_node, old_scan_args):
def push_out_inner_vars(
fgraph: FunctionGraph,
inner_vars: List[Variable],
old_scan_node: Node,
old_scan_args: ScanArgs,
) -> Tuple[List[Variable], ScanArgs, Dict[Variable, Variable]]:
outer_vars = [None] * len(inner_vars)
new_scan_node = old_scan_node
new_scan_args = old_scan_args
replacements = {}
# For the inner_vars that already exist in the outer graph,
# simply obtain a reference to them
......@@ -942,14 +706,12 @@ class PushOutScanOutput(GlobalOptimizer):
# For the inner_vars that don't already exist in the outer graph, add
# them as new nitsot outputs to the scan node.
idx_add_as_nitsots = [
i for i in range(len(outer_vars)) if outer_vars[i] is None
]
idx_add_as_nitsots = [i for i in range(len(outer_vars)) if outer_vars[i] is None]
add_as_nitsots = [inner_vars[idx] for idx in idx_add_as_nitsots]
if len(add_as_nitsots) > 0:
new_scan_node = self.add_nitsot_outputs(
new_scan_node, replacements = add_nitsot_outputs(
fgraph, old_scan_node, old_scan_args, add_as_nitsots
)
......@@ -966,20 +728,22 @@ class PushOutScanOutput(GlobalOptimizer):
for i in range(len(new_outs)):
outer_vars[idx_add_as_nitsots[i]] = new_outs[i]
return outer_vars, new_scan_node, new_scan_args
return outer_vars, new_scan_args, replacements
def add_nitsot_outputs(
self, fgraph, old_scan_node, old_scan_args, new_outputs_inner
):
def add_nitsot_outputs(
fgraph: FunctionGraph,
old_scan_node: Node,
old_scan_args: ScanArgs,
new_outputs_inner,
) -> Tuple[Node, Dict[Variable, Variable]]:
nb_new_outs = len(new_outputs_inner)
# Create the initial values for the new nitsot outputs
# (the initial value is the nb of steps to store. For a nistot,
# it should be the number of steps performed by scan)
new_nitsots_initial_value = [
old_scan_node.inputs[0] for i in range(nb_new_outs)
]
new_nitsots_initial_value = [old_scan_node.inputs[0] for i in range(nb_new_outs)]
# Create the `ScanArgs` corresponding to the new `Scan` `Op` to create
new_scan_args = copy.copy(old_scan_args)
......@@ -1002,9 +766,7 @@ class PushOutScanOutput(GlobalOptimizer):
)
# Create the Apply node for the scan op
new_scan_node = new_scan_op(*new_scan_args.outer_inputs, return_list=True)[
0
].owner
new_scan_node = new_scan_op(*new_scan_args.outer_inputs, return_list=True)[0].owner
# Modify the outer graph to make sure the outputs of the new scan are
# used instead of the outputs of the old scan
......@@ -1017,13 +779,123 @@ class PushOutScanOutput(GlobalOptimizer):
+ new_scan_node.outputs[new_node_new_outputs_idx + nb_new_outs :]
)
# TODO FIXME:
# replacements = dict(zip(old_scan_node.outputs, new_node_old_outputs))
# replacements["remove"] = [old_scan_node]
# return new_scan_node, replacements
fgraph.replace_all_validate_remove(
list(zip(old_scan_node.outputs, new_node_old_outputs)),
remove=[old_scan_node],
reason="scanOp_pushout_output",
reason="scan_pushout_output",
)
return new_scan_node, {}
@local_optimizer([Scan])
def push_out_add_scan(fgraph, node):
r"""Push `Add` operations performed at the end of the inner graph to the outside.
return new_scan_node
Like `push_out_seq_scan`, this optimization aims to replace many operations
on small tensors by few operations on large tensors. It can also lead to
increased memory usage.
"""
# Don't perform the optimization on `as_while` `Scan`s. Because these
# `Scan`s don't run for a predetermined number of steps, handling them is
# more complicated and this optimization doesn't support it at the moment.
if not (isinstance(node.op, Scan) and not node.op.as_while):
return False
op = node.op
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args = ScanArgs(
node.inputs, node.outputs, op.inputs, op.outputs, op.info, op.as_while
)
clients = {}
local_fgraph_topo = io_toposort(
args.inner_inputs, args.inner_outputs, clients=clients
)
for nd in local_fgraph_topo:
if (
isinstance(nd.op, Elemwise)
and isinstance(nd.op.scalar_op, aes.Add)
and nd.out in args.inner_out_sit_sot
and inner_sitsot_only_last_step_used(fgraph, nd.out, args)
):
# Ensure that one of the input to the add is the output of
# the add from a previous iteration of the inner function
sitsot_idx = args.inner_out_sit_sot.index(nd.out)
if args.inner_in_sit_sot[sitsot_idx] in nd.inputs:
# Ensure that the other input to the add is a dot product
# between 2 matrices which will become a tensor3 and a
# matrix if pushed outside of the scan. Also make sure
# that the output of the Dot is ONLY used by the 'add'
# otherwise doing a Dot in the outer graph will only
# duplicate computation.
sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[sitsot_idx])
# 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0
dot_in_idx = 1 - sitsot_in_idx
dot_input = nd.inputs[dot_in_idx]
if (
dot_input.owner is not None
and isinstance(dot_input.owner.op, Dot)
and len(clients[dot_input]) == 1
and dot_input.owner.inputs[0].ndim == 2
and dot_input.owner.inputs[1].ndim == 2
and get_outer_ndim(dot_input.owner.inputs[0], args) == 3
and get_outer_ndim(dot_input.owner.inputs[1], args) == 3
):
# The optimization can be be applied in this case.
# Move out of scan the two inputs to the Dot and
# perform a dot outside of scan on these two inputs
inner_dot_inputs = nd.inputs[dot_in_idx].owner.inputs
(
outer_dot_inputs,
new_scan_args,
replacements,
) = push_out_inner_vars(fgraph, inner_dot_inputs, node, args)
# Collapse some of the dimensions of the tensors
# so that they become matrices. This is because a
# dot is usually faster on two large matrices than
# a bunch of small ones
outer_dot_inputs[0] = aet.flatten(
outer_dot_inputs[0].dimshuffle(1, 0, 2), ndim=2
)
shape_input1 = shape(outer_dot_inputs[1])
outer_dot_inputs[1] = outer_dot_inputs[1].reshape(
(shape_input1[0] * shape_input1[1], shape_input1[2])
)
# Perform the dot on the newly obtained matrices and
# add the initial value
outer_dot_output = dot(*outer_dot_inputs)
init_value = new_scan_args.outer_in_sit_sot[sitsot_idx][0]
replacement = outer_dot_output + init_value
# Alter the outer graph to use the output of the
# external Dot instead of the output of scan
# Modify the outer graph to add the outer Dot
outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx]
subtensor_node = fgraph.clients[outer_sitsot][0][0]
outer_sitsot_last_step = subtensor_node.outputs[0]
replacements[outer_sitsot_last_step] = replacement
return replacements
return False
class ScanInplaceOptimizer(GlobalOptimizer):
......@@ -1203,7 +1075,31 @@ class ScanInplaceOptimizer(GlobalOptimizer):
node = self.attempt_scan_inplace(fgraph, node, [pos], alloc_ops)
class ScanSaveMem(GlobalOptimizer):
def select_min(x, y):
if x is None:
return y
if y is None:
return x
return minimum(x, y)
def select_max(x, y):
if x is None:
return y
if y is None:
return x
return maximum(x, y)
def sanitize(x):
if x is None:
return None
else:
return aet.as_tensor_variable(x)
@local_optimizer([Scan])
def save_mem_new_scan(fgraph, node):
r"""Graph optimizer that reduces scan memory consumption.
This optimizations attempts to determine if a `Scan` node, during its execution,
......@@ -1224,35 +1120,8 @@ class ScanSaveMem(GlobalOptimizer):
be kept in memory.
"""
def __init__(self):
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
def process_node(self, fgraph, node):
# helpful functions
def select_min(x, y):
if x is None:
return y
if y is None:
return x
return minimum(x, y)
def select_max(x, y):
if x is None:
return y
if y is None:
return x
return maximum(x, y)
def sanitize(x):
if x is None:
return None
else:
return aet.as_tensor_variable(x)
if not isinstance(node.op, Scan):
return False
if hasattr(fgraph, "shape_feature"):
shape_of = fgraph.shape_feature.shape_of
......@@ -1487,17 +1356,12 @@ class ScanSaveMem(GlobalOptimizer):
first_mitsot_idx = node.op.n_mit_mot
last_sitsot_idx = (
node.op.n_mit_mot
+ node.op.n_mit_sot
+ node.op.n_sit_sot
- 1
node.op.n_mit_mot + node.op.n_mit_sot + node.op.n_sit_sot - 1
)
preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx
if prealloc_outs and preallocable_output:
pval = select_max(
nw_steps - start + init_l[i], init_l[i] + 1
)
pval = select_max(nw_steps - start + init_l[i], init_l[i] + 1)
else:
pval = select_max(nw_steps - start + init_l[i], init_l[i])
......@@ -1544,9 +1408,7 @@ class ScanSaveMem(GlobalOptimizer):
# TODO: commit change below with Razvan
if (
nw_inputs[offset + idx].owner
and isinstance(
nw_inputs[offset + idx].owner.op, IncSubtensor
)
and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor)
and isinstance(
nw_inputs[offset + idx].owner.op.idx_list[0], slice
)
......@@ -1558,9 +1420,7 @@ class ScanSaveMem(GlobalOptimizer):
_nw_input = nw_inputs[offset + idx].owner.inputs[1]
cval = aet.as_tensor_variable(val)
initl = aet.as_tensor_variable(init_l[i])
tmp_idx = aet.switch(
cval < initl, cval + initl, cval - initl
)
tmp_idx = aet.switch(cval < initl, cval + initl, cval - initl)
nw_input = expand_empty(_nw_input, tmp_idx)
else:
tmp = aet.as_tensor_variable(val)
......@@ -1645,7 +1505,7 @@ class ScanSaveMem(GlobalOptimizer):
# TODO: currently we don't support scan with 0 step. So
# don't create one.
if aet.extract_constant(node_ins[0]) == 0:
return
return False
# Do not call make_node for test_value
new_op = Scan(
......@@ -1758,19 +1618,18 @@ class ScanSaveMem(GlobalOptimizer):
]
if any(old_scan_is_used):
return False
remove = [old.owner for (old, new) in old_new]
replacements = dict(old_new)
# remove = [old.owner for (old, new) in old_new]
# As Fred suggested assert that also the old node is not in
# the Graph as that will make things suboptimal
remove.append(node)
fgraph.replace_all_validate_remove(
old_new, remove, reason="scanOp_save_mem"
)
# remove.append(node)
replacements["remove"] = [node]
def apply(self, fgraph):
return replacements
nodelist = [x for x in fgraph.toposort() if isinstance(x.op, Scan)]
for node in nodelist:
self.process_node(fgraph, node)
return False
class ScanMerge(GlobalOptimizer):
......@@ -2271,27 +2130,16 @@ def scan_merge_inouts(fgraph, node):
return na.outer_outputs
class PushOutDot1(GlobalOptimizer):
@local_optimizer([Scan])
def push_out_dot1_scan(fgraph, node):
r"""
This is another optimization that attempts to detect certain patterns of
computation in a `Scan` `Op`'s inner function and move this computation to the
outer graph.
"""
if not isinstance(node.op, Scan):
return False
def __init__(self):
super().__init__()
def add_requirements(self, fgraph):
fgraph.attach_feature(ReplaceValidate())
def apply(self, fgraph):
nodes = fgraph.toposort()
scan_nodes = [x for x in nodes if (isinstance(x.op, Scan))]
for node in scan_nodes:
self.apply_opt(fgraph, node)
def apply_opt(self, fgraph, node):
# Replace pattern of the form
# x[t] = x[t-1] + dot(seq[t], value)
# with Sequence.reshape((-1, seq.shape[2])) \dot Value
......@@ -2470,9 +2318,11 @@ class PushOutDot1(GlobalOptimizer):
old = fgraph.clients[node.outputs[pos]][0][0].outputs[0]
old_new.append((old, new_out))
old_new += list(zip(node.outputs[pos + 1 :], new_outs[pos:]))
fgraph.replace_all_validate_remove(
old_new, remove=[node], reason="scan_pushout_dot1"
)
replacements = dict(old_new)
replacements["remove"] = [node]
return replacements
return False
# I've added an equilibrium because later scan optimization in the sequence
......@@ -2490,7 +2340,13 @@ optdb.register("scan_eqopt1", scan_eqopt1, 0.05, "fast_run", "scan")
# but after stabilize at 1.5. Should we put it before stabilize?
optdb.register("scan_eqopt2", scan_eqopt2, 1.6, "fast_run", "scan")
# ScanSaveMem should execute only once per node.
optdb.register("scanOp_save_mem", ScanSaveMem(), 1.61, "fast_run", "scan")
optdb.register(
"scanOp_save_mem",
in2out(save_mem_new_scan, ignore_newtrees=True),
1.61,
"fast_run",
"scan",
)
optdb.register(
"scanOp_make_inplace",
ScanInplaceOptimizer(typeInfer=None),
......@@ -2514,22 +2370,41 @@ scan_seqopt1.register(
scan_seqopt1.register(
"scanOp_pushout_nonseqs_ops", PushOutNonSeqScan(), 2, "fast_run", "scan"
"scanOp_pushout_nonseqs_ops",
in2out(push_out_non_seq_scan, ignore_newtrees=True),
2,
"fast_run",
"scan",
)
scan_seqopt1.register(
"scanOp_pushout_seqs_ops", PushOutSeqScan(), 3, "fast_run", "scan"
"scanOp_pushout_seqs_ops",
in2out(push_out_seq_scan, ignore_newtrees=True),
3,
"fast_run",
"scan",
)
scan_seqopt1.register(
"scan_pushout_dot1", PushOutDot1(), 4, "fast_run", "more_mem", "scan"
"scan_pushout_dot1",
in2out(push_out_dot1_scan, ignore_newtrees=True),
4,
"fast_run",
"more_mem",
"scan",
)
scan_seqopt1.register(
"scanOp_pushout_output", PushOutScanOutput(), 5, "fast_run", "more_mem", "scan"
"scanOp_pushout_output",
# TODO: Perhaps this should be an `EquilibriumOptimizer`?
in2out(push_out_add_scan, ignore_newtrees=False),
5,
"fast_run",
"more_mem",
"scan",
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论