提交 a22da800 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Compile the FunctionGraph object tracked by Scan directly

上级 2fb98571
...@@ -56,9 +56,9 @@ import aesara ...@@ -56,9 +56,9 @@ import aesara
from aesara import tensor as at from aesara import tensor as at
from aesara.compile import SharedVariable from aesara.compile import SharedVariable
from aesara.compile.builders import infer_shape from aesara.compile.builders import infer_shape
from aesara.compile.function import function from aesara.compile.function.pfunc import pfunc
from aesara.compile.io import In, Out from aesara.compile.io import In, Out
from aesara.compile.mode import AddFeatureOptimizer, Mode, get_default_mode, get_mode from aesara.compile.mode import Mode, get_default_mode, get_mode
from aesara.compile.profiling import register_profiler_printer from aesara.compile.profiling import register_profiler_printer
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
...@@ -76,7 +76,7 @@ from aesara.graph.basic import ( ...@@ -76,7 +76,7 @@ from aesara.graph.basic import (
from aesara.graph.features import NoOutputFromInplace from aesara.graph.features import NoOutputFromInplace
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import HasInnerGraph, Op from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.utils import MissingInputError from aesara.graph.utils import InconsistencyError, MissingInputError
from aesara.link.c.basic import CLinker from aesara.link.c.basic import CLinker
from aesara.link.c.exceptions import MissingGXX from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import raise_with_op from aesara.link.utils import raise_with_op
...@@ -778,8 +778,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -778,8 +778,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else: else:
inputs = [] inputs = []
self.fgraph = FunctionGraph(inputs, outputs, clone=False)
self.info = info self.info = info
self.truncate_gradient = truncate_gradient self.truncate_gradient = truncate_gradient
self.name = name self.name = name
...@@ -863,6 +861,36 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -863,6 +861,36 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot self.n_outs = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot
self.n_tap_outs = info.n_mit_mot + info.n_mit_sot self.n_tap_outs = info.n_mit_mot + info.n_mit_sot
(
self.preallocated_mitmot_outs,
self.mitmots_preallocated,
) = self._mitmot_preallocations()
self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs
features = []
if config.scan__allow_output_prealloc:
# This feature will prevent mitsot, sitsot and nitsot outputs from
# being computed inplace (to allow their preallocation).
mitsot_start = info.n_mit_mot_outs - len(self.preallocated_mitmot_outs)
nitsot_end = mitsot_start + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot
features.append(NoOutputFromInplace(mitsot_start, nitsot_end))
self.fgraph = FunctionGraph(
inputs,
outputs,
clone=False,
features=features,
)
if any(node.op.destroy_map for node in self.fgraph.apply_nodes):
raise InconsistencyError(
"Inner-graphs must not contain in-place operations."
)
# Do the missing inputs check here to have the error early. # Do the missing inputs check here to have the error early.
for var in graph_inputs(self.inner_outputs, self.inner_inputs): for var in graph_inputs(self.inner_outputs, self.inner_inputs):
if var not in self.inner_inputs and not isinstance(var, Constant): if var not in self.inner_inputs and not isinstance(var, Constant):
...@@ -872,14 +900,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -872,14 +900,6 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
) )
self._hash_inner_graph = hash(self._cmodule_key) self._hash_inner_graph = hash(self._cmodule_key)
(
self.preallocated_mitmot_outs,
self.mitmots_preallocated,
) = self._mitmot_preallocations()
self.n_outer_inputs = info.n_outer_inputs
self.n_outer_outputs = info.n_outer_outputs
def _mitmot_preallocations(self): def _mitmot_preallocations(self):
if config.scan__allow_output_prealloc: if config.scan__allow_output_prealloc:
preallocated_mitmot_outs = [] preallocated_mitmot_outs = []
...@@ -1356,6 +1376,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1356,6 +1376,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+ self.info.n_nit_sot + self.info.n_nit_sot
) )
fgraph = self.fgraph.clone()
if config.scan__allow_output_prealloc: if config.scan__allow_output_prealloc:
# Go through the mitmots. Whenever a mitmot has a tap both as an # Go through the mitmots. Whenever a mitmot has a tap both as an
...@@ -1363,15 +1385,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1363,15 +1385,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# output variable becomes an update to be performed on it, possibly # output variable becomes an update to be performed on it, possibly
# inplace at the end of the functions's execution. # inplace at the end of the functions's execution.
wrapped_inputs = [ wrapped_inputs = [
In(x, borrow=False) for x in self.inner_inputs[: self.n_seqs] In(x, borrow=False) for x in fgraph.inputs[: self.info.n_seqs]
] ]
new_outputs = [x for x in self.inner_outputs] new_outputs = [x for x in fgraph.outputs]
input_idx = self.info.n_seqs input_idx = self.info.n_seqs
for mitmot_idx in range(self.info.n_mit_mot): for mitmot_idx in range(self.info.n_mit_mot):
for inp_tap in self.info.mit_mot_in_slices[mitmot_idx]: for inp_tap in self.info.mit_mot_in_slices[mitmot_idx]:
if inp_tap in self.info.mit_mot_out_slices[mitmot_idx]: if inp_tap in self.info.mit_mot_out_slices[mitmot_idx]:
inp = self.inner_inputs[input_idx] inp = fgraph.inputs[input_idx]
# Figure out the index of the corresponding output # Figure out the index of the corresponding output
output_idx = sum( output_idx = sum(
...@@ -1393,57 +1415,35 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1393,57 +1415,35 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
wrapped_inp = In( wrapped_inp = In(
variable=inp, variable=inp,
value=default_val, value=default_val,
update=self.inner_outputs[output_idx], update=fgraph.outputs[output_idx],
) )
wrapped_inputs.append(wrapped_inp) wrapped_inputs.append(wrapped_inp)
else: else:
# Wrap the corresponding input as usual. Leave the # Wrap the corresponding input as usual. Leave the
# output as-is. # output as-is.
wrapped_inputs.append( wrapped_inputs.append(
In(self.inner_inputs[input_idx], borrow=False) In(fgraph.inputs[input_idx], borrow=False)
) )
input_idx += 1 input_idx += 1
# Wrap the inputs not associated to mitmots and wrap the remaining # Wrap the inputs not associated to mitmots and wrap the remaining
# outputs # outputs
wrapped_inputs += [ wrapped_inputs += [In(x, borrow=False) for x in fgraph.inputs[input_idx:]]
In(x, borrow=False) for x in self.inner_inputs[input_idx:]
]
wrapped_outputs = [Out(x, borrow=True) for x in new_outputs[:slices]] wrapped_outputs = [Out(x, borrow=True) for x in new_outputs[:slices]]
wrapped_outputs += new_outputs[slices:] wrapped_outputs += new_outputs[slices:]
# Remove now useless outputs from the output list (start from the # Remove now useless outputs from the output list and start from
# end to avoid altering the indices of the other outputs to be # the end to avoid altering the indices of the other outputs to be
# deleted. # deleted.
for p in self.preallocated_mitmot_outs[::-1]: for p in self.preallocated_mitmot_outs[::-1]:
fgraph.remove_output(p, reason="scan_prealloc")
del wrapped_outputs[p] del wrapped_outputs[p]
# Add an optimization to the compilation mode to attach a feature
# to the function graph just before the inplace optimizations are
# applied (inplace optimizations start at position 50 so the
# optimization to attach the feature is registered at position 49.9
# so that it runs before them). This feature will prevent mitsot,
# sitsot and nitsot outputs from being computed inplace (to allow
# their preallocation).
mitsot_start = self.info.n_mit_mot_outs - len(self.preallocated_mitmot_outs)
nitsot_end = (
mitsot_start
+ self.info.n_mit_sot
+ self.info.n_sit_sot
+ self.info.n_nit_sot
)
feature = NoOutputFromInplace(mitsot_start, nitsot_end)
opt = AddFeatureOptimizer(feature)
compilation_mode = self.mode_instance.register((opt, 49.9))
else: else:
compilation_mode = self.mode_instance wrapped_inputs = [In(x, borrow=True) for x in fgraph.inputs]
wrapped_inputs = [In(x, borrow=True) for x in self.inner_inputs] wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs[:slices]]
wrapped_outputs = [ wrapped_outputs += fgraph.outputs[slices:]
Out(x, borrow=False) for x in self.inner_outputs[:slices]
]
wrapped_outputs += self.inner_outputs[slices:]
profile = None profile = None
if config.profile or ( if config.profile or (
...@@ -1456,13 +1456,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -1456,13 +1456,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
elif self.profile: elif self.profile:
profile = self.profile profile = self.profile
self._fn = function( self._fn = pfunc(
wrapped_inputs, wrapped_inputs,
wrapped_outputs, wrapped_outputs,
mode=compilation_mode, mode=self.mode_instance,
name=self.name, accept_inplace=False,
profile=profile, profile=profile,
on_unused_input="ignore", on_unused_input="ignore",
fgraph=fgraph,
) )
return self._fn return self._fn
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论