提交 e51be01c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add FunctionGraph input support to the Function compilation pipeline

These changes allow one to pass an `fgraph` argument to all key functions in the `Function` compilation pipeline. The given `FunctionGraph` will be directly compiled without cloning. Unlike the previous `FunctionMaker.__init__`, this one's `fgraph` argument *will* be optimized according to the given `mode` unless the keyword argument `no_fgraph_prep` is `True`.
上级 90e794c0
...@@ -433,7 +433,9 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -433,7 +433,9 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
""" """
equivalence_tracker = _VariableEquivalenceTracker() equivalence_tracker = _VariableEquivalenceTracker()
fgraph, updates = std_fgraph(input_specs, output_specs, accept_inplace) fgraph, updates = std_fgraph(
input_specs, output_specs, accept_inplace, force_clone=True
)
fgraph.attach_feature(equivalence_tracker) fgraph.attach_feature(equivalence_tracker)
return fgraph, updates, equivalence_tracker return fgraph, updates, equivalence_tracker
...@@ -2006,6 +2008,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2006,6 +2008,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
fgraph=None, # If present the optimized graph. we ignore it. fgraph=None, # If present the optimized graph. we ignore it.
output_keys=None, output_keys=None,
name=None, name=None,
no_fgraph_prep=False,
): ):
self.mode = mode self.mode = mode
self.profile = profile self.profile = profile
......
...@@ -4,6 +4,8 @@ Provide a simple user friendly API. ...@@ -4,6 +4,8 @@ Provide a simple user friendly API.
""" """
import logging import logging
from copy import copy
from typing import Optional
from aesara.compile.function.types import Function, UnusedInputError, orig_function from aesara.compile.function.types import Function, UnusedInputError, orig_function
from aesara.compile.io import In, Out from aesara.compile.io import In, Out
...@@ -11,6 +13,7 @@ from aesara.compile.profiling import ProfileStats ...@@ -11,6 +13,7 @@ from aesara.compile.profiling import ProfileStats
from aesara.compile.sharedvalue import SharedVariable, shared from aesara.compile.sharedvalue import SharedVariable, shared
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
from aesara.graph.fg import FunctionGraph
_logger = logging.getLogger("aesara.compile.function.pfunc") _logger = logging.getLogger("aesara.compile.function.pfunc")
...@@ -279,6 +282,7 @@ def pfunc( ...@@ -279,6 +282,7 @@ def pfunc(
profile=None, profile=None,
on_unused_input=None, on_unused_input=None,
output_keys=None, output_keys=None,
fgraph: Optional[FunctionGraph] = None,
) -> Function: ) -> Function:
""" """
Function-constructor for graphs with shared variables. Function-constructor for graphs with shared variables.
...@@ -324,6 +328,9 @@ def pfunc( ...@@ -324,6 +328,9 @@ def pfunc(
be available via self.profile. be available via self.profile.
on_unused_input : {'raise', 'warn','ignore', None} on_unused_input : {'raise', 'warn','ignore', None}
What to do if a variable in the 'inputs' list is not used in the graph. What to do if a variable in the 'inputs' list is not used in the graph.
fgraph
An existing `FunctionGraph` from which to construct the returned
`Function`. When this is non-``None``, nothing is cloned.
Returns Returns
------- -------
...@@ -358,6 +365,7 @@ def pfunc( ...@@ -358,6 +365,7 @@ def pfunc(
no_default_updates, no_default_updates,
rebuild_strict, rebuild_strict,
allow_input_downcast, allow_input_downcast,
fgraph=fgraph,
) )
return orig_function( return orig_function(
...@@ -369,6 +377,7 @@ def pfunc( ...@@ -369,6 +377,7 @@ def pfunc(
profile=profile, profile=profile,
on_unused_input=on_unused_input, on_unused_input=on_unused_input,
output_keys=output_keys, output_keys=output_keys,
fgraph=fgraph,
) )
...@@ -381,6 +390,7 @@ def construct_pfunc_ins_and_outs( ...@@ -381,6 +390,7 @@ def construct_pfunc_ins_and_outs(
no_default_updates=False, no_default_updates=False,
rebuild_strict=True, rebuild_strict=True,
allow_input_downcast=None, allow_input_downcast=None,
fgraph: Optional[FunctionGraph] = None,
): ):
"""Construct inputs and outputs for `pfunc`. """Construct inputs and outputs for `pfunc`.
...@@ -398,6 +408,9 @@ def construct_pfunc_ins_and_outs( ...@@ -398,6 +408,9 @@ def construct_pfunc_ins_and_outs(
Then it clones the outputs and the update expressions. This Then it clones the outputs and the update expressions. This
rebuilds a computation graph from the inputs and the `givens`. rebuilds a computation graph from the inputs and the `givens`.
When `fgraph` is non-``None``, nothing is cloned and the given `fgraph` is
simply prepared for direct use.
""" """
if updates is None: if updates is None:
updates = [] updates = []
...@@ -459,68 +472,93 @@ def construct_pfunc_ins_and_outs( ...@@ -459,68 +472,93 @@ def construct_pfunc_ins_and_outs(
"aesara.clone_replace(f(x), replace={x: g(x)}))`." "aesara.clone_replace(f(x), replace={x: g(x)}))`."
) )
# Extend the outputs with the updates on input variables so they are also if not fgraph:
# cloned
additional_outputs = [i.update for i in inputs if i.update]
if outputs is None:
out_list = []
else:
if isinstance(outputs, (list, tuple)):
out_list = list(outputs)
else:
out_list = [outputs]
extended_outputs = out_list + additional_outputs
output_vars = rebuild_collect_shared(
extended_outputs,
in_variables,
replace=givens,
updates=updates,
rebuild_strict=rebuild_strict,
copy_inputs_over=True,
no_default_updates=no_default_updates,
)
# extracting the arguments
input_variables, cloned_extended_outputs, other_stuff = output_vars
clone_d, update_d, update_expr, shared_inputs = other_stuff
# Recover only the clones of the original outputs # Extend the outputs with the updates on input variables so they are
if outputs is None: # also cloned
cloned_outputs = [] additional_outputs = [i.update for i in inputs if i.update]
else: if outputs is None:
if isinstance(outputs, (list, tuple)): out_list = []
cloned_outputs = cloned_extended_outputs[: len(outputs)]
else: else:
cloned_outputs = cloned_extended_outputs[0] if isinstance(outputs, (list, tuple)):
out_list = list(outputs)
for i, iv in zip(inputs, input_variables): else:
i.variable = iv out_list = [outputs]
extended_outputs = out_list + additional_outputs
# If needed, replace the input's update by its cloned equivalent
if i.update: output_vars = rebuild_collect_shared(
i.update = clone_d[i.update] extended_outputs,
in_variables,
for sv in shared_inputs: replace=givens,
# pass value of None updates=updates,
# value will be stored in the resulting functions' defaults rebuild_strict=rebuild_strict,
# list but since the value of shared variables never needs to copy_inputs_over=True,
# be refed, it is not needed no_default_updates=no_default_updates,
if sv in update_d: )
si = In( input_variables, cloned_extended_outputs, other_stuff = output_vars
variable=sv, clone_d, update_d, update_expr, shared_inputs = other_stuff
value=sv.container,
mutable=True, # Recover only the clones of the original outputs
borrow=True, if outputs is None:
update=update_d[sv], new_outputs = []
shared=True,
)
else: else:
si = In( if isinstance(outputs, (list, tuple)):
variable=sv, value=sv.container, mutable=False, borrow=True, shared=True new_outputs = cloned_extended_outputs[: len(outputs)]
) else:
inputs.append(si) new_outputs = cloned_extended_outputs[0]
new_inputs = []
for i, iv in zip(inputs, input_variables):
new_i = copy(i)
new_i.variable = iv
# If needed, replace the input's update by its cloned equivalent
if i.update:
new_i.update = clone_d[i.update]
new_inputs.append(new_i)
for sv in shared_inputs:
if sv in update_d:
si = In(
variable=sv,
value=sv.container,
mutable=True,
borrow=True,
update=update_d[sv],
shared=True,
)
else:
si = In(
variable=sv,
value=sv.container,
mutable=False,
borrow=True,
shared=True,
)
new_inputs.append(si)
else:
assert len(fgraph.inputs) == len(inputs)
assert len(fgraph.outputs) == len(outputs)
for fg_inp, inp in zip(fgraph.inputs, inputs):
if fg_inp != getattr(inp, "variable", inp):
raise ValueError(
f"`fgraph`'s input does not match the provided input: {fg_inp}, {inp}"
)
for fg_out, out in zip(fgraph.outputs, outputs):
if fg_out != getattr(out, "variable", out):
raise ValueError(
f"`fgraph`'s output does not match the provided output: {fg_out}, {out}"
)
new_inputs = inputs
new_outputs = outputs
return inputs, cloned_outputs return new_inputs, new_outputs
def _pfunc_param_to_in(param, strict=False, allow_downcast=None): def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论