提交 e588b752 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Minor refactoring of aesara.compile.function.types

- Added more type annotations - Cleaned up and remove comments and docstrings - Replaced `std_graph.features` with a keyword argument - Made `DUPLICATE` an `object` sentinel - etc.
上级 afae8eb0
......@@ -9,6 +9,7 @@ import logging
import time
import warnings
from itertools import chain
from typing import List, Tuple, Type
import numpy as np
......@@ -25,7 +26,7 @@ from aesara.graph.basic import (
graph_inputs,
)
from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import PreserveVariableAttributes
from aesara.graph.features import Feature, PreserveVariableAttributes
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import HasInnerGraph
from aesara.graph.utils import InconsistencyError, get_variable_trace_string
......@@ -147,18 +148,23 @@ class Supervisor:
raise InconsistencyError(f"Trying to destroy a protected variable: {r}")
def std_fgraph(input_specs, output_specs, accept_inplace=False):
"""
Makes an FunctionGraph corresponding to the input specs and the output
specs. Any SymbolicInput in the input_specs, if its update field
is not None, will add an output to the FunctionGraph corresponding to that
update. The return value is the FunctionGraph as well as a list of
SymbolicOutput instances corresponding to the updates.
def std_fgraph(
input_specs: List[SymbolicInput],
output_specs: List[SymbolicOutput],
accept_inplace: bool = False,
features: List[Type[Feature]] = [PreserveVariableAttributes],
) -> Tuple[FunctionGraph, List[SymbolicOutput]]:
"""Make or set up `FunctionGraph` corresponding to the input specs and the output specs.
Any `SymbolicInput` in the `input_specs`, if its `update` field is not
``None``, will add an output corresponding to that update to the
`FunctionGraph`. The return value is the `FunctionGraph` as well as a list
of `SymbolicOutput` instances corresponding to the updates.
If accept_inplace is False, the graph will be checked for inplace
If `accept_inplace` is ``False``, the graph will be checked for in-place
operations and an exception will be raised if it has any. If
accept_inplace is True, a DestroyHandler will be added to the FunctionGraph
if there are any inplace operations.
`accept_inplace` is ``True``, a `DestroyHandler` will be added to the
`FunctionGraph` if there are any in-place operations.
The returned FunctionGraph is a clone of the graph between the provided
inputs and outputs.
......@@ -166,8 +172,8 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
"""
orig_inputs = [spec.variable for spec in input_specs]
# Extract the updates and the mapping between update outputs and
# the updated inputs.
# Extract the updates and the mapping between update outputs and the
# updated inputs
updates = []
update_mapping = {}
out_idx = len(output_specs)
......@@ -202,14 +208,11 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
)
# If named nodes are replaced, keep the name
for feature in std_fgraph.features:
for feature in features:
fgraph.attach_feature(feature())
return fgraph, list(map(SymbolicOutput, updates))
std_fgraph.features = [PreserveVariableAttributes]
class AliasedMemoryError(Exception):
"""
Memory is aliased that should not be.
......@@ -217,8 +220,8 @@ class AliasedMemoryError(Exception):
"""
# unique id object used as a placeholder for duplicate entries
DUPLICATE = ["DUPLICATE"]
# A sentinel for duplicate entries
DUPLICATE = object()
class Function:
......@@ -523,15 +526,14 @@ class Function:
self._value = ValueAttribute()
self._container = ContainerAttribute()
# Compute self.n_returned_outputs.
# This is used only when fn.need_update_inputs is False
# because we're using one of the VM objects and it is
# putting updates back into the input containers all by itself.
assert len(self.maker.expanded_inputs) == len(self.input_storage)
self.n_returned_outputs = len(self.output_storage)
for input in self.maker.expanded_inputs:
if input.update is not None:
self.n_returned_outputs -= 1
# This is used only when `fn.need_update_inputs` is `False`, because
# we're using one of the VM objects and it is putting updates back into
# the input containers all by itself.
self.n_returned_outputs = len(self.output_storage) - sum(
inp.update is not None for inp in self.maker.expanded_inputs
)
for node in self.maker.fgraph.apply_nodes:
if isinstance(node.op, HasInnerGraph):
......@@ -836,8 +838,8 @@ class Function:
# Set positional arguments
i = 0
for arg in args:
# TODO: provide a Param option for skipping the filter if we
# really want speed.
# TODO: provide a option for skipping the filter if we really
# want speed.
s = self.input_storage[i]
# see this emails for a discuation about None as input
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
......@@ -1132,7 +1134,7 @@ def _pickle_Function(f):
):
if np.may_share_memory(d_i, d_j):
if f.pickle_aliased_memory_strategy == "warn":
_logger.warning(
warnings.warn(
"aliased relationship between "
f"Function arguments {d_i}, {d_j} "
"will not be preserved by "
......@@ -1167,21 +1169,20 @@ copyreg.pickle(Function, _pickle_Function)
def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
"""
Insert deepcopy in the fgraph to break aliasing of outputs
"""
# This loop was inserted to remove aliasing between outputs when
# they all evaluate to the same value. Originally it was OK for
# outputs to be aliased, but some of the outputs can be shared
# variables, and is not good for shared variables to be
# aliased. It might be possible to optimize this by making sure
# there is no aliasing only between shared variables.
"""Insert deepcopy in the fgraph to break aliasing of outputs.
This loop was inserted to remove aliasing between outputs when they all
evaluate to the same value. Originally it was OK for outputs to be aliased,
but some of the outputs can be shared variables, and is not good for shared
variables to be aliased. It might be possible to optimize this by making
sure there is no aliasing only between shared variables.
# If some outputs are constant, we add deep copy to respect the
# memory contract
If some outputs are constant, we add deep copy to respect the memory
contract
# We don't insert deep copy when the output.borrow is True for all
# concerned outputs.
We don't insert deep copy when :attr:`SymbolicOutput.borrow` is ``True``
for all concerned outputs.
"""
assert len(wrapped_inputs) == len(fgraph.inputs)
assert len(wrapped_outputs) == len(fgraph.outputs)
......@@ -1237,7 +1238,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
"output",
i,
view_op(fgraph.outputs[i]),
reason="insert_deepcopy",
reason=reason,
)
break
else:
......@@ -1245,7 +1246,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
"output",
i,
deep_copy_op(fgraph.outputs[i]),
reason="insert_deepcopy",
reason=reason,
)
break
elif wrapped_outputs[i].borrow:
......@@ -1253,7 +1254,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
"output",
i,
view_op(fgraph.outputs[i]),
reason="insert_deepcopy",
reason=reason,
)
break
else:
......@@ -1261,7 +1262,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
"output",
i,
deep_copy_op(fgraph.outputs[i]),
reason="insert_deepcopy",
reason=reason,
)
break
......@@ -1453,7 +1454,6 @@ class FunctionMaker:
self.fgraph = fgraph
# Fetch the optimizer and linker
optimizer, linker = mode.optimizer, copy.copy(mode.linker)
if need_opt:
# Why we add stack on node when it get done in output var?
......@@ -1504,7 +1504,6 @@ class FunctionMaker:
stacklevel=3,
)
# initialize the linker
if not hasattr(linker, "accept"):
raise ValueError(
"'linker' parameter of FunctionMaker should be "
......@@ -1678,7 +1677,7 @@ def orig_function(
profile=None,
on_unused_input=None,
output_keys=None,
):
) -> Function:
"""
Return a Function that will calculate the outputs from the inputs.
......@@ -1701,35 +1700,21 @@ def orig_function(
profile : None or ProfileStats instance
on_unused_input : {'raise', 'warn', 'ignore', None}
What to do if a variable in the 'inputs' list is not used in the graph.
output_keys :
output_keys
If the outputs were provided to aesara.function as a list, then
output_keys is None. Otherwise, if outputs were provided as a dict,
output_keys is the sorted list of keys from the outputs.
Notes
-----
Currently, the library provides the following mode strings:
- FAST_RUN (default) (optimize without too much time)
- FAST_COMPILE (minimal optimization)
- DebugMode: verify many internal conditions that are normally assumed
(slow)
fgraph
An existing `FunctionGraph` to use instead of constructing a new one
from cloned `outputs`.
"""
# Every element of the input list will be upgraded to an `In` instance if
# necessary, using the rules implemented by the `convert_function_input`
# function.
# Similarly, every element of the output list will be upgraded to an `Out`
# instance if necessary:
t1 = time.time()
mode = aesara.compile.mode.get_mode(mode)
inputs = list(map(convert_function_input, inputs))
if outputs is not None:
if isinstance(outputs, (list, tuple)):
outputs = list(map(FunctionMaker.wrap_out, outputs))
......@@ -1738,8 +1723,9 @@ def orig_function(
defaults = [getattr(input, "value", None) for input in inputs]
if isinstance(mode, (list, tuple)): # "mode comparison" semantics
raise Exception("We do not support the passing of multiple modes")
if isinstance(mode, (list, tuple)):
raise ValueError("We do not support the passing of multiple modes")
fn = None
try:
Maker = getattr(mode, "function_maker", FunctionMaker)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论