提交 e51be01c authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add FunctionGraph input support to the Function compilation pipeline

These changes allow one to pass an `fgraph` argument to all key functions in the `Function` compilation pipeline. The given `FunctionGraph` will be directly compiled without cloning. Unlike the previous `FunctionMaker.__init__`, this one's `fgraph` argument *will* be optimized according to the given `mode` unless the keyword argument `no_fgraph_prep` is `True`.
上级 90e794c0
......@@ -433,7 +433,9 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
"""
equivalence_tracker = _VariableEquivalenceTracker()
fgraph, updates = std_fgraph(input_specs, output_specs, accept_inplace)
fgraph, updates = std_fgraph(
input_specs, output_specs, accept_inplace, force_clone=True
)
fgraph.attach_feature(equivalence_tracker)
return fgraph, updates, equivalence_tracker
......@@ -2006,6 +2008,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
fgraph=None, # If present the optimized graph. we ignore it.
output_keys=None,
name=None,
no_fgraph_prep=False,
):
self.mode = mode
self.profile = profile
......
......@@ -4,6 +4,8 @@ Provide a simple user friendly API.
"""
import logging
from copy import copy
from typing import Optional
from aesara.compile.function.types import Function, UnusedInputError, orig_function
from aesara.compile.io import In, Out
......@@ -11,6 +13,7 @@ from aesara.compile.profiling import ProfileStats
from aesara.compile.sharedvalue import SharedVariable, shared
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.graph.fg import FunctionGraph
_logger = logging.getLogger("aesara.compile.function.pfunc")
......@@ -279,6 +282,7 @@ def pfunc(
profile=None,
on_unused_input=None,
output_keys=None,
fgraph: Optional[FunctionGraph] = None,
) -> Function:
"""
Function-constructor for graphs with shared variables.
......@@ -324,6 +328,9 @@ def pfunc(
be available via self.profile.
on_unused_input : {'raise', 'warn','ignore', None}
What to do if a variable in the 'inputs' list is not used in the graph.
fgraph
An existing `FunctionGraph` from which to construct the returned
`Function`. When this is non-``None``, nothing is cloned.
Returns
-------
......@@ -358,6 +365,7 @@ def pfunc(
no_default_updates,
rebuild_strict,
allow_input_downcast,
fgraph=fgraph,
)
return orig_function(
......@@ -369,6 +377,7 @@ def pfunc(
profile=profile,
on_unused_input=on_unused_input,
output_keys=output_keys,
fgraph=fgraph,
)
......@@ -381,6 +390,7 @@ def construct_pfunc_ins_and_outs(
no_default_updates=False,
rebuild_strict=True,
allow_input_downcast=None,
fgraph: Optional[FunctionGraph] = None,
):
"""Construct inputs and outputs for `pfunc`.
......@@ -398,6 +408,9 @@ def construct_pfunc_ins_and_outs(
Then it clones the outputs and the update expressions. This
rebuilds a computation graph from the inputs and the `givens`.
When `fgraph` is non-``None``, nothing is cloned and the given `fgraph` is
simply prepared for direct use.
"""
if updates is None:
updates = []
......@@ -459,68 +472,93 @@ def construct_pfunc_ins_and_outs(
"aesara.clone_replace(f(x), replace={x: g(x)}))`."
)
# Extend the outputs with the updates on input variables so they are also
# cloned
additional_outputs = [i.update for i in inputs if i.update]
if outputs is None:
out_list = []
else:
if isinstance(outputs, (list, tuple)):
out_list = list(outputs)
else:
out_list = [outputs]
extended_outputs = out_list + additional_outputs
output_vars = rebuild_collect_shared(
extended_outputs,
in_variables,
replace=givens,
updates=updates,
rebuild_strict=rebuild_strict,
copy_inputs_over=True,
no_default_updates=no_default_updates,
)
# extracting the arguments
input_variables, cloned_extended_outputs, other_stuff = output_vars
clone_d, update_d, update_expr, shared_inputs = other_stuff
if not fgraph:
# Recover only the clones of the original outputs
if outputs is None:
cloned_outputs = []
else:
if isinstance(outputs, (list, tuple)):
cloned_outputs = cloned_extended_outputs[: len(outputs)]
# Extend the outputs with the updates on input variables so they are
# also cloned
additional_outputs = [i.update for i in inputs if i.update]
if outputs is None:
out_list = []
else:
cloned_outputs = cloned_extended_outputs[0]
for i, iv in zip(inputs, input_variables):
i.variable = iv
# If needed, replace the input's update by its cloned equivalent
if i.update:
i.update = clone_d[i.update]
for sv in shared_inputs:
# pass value of None
# value will be stored in the resulting functions' defaults
# list but since the value of shared variables never needs to
# be refed, it is not needed
if sv in update_d:
si = In(
variable=sv,
value=sv.container,
mutable=True,
borrow=True,
update=update_d[sv],
shared=True,
)
if isinstance(outputs, (list, tuple)):
out_list = list(outputs)
else:
out_list = [outputs]
extended_outputs = out_list + additional_outputs
output_vars = rebuild_collect_shared(
extended_outputs,
in_variables,
replace=givens,
updates=updates,
rebuild_strict=rebuild_strict,
copy_inputs_over=True,
no_default_updates=no_default_updates,
)
input_variables, cloned_extended_outputs, other_stuff = output_vars
clone_d, update_d, update_expr, shared_inputs = other_stuff
# Recover only the clones of the original outputs
if outputs is None:
new_outputs = []
else:
si = In(
variable=sv, value=sv.container, mutable=False, borrow=True, shared=True
)
inputs.append(si)
if isinstance(outputs, (list, tuple)):
new_outputs = cloned_extended_outputs[: len(outputs)]
else:
new_outputs = cloned_extended_outputs[0]
new_inputs = []
for i, iv in zip(inputs, input_variables):
new_i = copy(i)
new_i.variable = iv
# If needed, replace the input's update by its cloned equivalent
if i.update:
new_i.update = clone_d[i.update]
new_inputs.append(new_i)
for sv in shared_inputs:
if sv in update_d:
si = In(
variable=sv,
value=sv.container,
mutable=True,
borrow=True,
update=update_d[sv],
shared=True,
)
else:
si = In(
variable=sv,
value=sv.container,
mutable=False,
borrow=True,
shared=True,
)
new_inputs.append(si)
else:
assert len(fgraph.inputs) == len(inputs)
assert len(fgraph.outputs) == len(outputs)
for fg_inp, inp in zip(fgraph.inputs, inputs):
if fg_inp != getattr(inp, "variable", inp):
raise ValueError(
f"`fgraph`'s input does not match the provided input: {fg_inp}, {inp}"
)
for fg_out, out in zip(fgraph.outputs, outputs):
if fg_out != getattr(out, "variable", out):
raise ValueError(
f"`fgraph`'s output does not match the provided output: {fg_out}, {out}"
)
new_inputs = inputs
new_outputs = outputs
return inputs, cloned_outputs
return new_inputs, new_outputs
def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论