提交 e51be01c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add FunctionGraph input support to the Function compilation pipeline

These changes allow one to pass an `fgraph` argument to all key functions in the `Function` compilation pipeline. The given `FunctionGraph` will be directly compiled without cloning. Unlike the previous `FunctionMaker.__init__`, this one's `fgraph` argument *will* be optimized according to the given `mode` unless the keyword argument `no_fgraph_prep` is `True`.
上级 90e794c0
...@@ -433,7 +433,9 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -433,7 +433,9 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
""" """
equivalence_tracker = _VariableEquivalenceTracker() equivalence_tracker = _VariableEquivalenceTracker()
fgraph, updates = std_fgraph(input_specs, output_specs, accept_inplace) fgraph, updates = std_fgraph(
input_specs, output_specs, accept_inplace, force_clone=True
)
fgraph.attach_feature(equivalence_tracker) fgraph.attach_feature(equivalence_tracker)
return fgraph, updates, equivalence_tracker return fgraph, updates, equivalence_tracker
...@@ -2006,6 +2008,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2006,6 +2008,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
fgraph=None, # If present the optimized graph. we ignore it. fgraph=None, # If present the optimized graph. we ignore it.
output_keys=None, output_keys=None,
name=None, name=None,
no_fgraph_prep=False,
): ):
self.mode = mode self.mode = mode
self.profile = profile self.profile = profile
......
...@@ -4,6 +4,8 @@ Provide a simple user friendly API. ...@@ -4,6 +4,8 @@ Provide a simple user friendly API.
""" """
import logging import logging
from copy import copy
from typing import Optional
from aesara.compile.function.types import Function, UnusedInputError, orig_function from aesara.compile.function.types import Function, UnusedInputError, orig_function
from aesara.compile.io import In, Out from aesara.compile.io import In, Out
...@@ -11,6 +13,7 @@ from aesara.compile.profiling import ProfileStats ...@@ -11,6 +13,7 @@ from aesara.compile.profiling import ProfileStats
from aesara.compile.sharedvalue import SharedVariable, shared from aesara.compile.sharedvalue import SharedVariable, shared
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
from aesara.graph.fg import FunctionGraph
_logger = logging.getLogger("aesara.compile.function.pfunc") _logger = logging.getLogger("aesara.compile.function.pfunc")
...@@ -279,6 +282,7 @@ def pfunc( ...@@ -279,6 +282,7 @@ def pfunc(
profile=None, profile=None,
on_unused_input=None, on_unused_input=None,
output_keys=None, output_keys=None,
fgraph: Optional[FunctionGraph] = None,
) -> Function: ) -> Function:
""" """
Function-constructor for graphs with shared variables. Function-constructor for graphs with shared variables.
...@@ -324,6 +328,9 @@ def pfunc( ...@@ -324,6 +328,9 @@ def pfunc(
be available via self.profile. be available via self.profile.
on_unused_input : {'raise', 'warn','ignore', None} on_unused_input : {'raise', 'warn','ignore', None}
What to do if a variable in the 'inputs' list is not used in the graph. What to do if a variable in the 'inputs' list is not used in the graph.
fgraph
An existing `FunctionGraph` from which to construct the returned
`Function`. When this is non-``None``, nothing is cloned.
Returns Returns
------- -------
...@@ -358,6 +365,7 @@ def pfunc( ...@@ -358,6 +365,7 @@ def pfunc(
no_default_updates, no_default_updates,
rebuild_strict, rebuild_strict,
allow_input_downcast, allow_input_downcast,
fgraph=fgraph,
) )
return orig_function( return orig_function(
...@@ -369,6 +377,7 @@ def pfunc( ...@@ -369,6 +377,7 @@ def pfunc(
profile=profile, profile=profile,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
output_keys=output_keys, output_keys=output_keys,
fgraph=fgraph,
) )
...@@ -381,6 +390,7 @@ def construct_pfunc_ins_and_outs( ...@@ -381,6 +390,7 @@ def construct_pfunc_ins_and_outs(
no_default_updates=False, no_default_updates=False,
rebuild_strict=True, rebuild_strict=True,
allow_input_downcast=None, allow_input_downcast=None,
fgraph: Optional[FunctionGraph] = None,
): ):
"""Construct inputs and outputs for `pfunc`. """Construct inputs and outputs for `pfunc`.
...@@ -398,6 +408,9 @@ def construct_pfunc_ins_and_outs( ...@@ -398,6 +408,9 @@ def construct_pfunc_ins_and_outs(
Then it clones the outputs and the update expressions. This Then it clones the outputs and the update expressions. This
rebuilds a computation graph from the inputs and the `givens`. rebuilds a computation graph from the inputs and the `givens`.
When `fgraph` is non-``None``, nothing is cloned and the given `fgraph` is
simply prepared for direct use.
""" """
if updates is None: if updates is None:
updates = [] updates = []
...@@ -459,68 +472,93 @@ def construct_pfunc_ins_and_outs( ...@@ -459,68 +472,93 @@ def construct_pfunc_ins_and_outs(
"aesara.clone_replace(f(x), replace={x: g(x)}))`." "aesara.clone_replace(f(x), replace={x: g(x)}))`."
) )
# Extend the outputs with the updates on input variables so they are also if not fgraph:
# cloned
additional_outputs = [i.update for i in inputs if i.update]
if outputs is None:
out_list = []
else:
if isinstance(outputs, (list, tuple)):
out_list = list(outputs)
else:
out_list = [outputs]
extended_outputs = out_list + additional_outputs
output_vars = rebuild_collect_shared(
extended_outputs,
in_variables,
replace=givens,
updates=updates,
rebuild_strict=rebuild_strict,
copy_inputs_over=True,
no_default_updates=no_default_updates,
)
# extracting the arguments
input_variables, cloned_extended_outputs, other_stuff = output_vars
clone_d, update_d, update_expr, shared_inputs = other_stuff
# Recover only the clones of the original outputs # Extend the outputs with the updates on input variables so they are
if outputs is None: # also cloned
cloned_outputs = [] additional_outputs = [i.update for i in inputs if i.update]
else: if outputs is None:
if isinstance(outputs, (list, tuple)): out_list = []
cloned_outputs = cloned_extended_outputs[: len(outputs)]
else: else:
cloned_outputs = cloned_extended_outputs[0] if isinstance(outputs, (list, tuple)):
out_list = list(outputs)
for i, iv in zip(inputs, input_variables): else:
i.variable = iv out_list = [outputs]
extended_outputs = out_list + additional_outputs
# If needed, replace the input's update by its cloned equivalent
if i.update: output_vars = rebuild_collect_shared(
i.update = clone_d[i.update] extended_outputs,
in_variables,
for sv in shared_inputs: replace=givens,
# pass value of None updates=updates,
# value will be stored in the resulting functions' defaults rebuild_strict=rebuild_strict,
# list but since the value of shared variables never needs to copy_inputs_over=True,
# be refed, it is not needed no_default_updates=no_default_updates,
if sv in update_d: )
si = In( input_variables, cloned_extended_outputs, other_stuff = output_vars
variable=sv, clone_d, update_d, update_expr, shared_inputs = other_stuff
value=sv.container,
mutable=True, # Recover only the clones of the original outputs
borrow=True, if outputs is None:
update=update_d[sv], new_outputs = []
shared=True,
)
else: else:
si = In( if isinstance(outputs, (list, tuple)):
variable=sv, value=sv.container, mutable=False, borrow=True, shared=True new_outputs = cloned_extended_outputs[: len(outputs)]
) else:
inputs.append(si) new_outputs = cloned_extended_outputs[0]
new_inputs = []
for i, iv in zip(inputs, input_variables):
new_i = copy(i)
new_i.variable = iv
# If needed, replace the input's update by its cloned equivalent
if i.update:
new_i.update = clone_d[i.update]
new_inputs.append(new_i)
for sv in shared_inputs:
if sv in update_d:
si = In(
variable=sv,
value=sv.container,
mutable=True,
borrow=True,
update=update_d[sv],
shared=True,
)
else:
si = In(
variable=sv,
value=sv.container,
mutable=False,
borrow=True,
shared=True,
)
new_inputs.append(si)
else:
assert len(fgraph.inputs) == len(inputs)
assert len(fgraph.outputs) == len(outputs)
for fg_inp, inp in zip(fgraph.inputs, inputs):
if fg_inp != getattr(inp, "variable", inp):
raise ValueError(
f"`fgraph`'s input does not match the provided input: {fg_inp}, {inp}"
)
for fg_out, out in zip(fgraph.outputs, outputs):
if fg_out != getattr(out, "variable", out):
raise ValueError(
f"`fgraph`'s output does not match the provided output: {fg_out}, {out}"
)
new_inputs = inputs
new_outputs = outputs
return inputs, cloned_outputs return new_inputs, new_outputs
def _pfunc_param_to_in(param, strict=False, allow_downcast=None): def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
...@@ -9,7 +9,7 @@ import logging ...@@ -9,7 +9,7 @@ import logging
import time import time
import warnings import warnings
from itertools import chain from itertools import chain
from typing import List, Tuple, Type from typing import List, Optional, Tuple, Type
import numpy as np import numpy as np
...@@ -152,7 +152,9 @@ def std_fgraph( ...@@ -152,7 +152,9 @@ def std_fgraph(
input_specs: List[SymbolicInput], input_specs: List[SymbolicInput],
output_specs: List[SymbolicOutput], output_specs: List[SymbolicOutput],
accept_inplace: bool = False, accept_inplace: bool = False,
fgraph: Optional[FunctionGraph] = None,
features: List[Type[Feature]] = [PreserveVariableAttributes], features: List[Type[Feature]] = [PreserveVariableAttributes],
force_clone=False,
) -> Tuple[FunctionGraph, List[SymbolicOutput]]: ) -> Tuple[FunctionGraph, List[SymbolicOutput]]:
"""Make or set up `FunctionGraph` corresponding to the input specs and the output specs. """Make or set up `FunctionGraph` corresponding to the input specs and the output specs.
...@@ -166,26 +168,48 @@ def std_fgraph( ...@@ -166,26 +168,48 @@ def std_fgraph(
`accept_inplace` is ``True``, a `DestroyHandler` will be added to the `accept_inplace` is ``True``, a `DestroyHandler` will be added to the
`FunctionGraph` if there are any in-place operations. `FunctionGraph` if there are any in-place operations.
The returned FunctionGraph is a clone of the graph between the provided If `fgraph` is ``None``, the returned `FunctionGraph` is a clone of the
inputs and outputs. graph between the provided inputs and outputs.
""" """
orig_inputs = [spec.variable for spec in input_specs]
# Extract the updates and the mapping between update outputs and the # Extract the updates and the mapping between update outputs and the
# updated inputs # updated inputs
updates = [] updates = []
update_mapping = {} update_mapping = {}
out_idx = len(output_specs) out_idx = len(output_specs)
for inp_idx in range(len(input_specs)): for idx, input_spec in enumerate(input_specs):
if input_specs[inp_idx].update: if input_spec.update:
updates.append(input_specs[inp_idx].update) updates.append(input_spec.update)
update_mapping[out_idx] = inp_idx update_mapping[out_idx] = idx
out_idx += 1 out_idx += 1
orig_outputs = [spec.variable for spec in output_specs] + updates if fgraph:
if fgraph.update_mapping is None:
fgraph.update_mapping = update_mapping
for update in updates:
fgraph.add_output(update, reason="std_fgraph")
else:
input_vars = []
# If one of the inputs is non-atomic (i.e. has a non-`None` `Variable.owner`),
# then we need to create/clone the graph starting at these inputs.
# The result will be atomic versions of the given inputs connected to
# the same outputs.
# Otherwise, when all the inputs are already atomic, there's no need to
# clone the graph.
clone = force_clone
for spec in input_specs:
input_vars.append(spec.variable)
clone |= spec.variable.owner is not None
fgraph = FunctionGraph(
input_vars,
[spec.variable for spec in output_specs] + updates,
update_mapping=update_mapping,
clone=clone,
)
fgraph = FunctionGraph(orig_inputs, orig_outputs, update_mapping=update_mapping) additional_outputs = list(map(SymbolicOutput, updates))
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
if node.op.destroy_map: if node.op.destroy_map:
...@@ -210,7 +234,8 @@ def std_fgraph( ...@@ -210,7 +234,8 @@ def std_fgraph(
# If named nodes are replaced, keep the name # If named nodes are replaced, keep the name
for feature in features: for feature in features:
fgraph.attach_feature(feature()) fgraph.attach_feature(feature())
return fgraph, list(map(SymbolicOutput, updates))
return fgraph, additional_outputs
class AliasedMemoryError(Exception): class AliasedMemoryError(Exception):
...@@ -646,6 +671,7 @@ class Function: ...@@ -646,6 +671,7 @@ class Function:
[memo[o] for o in out_vars], [memo[o] for o in out_vars],
clone=False, clone=False,
) )
fg_cpy.update_mapping = maker.fgraph.update_mapping
# Re initialize Outs and swap update and variable in Ins # Re initialize Outs and swap update and variable in Ins
# By doing this, we can pass FunctionMaker.check_unused_inputs() # By doing this, we can pass FunctionMaker.check_unused_inputs()
...@@ -747,6 +773,7 @@ class Function: ...@@ -747,6 +773,7 @@ class Function:
# can contain inplace. DebugMode check # can contain inplace. DebugMode check
# that. # that.
accept_inplace=True, accept_inplace=True,
no_fgraph_prep=True,
).create(input_storage, storage_map=new_storage_map) ).create(input_storage, storage_map=new_storage_map)
for in_ori, in_cpy, ori, cpy in zip( for in_ori, in_cpy, ori, cpy in zip(
...@@ -1378,6 +1405,63 @@ class FunctionMaker: ...@@ -1378,6 +1405,63 @@ class FunctionMaker:
"Valid values are 'raise', 'warn', and 'ignore'." "Valid values are 'raise', 'warn', and 'ignore'."
) )
@staticmethod
def prepare_fgraph(
inputs, outputs, additional_outputs, fgraph, optimizer, linker, profile
):
try:
start_optimizer = time.time()
optimizer_profile = None
opt_time = None
with config.change_flags(
compute_test_value=config.compute_test_value_opt,
traceback__limit=config.traceback__compile_limit,
):
optimizer_profile = optimizer(fgraph)
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
_logger.debug(f"Optimizing took {opt_time:f} seconds")
# Add deep copy to respect the memory interface
insert_deepcopy(fgraph, inputs, outputs + additional_outputs)
finally:
# If the optimizer got interrupted
if opt_time is None:
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
aesara.compile.profiling.total_graph_opt_time += opt_time
if profile:
if optimizer_profile is None and hasattr(optimizer, "pre_profile"):
optimizer_profile = optimizer.pre_profile
profile.optimizer_time += opt_time
if config.profile_optimizer:
profile.optimizer_profile = (optimizer, optimizer_profile)
elif config.profile_optimizer and profile is not False:
# If False, it means the profiling for that function was
# explicitly disabled
warnings.warn(
(
"config.profile_optimizer requires config.profile to "
" be set to True as well"
),
stacklevel=3,
)
if not hasattr(linker, "accept"):
raise ValueError(
"'linker' parameter of FunctionMaker should be "
f"a Linker with an accept method or one of {list(aesara.compile.mode.predefined_linkers.keys())}"
)
def __init__( def __init__(
self, self,
inputs, inputs,
...@@ -1390,6 +1474,7 @@ class FunctionMaker: ...@@ -1390,6 +1474,7 @@ class FunctionMaker:
fgraph=None, fgraph=None,
output_keys=None, output_keys=None,
name=None, name=None,
no_fgraph_prep=False,
): ):
# Save the provided mode, not the instantiated mode. # Save the provided mode, not the instantiated mode.
# The instantiated mode don't pickle and if we unpickle an Aesara # The instantiated mode don't pickle and if we unpickle an Aesara
...@@ -1433,84 +1518,32 @@ class FunctionMaker: ...@@ -1433,84 +1518,32 @@ class FunctionMaker:
indices = [[input, None, [input]] for input in inputs] indices = [[input, None, [input]] for input in inputs]
if fgraph is None: fgraph, additional_outputs = std_fgraph(
need_opt = True inputs, outputs, accept_inplace, fgraph=fgraph
# make the fgraph (copies the graph, creates NEW INPUT AND )
# OUTPUT VARIABLES)
fgraph, additional_outputs = std_fgraph(inputs, outputs, accept_inplace) if fgraph.profile is None:
fgraph.profile = profile fgraph.profile = profile
else:
# fgraph is already an optimized one
need_opt = False
updates = [spec.update for spec in inputs if spec.update]
additional_outputs = list(map(SymbolicOutput, updates))
self.fgraph = fgraph self.fgraph = fgraph
optimizer, linker = mode.optimizer, copy.copy(mode.linker) optimizer, linker = mode.optimizer, copy.copy(mode.linker)
if need_opt:
# Why we add stack on node when it get done in output var?
try:
start_optimizer = time.time()
# In case there is an error during optimization.
optimizer_profile = None
opt_time = None
with config.change_flags(
compute_test_value=config.compute_test_value_opt,
traceback__limit=config.traceback__compile_limit,
):
optimizer_profile = optimizer(fgraph)
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
_logger.debug(f"Optimizing took {opt_time:f} seconds")
# Add deep copy to respect the memory interface
insert_deepcopy(fgraph, inputs, outputs + additional_outputs)
finally:
# If the optimizer got interrupted
if opt_time is None:
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
aesara.compile.profiling.total_graph_opt_time += opt_time
if profile:
if optimizer_profile is None and hasattr(optimizer, "pre_profile"):
optimizer_profile = optimizer.pre_profile
profile.optimizer_time += opt_time if not no_fgraph_prep:
self.prepare_fgraph(
if config.profile_optimizer: inputs, outputs, additional_outputs, fgraph, optimizer, linker, profile
profile.optimizer_profile = (optimizer, optimizer_profile)
# IF False, if mean the profile for that function was
# explicitly disabled
elif config.profile_optimizer and profile is not False:
warnings.warn(
(
"config.profile_optimizer requires config.profile to "
" be set to True as well"
),
stacklevel=3,
)
if not hasattr(linker, "accept"):
raise ValueError(
"'linker' parameter of FunctionMaker should be "
f"a Linker with an accept method or one of {list(aesara.compile.mode.predefined_linkers.keys())}"
) )
# the 'no_borrow' outputs are the ones for which that we can't # the 'no_borrow' outputs are the ones for which that we can't
# return the internal storage pointer. # return the internal storage pointer.
assert len(fgraph.outputs) == len(outputs + additional_outputs) assert len(fgraph.outputs) == len(outputs + additional_outputs)
no_borrow = [ no_borrow = [
output output
for output, spec in zip(fgraph.outputs, outputs + additional_outputs) for output, spec in zip(fgraph.outputs, outputs + additional_outputs)
if not spec.borrow if not spec.borrow
] ]
if no_borrow: if no_borrow:
self.linker = linker.accept( self.linker = linker.accept(
fgraph, fgraph,
...@@ -1670,6 +1703,7 @@ def orig_function( ...@@ -1670,6 +1703,7 @@ def orig_function(
profile=None, profile=None,
on_unused_input=None, on_unused_input=None,
output_keys=None, output_keys=None,
fgraph: Optional[FunctionGraph] = None,
) -> Function: ) -> Function:
""" """
Return a Function that will calculate the outputs from the inputs. Return a Function that will calculate the outputs from the inputs.
...@@ -1731,6 +1765,7 @@ def orig_function( ...@@ -1731,6 +1765,7 @@ def orig_function(
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
output_keys=output_keys, output_keys=output_keys,
name=name, name=name,
fgraph=fgraph,
) )
with config.change_flags(compute_test_value="off"): with config.change_flags(compute_test_value="off"):
fn = m.create(defaults) fn = m.create(defaults)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论