提交 e51be01c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add FunctionGraph input support to the Function compilation pipeline

These changes allow one to pass an `fgraph` argument to all key functions in the `Function` compilation pipeline. The given `FunctionGraph` will be directly compiled without cloning. Unlike the previous `FunctionMaker.__init__`, this one's `fgraph` argument *will* be optimized according to the given `mode` unless the keyword argument `no_fgraph_prep` is `True`.
上级 90e794c0
......@@ -433,7 +433,9 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
"""
equivalence_tracker = _VariableEquivalenceTracker()
fgraph, updates = std_fgraph(input_specs, output_specs, accept_inplace)
fgraph, updates = std_fgraph(
input_specs, output_specs, accept_inplace, force_clone=True
)
fgraph.attach_feature(equivalence_tracker)
return fgraph, updates, equivalence_tracker
......@@ -2006,6 +2008,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
fgraph=None, # If present the optimized graph. we ignore it.
output_keys=None,
name=None,
no_fgraph_prep=False,
):
self.mode = mode
self.profile = profile
......
......@@ -4,6 +4,8 @@ Provide a simple user friendly API.
"""
import logging
from copy import copy
from typing import Optional
from aesara.compile.function.types import Function, UnusedInputError, orig_function
from aesara.compile.io import In, Out
......@@ -11,6 +13,7 @@ from aesara.compile.profiling import ProfileStats
from aesara.compile.sharedvalue import SharedVariable, shared
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.graph.fg import FunctionGraph
_logger = logging.getLogger("aesara.compile.function.pfunc")
......@@ -279,6 +282,7 @@ def pfunc(
profile=None,
on_unused_input=None,
output_keys=None,
fgraph: Optional[FunctionGraph] = None,
) -> Function:
"""
Function-constructor for graphs with shared variables.
......@@ -324,6 +328,9 @@ def pfunc(
be available via self.profile.
on_unused_input : {'raise', 'warn','ignore', None}
What to do if a variable in the 'inputs' list is not used in the graph.
fgraph
An existing `FunctionGraph` from which to construct the returned
`Function`. When this is non-``None``, nothing is cloned.
Returns
-------
......@@ -358,6 +365,7 @@ def pfunc(
no_default_updates,
rebuild_strict,
allow_input_downcast,
fgraph=fgraph,
)
return orig_function(
......@@ -369,6 +377,7 @@ def pfunc(
profile=profile,
on_unused_input=on_unused_input,
output_keys=output_keys,
fgraph=fgraph,
)
......@@ -381,6 +390,7 @@ def construct_pfunc_ins_and_outs(
no_default_updates=False,
rebuild_strict=True,
allow_input_downcast=None,
fgraph: Optional[FunctionGraph] = None,
):
"""Construct inputs and outputs for `pfunc`.
......@@ -398,6 +408,9 @@ def construct_pfunc_ins_and_outs(
Then it clones the outputs and the update expressions. This
rebuilds a computation graph from the inputs and the `givens`.
When `fgraph` is non-``None``, nothing is cloned and the given `fgraph` is
simply prepared for direct use.
"""
if updates is None:
updates = []
......@@ -459,68 +472,93 @@ def construct_pfunc_ins_and_outs(
"aesara.clone_replace(f(x), replace={x: g(x)}))`."
)
# Extend the outputs with the updates on input variables so they are also
# cloned
additional_outputs = [i.update for i in inputs if i.update]
if outputs is None:
out_list = []
else:
if isinstance(outputs, (list, tuple)):
out_list = list(outputs)
else:
out_list = [outputs]
extended_outputs = out_list + additional_outputs
output_vars = rebuild_collect_shared(
extended_outputs,
in_variables,
replace=givens,
updates=updates,
rebuild_strict=rebuild_strict,
copy_inputs_over=True,
no_default_updates=no_default_updates,
)
# extracting the arguments
input_variables, cloned_extended_outputs, other_stuff = output_vars
clone_d, update_d, update_expr, shared_inputs = other_stuff
if not fgraph:
# Recover only the clones of the original outputs
if outputs is None:
cloned_outputs = []
else:
if isinstance(outputs, (list, tuple)):
cloned_outputs = cloned_extended_outputs[: len(outputs)]
# Extend the outputs with the updates on input variables so they are
# also cloned
additional_outputs = [i.update for i in inputs if i.update]
if outputs is None:
out_list = []
else:
cloned_outputs = cloned_extended_outputs[0]
for i, iv in zip(inputs, input_variables):
i.variable = iv
# If needed, replace the input's update by its cloned equivalent
if i.update:
i.update = clone_d[i.update]
for sv in shared_inputs:
# pass value of None
# value will be stored in the resulting functions' defaults
# list but since the value of shared variables never needs to
# be refed, it is not needed
if sv in update_d:
si = In(
variable=sv,
value=sv.container,
mutable=True,
borrow=True,
update=update_d[sv],
shared=True,
)
if isinstance(outputs, (list, tuple)):
out_list = list(outputs)
else:
out_list = [outputs]
extended_outputs = out_list + additional_outputs
output_vars = rebuild_collect_shared(
extended_outputs,
in_variables,
replace=givens,
updates=updates,
rebuild_strict=rebuild_strict,
copy_inputs_over=True,
no_default_updates=no_default_updates,
)
input_variables, cloned_extended_outputs, other_stuff = output_vars
clone_d, update_d, update_expr, shared_inputs = other_stuff
# Recover only the clones of the original outputs
if outputs is None:
new_outputs = []
else:
si = In(
variable=sv, value=sv.container, mutable=False, borrow=True, shared=True
)
inputs.append(si)
if isinstance(outputs, (list, tuple)):
new_outputs = cloned_extended_outputs[: len(outputs)]
else:
new_outputs = cloned_extended_outputs[0]
new_inputs = []
for i, iv in zip(inputs, input_variables):
new_i = copy(i)
new_i.variable = iv
# If needed, replace the input's update by its cloned equivalent
if i.update:
new_i.update = clone_d[i.update]
new_inputs.append(new_i)
for sv in shared_inputs:
if sv in update_d:
si = In(
variable=sv,
value=sv.container,
mutable=True,
borrow=True,
update=update_d[sv],
shared=True,
)
else:
si = In(
variable=sv,
value=sv.container,
mutable=False,
borrow=True,
shared=True,
)
new_inputs.append(si)
else:
assert len(fgraph.inputs) == len(inputs)
assert len(fgraph.outputs) == len(outputs)
for fg_inp, inp in zip(fgraph.inputs, inputs):
if fg_inp != getattr(inp, "variable", inp):
raise ValueError(
f"`fgraph`'s input does not match the provided input: {fg_inp}, {inp}"
)
for fg_out, out in zip(fgraph.outputs, outputs):
if fg_out != getattr(out, "variable", out):
raise ValueError(
f"`fgraph`'s output does not match the provided output: {fg_out}, {out}"
)
new_inputs = inputs
new_outputs = outputs
return inputs, cloned_outputs
return new_inputs, new_outputs
def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
......@@ -9,7 +9,7 @@ import logging
import time
import warnings
from itertools import chain
from typing import List, Tuple, Type
from typing import List, Optional, Tuple, Type
import numpy as np
......@@ -152,7 +152,9 @@ def std_fgraph(
input_specs: List[SymbolicInput],
output_specs: List[SymbolicOutput],
accept_inplace: bool = False,
fgraph: Optional[FunctionGraph] = None,
features: List[Type[Feature]] = [PreserveVariableAttributes],
force_clone=False,
) -> Tuple[FunctionGraph, List[SymbolicOutput]]:
"""Make or set up `FunctionGraph` corresponding to the input specs and the output specs.
......@@ -166,26 +168,48 @@ def std_fgraph(
`accept_inplace` is ``True``, a `DestroyHandler` will be added to the
`FunctionGraph` if there are any in-place operations.
The returned FunctionGraph is a clone of the graph between the provided
inputs and outputs.
If `fgraph` is ``None``, the returned `FunctionGraph` is a clone of the
graph between the provided inputs and outputs.
"""
orig_inputs = [spec.variable for spec in input_specs]
# Extract the updates and the mapping between update outputs and the
# updated inputs
updates = []
update_mapping = {}
out_idx = len(output_specs)
for inp_idx in range(len(input_specs)):
if input_specs[inp_idx].update:
updates.append(input_specs[inp_idx].update)
update_mapping[out_idx] = inp_idx
for idx, input_spec in enumerate(input_specs):
if input_spec.update:
updates.append(input_spec.update)
update_mapping[out_idx] = idx
out_idx += 1
orig_outputs = [spec.variable for spec in output_specs] + updates
if fgraph:
if fgraph.update_mapping is None:
fgraph.update_mapping = update_mapping
for update in updates:
fgraph.add_output(update, reason="std_fgraph")
else:
input_vars = []
# If one of the inputs is non-atomic (i.e. has a non-`None` `Variable.owner`),
# then we need to create/clone the graph starting at these inputs.
# The result will be atomic versions of the given inputs connected to
# the same outputs.
# Otherwise, when all the inputs are already atomic, there's no need to
# clone the graph.
clone = force_clone
for spec in input_specs:
input_vars.append(spec.variable)
clone |= spec.variable.owner is not None
fgraph = FunctionGraph(
input_vars,
[spec.variable for spec in output_specs] + updates,
update_mapping=update_mapping,
clone=clone,
)
fgraph = FunctionGraph(orig_inputs, orig_outputs, update_mapping=update_mapping)
additional_outputs = list(map(SymbolicOutput, updates))
for node in fgraph.apply_nodes:
if node.op.destroy_map:
......@@ -210,7 +234,8 @@ def std_fgraph(
# If named nodes are replaced, keep the name
for feature in features:
fgraph.attach_feature(feature())
return fgraph, list(map(SymbolicOutput, updates))
return fgraph, additional_outputs
class AliasedMemoryError(Exception):
......@@ -646,6 +671,7 @@ class Function:
[memo[o] for o in out_vars],
clone=False,
)
fg_cpy.update_mapping = maker.fgraph.update_mapping
# Re initialize Outs and swap update and variable in Ins
# By doing this, we can pass FunctionMaker.check_unused_inputs()
......@@ -747,6 +773,7 @@ class Function:
# can contain inplace. DebugMode check
# that.
accept_inplace=True,
no_fgraph_prep=True,
).create(input_storage, storage_map=new_storage_map)
for in_ori, in_cpy, ori, cpy in zip(
......@@ -1378,6 +1405,63 @@ class FunctionMaker:
"Valid values are 'raise', 'warn', and 'ignore'."
)
@staticmethod
def prepare_fgraph(
inputs, outputs, additional_outputs, fgraph, optimizer, linker, profile
):
try:
start_optimizer = time.time()
optimizer_profile = None
opt_time = None
with config.change_flags(
compute_test_value=config.compute_test_value_opt,
traceback__limit=config.traceback__compile_limit,
):
optimizer_profile = optimizer(fgraph)
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
_logger.debug(f"Optimizing took {opt_time:f} seconds")
# Add deep copy to respect the memory interface
insert_deepcopy(fgraph, inputs, outputs + additional_outputs)
finally:
# If the optimizer got interrupted
if opt_time is None:
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
aesara.compile.profiling.total_graph_opt_time += opt_time
if profile:
if optimizer_profile is None and hasattr(optimizer, "pre_profile"):
optimizer_profile = optimizer.pre_profile
profile.optimizer_time += opt_time
if config.profile_optimizer:
profile.optimizer_profile = (optimizer, optimizer_profile)
elif config.profile_optimizer and profile is not False:
# If False, it means the profiling for that function was
# explicitly disabled
warnings.warn(
(
"config.profile_optimizer requires config.profile to "
" be set to True as well"
),
stacklevel=3,
)
if not hasattr(linker, "accept"):
raise ValueError(
"'linker' parameter of FunctionMaker should be "
f"a Linker with an accept method or one of {list(aesara.compile.mode.predefined_linkers.keys())}"
)
def __init__(
self,
inputs,
......@@ -1390,6 +1474,7 @@ class FunctionMaker:
fgraph=None,
output_keys=None,
name=None,
no_fgraph_prep=False,
):
# Save the provided mode, not the instantiated mode.
# The instantiated mode don't pickle and if we unpickle an Aesara
......@@ -1433,84 +1518,32 @@ class FunctionMaker:
indices = [[input, None, [input]] for input in inputs]
if fgraph is None:
need_opt = True
# make the fgraph (copies the graph, creates NEW INPUT AND
# OUTPUT VARIABLES)
fgraph, additional_outputs = std_fgraph(inputs, outputs, accept_inplace)
fgraph, additional_outputs = std_fgraph(
inputs, outputs, accept_inplace, fgraph=fgraph
)
if fgraph.profile is None:
fgraph.profile = profile
else:
# fgraph is already an optimized one
need_opt = False
updates = [spec.update for spec in inputs if spec.update]
additional_outputs = list(map(SymbolicOutput, updates))
self.fgraph = fgraph
optimizer, linker = mode.optimizer, copy.copy(mode.linker)
if need_opt:
# Why we add stack on node when it get done in output var?
try:
start_optimizer = time.time()
# In case there is an error during optimization.
optimizer_profile = None
opt_time = None
with config.change_flags(
compute_test_value=config.compute_test_value_opt,
traceback__limit=config.traceback__compile_limit,
):
optimizer_profile = optimizer(fgraph)
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
_logger.debug(f"Optimizing took {opt_time:f} seconds")
# Add deep copy to respect the memory interface
insert_deepcopy(fgraph, inputs, outputs + additional_outputs)
finally:
# If the optimizer got interrupted
if opt_time is None:
end_optimizer = time.time()
opt_time = end_optimizer - start_optimizer
aesara.compile.profiling.total_graph_opt_time += opt_time
if profile:
if optimizer_profile is None and hasattr(optimizer, "pre_profile"):
optimizer_profile = optimizer.pre_profile
profile.optimizer_time += opt_time
if config.profile_optimizer:
profile.optimizer_profile = (optimizer, optimizer_profile)
# IF False, if mean the profile for that function was
# explicitly disabled
elif config.profile_optimizer and profile is not False:
warnings.warn(
(
"config.profile_optimizer requires config.profile to "
" be set to True as well"
),
stacklevel=3,
)
if not hasattr(linker, "accept"):
raise ValueError(
"'linker' parameter of FunctionMaker should be "
f"a Linker with an accept method or one of {list(aesara.compile.mode.predefined_linkers.keys())}"
if not no_fgraph_prep:
self.prepare_fgraph(
inputs, outputs, additional_outputs, fgraph, optimizer, linker, profile
)
# the 'no_borrow' outputs are the ones for which that we can't
# return the internal storage pointer.
assert len(fgraph.outputs) == len(outputs + additional_outputs)
no_borrow = [
output
for output, spec in zip(fgraph.outputs, outputs + additional_outputs)
if not spec.borrow
]
if no_borrow:
self.linker = linker.accept(
fgraph,
......@@ -1670,6 +1703,7 @@ def orig_function(
profile=None,
on_unused_input=None,
output_keys=None,
fgraph: Optional[FunctionGraph] = None,
) -> Function:
"""
Return a Function that will calculate the outputs from the inputs.
......@@ -1731,6 +1765,7 @@ def orig_function(
on_unused_input=on_unused_input,
output_keys=output_keys,
name=name,
fgraph=fgraph,
)
with config.change_flags(compute_test_value="off"):
fn = m.create(defaults)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论