提交 a22da800 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Compile the FunctionGraph object tracked by Scan directly

上级 2fb98571
......@@ -56,9 +56,9 @@ import aesara
from aesara import tensor as at
from aesara.compile import SharedVariable
from aesara.compile.builders import infer_shape
from aesara.compile.function import function
from aesara.compile.function.pfunc import pfunc
from aesara.compile.io import In, Out
from aesara.compile.mode import AddFeatureOptimizer, Mode, get_default_mode, get_mode
from aesara.compile.mode import Mode, get_default_mode, get_mode
from aesara.compile.profiling import register_profiler_printer
from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
......@@ -76,7 +76,7 @@ from aesara.graph.basic import (
from aesara.graph.features import NoOutputFromInplace
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.utils import MissingInputError
from aesara.graph.utils import InconsistencyError, MissingInputError
from aesara.link.c.basic import CLinker
from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import raise_with_op
......@@ -778,8 +778,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else:
inputs = []
self.fgraph = FunctionGraph(inputs, outputs, clone=False)
self.info = info
self.truncate_gradient = truncate_gradient
self.name = name
......@@ -863,6 +861,36 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
self.n_tap_outs = info.n_mit_mot + info.n_mit_sot
(
self.preallocated_mitmot_outs,
self.mitmots_preallocated,
) = self._mitmot_preallocations()
self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs
features = []
if config.scan__allow_output_prealloc:
# This feature will prevent mitsot, sitsot and nitsot outputs from
# being computed inplace (to allow their preallocation).
mitsot_start = info.n_mit_mot_outs - len(self.preallocated_mitmot_outs)
nitsot_end = mitsot_start + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot
features.append(NoOutputFromInplace(mitsot_start, nitsot_end))
self.fgraph = FunctionGraph(
inputs,
outputs,
clone=False,
features=features,
)
if any(node.op.destroy_map for node in self.fgraph.apply_nodes):
raise InconsistencyError(
"Inner-graphs must not contain in-place operations."
)
# Do the missing inputs check here to have the error early.
for var in graph_inputs(self.inner_outputs, self.inner_inputs):
if var not in self.inner_inputs and not isinstance(var, Constant):
......@@ -872,14 +900,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
self._hash_inner_graph = hash(self._cmodule_key)
(
self.preallocated_mitmot_outs,
self.mitmots_preallocated,
) = self._mitmot_preallocations()
self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs
def _mitmot_preallocations(self):
if config.scan__allow_output_prealloc:
preallocated_mitmot_outs = []
......@@ -1356,6 +1376,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+ self.info.n_nit_sot
)
fgraph = self.fgraph.clone()
if config.scan__allow_output_prealloc:
# Go through the mitmots. Whenever a mitmot has a tap both as an
......@@ -1363,15 +1385,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# output variable becomes an update to be performed on it, possibly
# inplace at the end of the functions's execution.
wrapped_inputs = [
In(x, borrow=False) for x in self.inner_inputs[: self.n_seqs]
In(x, borrow=False) for x in fgraph.inputs[: self.info.n_seqs]
]
new_outputs = [x for x in self.inner_outputs]
new_outputs = [x for x in fgraph.outputs]
input_idx = self.info.n_seqs
for mitmot_idx in range(self.info.n_mit_mot):
for inp_tap in self.info.mit_mot_in_slices[mitmot_idx]:
if inp_tap in self.info.mit_mot_out_slices[mitmot_idx]:
inp = self.inner_inputs[input_idx]
inp = fgraph.inputs[input_idx]
# Figure out the index of the corresponding output
output_idx = sum(
......@@ -1393,57 +1415,35 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
wrapped_inp = In(
variable=inp,
value=default_val,
update=self.inner_outputs[output_idx],
update=fgraph.outputs[output_idx],
)
wrapped_inputs.append(wrapped_inp)
else:
# Wrap the corresponding input as usual. Leave the
# output as-is.
wrapped_inputs.append(
In(self.inner_inputs[input_idx], borrow=False)
In(fgraph.inputs[input_idx], borrow=False)
)
input_idx += 1
# Wrap the inputs not associated to mitmots and wrap the remaining
# outputs
wrapped_inputs += [
In(x, borrow=False) for x in self.inner_inputs[input_idx:]
]
wrapped_inputs += [In(x, borrow=False) for x in fgraph.inputs[input_idx:]]
wrapped_outputs = [Out(x, borrow=True) for x in new_outputs[:slices]]
wrapped_outputs += new_outputs[slices:]
# Remove now useless outputs from the output list (start from the
# end to avoid altering the indices of the other outputs to be
# Remove now useless outputs from the output list and start from
# the end to avoid altering the indices of the other outputs to be
# deleted.
for p in self.preallocated_mitmot_outs[::-1]:
fgraph.remove_output(p, reason="scan_prealloc")
del wrapped_outputs[p]
# Add an optimization to the compilation mode to attach a feature
# to the function graph just before the inplace optimizations are
# applied (inplace optimizations start at position 50 so the
# optimization to attach the feature is registered at position 49.9
# so that it runs before them). This feature will prevent mitsot,
# sitsot and nitsot outputs from being computed inplace (to allow
# their preallocation).
mitsot_start = self.info.n_mit_mot_outs - len(self.preallocated_mitmot_outs)
nitsot_end = (
mitsot_start
+ self.info.n_mit_sot
+ self.info.n_sit_sot
+ self.info.n_nit_sot
)
feature = NoOutputFromInplace(mitsot_start, nitsot_end)
opt = AddFeatureOptimizer(feature)
compilation_mode = self.mode_instance.register((opt, 49.9))
else:
compilation_mode = self.mode_instance
wrapped_inputs = [In(x, borrow=True) for x in self.inner_inputs]
wrapped_outputs = [
Out(x, borrow=False) for x in self.inner_outputs[:slices]
]
wrapped_outputs += self.inner_outputs[slices:]
wrapped_inputs = [In(x, borrow=True) for x in fgraph.inputs]
wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs[:slices]]
wrapped_outputs += fgraph.outputs[slices:]
profile = None
if config.profile or (
......@@ -1456,13 +1456,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
elif self.profile:
profile = self.profile
self._fn = function(
self._fn = pfunc(
wrapped_inputs,
wrapped_outputs,
mode=compilation_mode,
name=self.name,
mode=self.mode_instance,
accept_inplace=False,
profile=profile,
on_unused_input="ignore",
fgraph=fgraph,
)
return self._fn
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论