提交 e588b752 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Minor refactoring of aesara.compile.function.types

- Added more type annotations - Cleaned up and remove comments and docstrings - Replaced `std_graph.features` with a keyword argument - Made `DUPLICATE` an `object` sentinel - etc.
上级 afae8eb0
...@@ -9,6 +9,7 @@ import logging ...@@ -9,6 +9,7 @@ import logging
import time import time
import warnings import warnings
from itertools import chain from itertools import chain
from typing import List, Tuple, Type
import numpy as np import numpy as np
...@@ -25,7 +26,7 @@ from aesara.graph.basic import ( ...@@ -25,7 +26,7 @@ from aesara.graph.basic import (
graph_inputs, graph_inputs,
) )
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.features import PreserveVariableAttributes from aesara.graph.features import Feature, PreserveVariableAttributes
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import HasInnerGraph from aesara.graph.op import HasInnerGraph
from aesara.graph.utils import InconsistencyError, get_variable_trace_string from aesara.graph.utils import InconsistencyError, get_variable_trace_string
...@@ -147,18 +148,23 @@ class Supervisor: ...@@ -147,18 +148,23 @@ class Supervisor:
raise InconsistencyError(f"Trying to destroy a protected variable: {r}") raise InconsistencyError(f"Trying to destroy a protected variable: {r}")
def std_fgraph(input_specs, output_specs, accept_inplace=False): def std_fgraph(
""" input_specs: List[SymbolicInput],
Makes an FunctionGraph corresponding to the input specs and the output output_specs: List[SymbolicOutput],
specs. Any SymbolicInput in the input_specs, if its update field accept_inplace: bool = False,
is not None, will add an output to the FunctionGraph corresponding to that features: List[Type[Feature]] = [PreserveVariableAttributes],
update. The return value is the FunctionGraph as well as a list of ) -> Tuple[FunctionGraph, List[SymbolicOutput]]:
SymbolicOutput instances corresponding to the updates. """Make or set up `FunctionGraph` corresponding to the input specs and the output specs.
Any `SymbolicInput` in the `input_specs`, if its `update` field is not
``None``, will add an output corresponding to that update to the
`FunctionGraph`. The return value is the `FunctionGraph` as well as a list
of `SymbolicOutput` instances corresponding to the updates.
If accept_inplace is False, the graph will be checked for inplace If `accept_inplace` is ``False``, the graph will be checked for in-place
operations and an exception will be raised if it has any. If operations and an exception will be raised if it has any. If
accept_inplace is True, a DestroyHandler will be added to the FunctionGraph `accept_inplace` is ``True``, a `DestroyHandler` will be added to the
if there are any inplace operations. `FunctionGraph` if there are any in-place operations.
The returned FunctionGraph is a clone of the graph between the provided The returned FunctionGraph is a clone of the graph between the provided
inputs and outputs. inputs and outputs.
...@@ -166,8 +172,8 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -166,8 +172,8 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
""" """
orig_inputs = [spec.variable for spec in input_specs] orig_inputs = [spec.variable for spec in input_specs]
# Extract the updates and the mapping between update outputs and # Extract the updates and the mapping between update outputs and the
# the updated inputs. # updated inputs
updates = [] updates = []
update_mapping = {} update_mapping = {}
out_idx = len(output_specs) out_idx = len(output_specs)
...@@ -202,14 +208,11 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -202,14 +208,11 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
) )
# If named nodes are replaced, keep the name # If named nodes are replaced, keep the name
for feature in std_fgraph.features: for feature in features:
fgraph.attach_feature(feature()) fgraph.attach_feature(feature())
return fgraph, list(map(SymbolicOutput, updates)) return fgraph, list(map(SymbolicOutput, updates))
std_fgraph.features = [PreserveVariableAttributes]
class AliasedMemoryError(Exception): class AliasedMemoryError(Exception):
""" """
Memory is aliased that should not be. Memory is aliased that should not be.
...@@ -217,8 +220,8 @@ class AliasedMemoryError(Exception): ...@@ -217,8 +220,8 @@ class AliasedMemoryError(Exception):
""" """
# unique id object used as a placeholder for duplicate entries # A sentinel for duplicate entries
DUPLICATE = ["DUPLICATE"] DUPLICATE = object()
class Function: class Function:
...@@ -523,15 +526,14 @@ class Function: ...@@ -523,15 +526,14 @@ class Function:
self._value = ValueAttribute() self._value = ValueAttribute()
self._container = ContainerAttribute() self._container = ContainerAttribute()
# Compute self.n_returned_outputs.
# This is used only when fn.need_update_inputs is False
# because we're using one of the VM objects and it is
# putting updates back into the input containers all by itself.
assert len(self.maker.expanded_inputs) == len(self.input_storage) assert len(self.maker.expanded_inputs) == len(self.input_storage)
self.n_returned_outputs = len(self.output_storage)
for input in self.maker.expanded_inputs: # This is used only when `fn.need_update_inputs` is `False`, because
if input.update is not None: # we're using one of the VM objects and it is putting updates back into
self.n_returned_outputs -= 1 # the input containers all by itself.
self.n_returned_outputs = len(self.output_storage) - sum(
inp.update is not None for inp in self.maker.expanded_inputs
)
for node in self.maker.fgraph.apply_nodes: for node in self.maker.fgraph.apply_nodes:
if isinstance(node.op, HasInnerGraph): if isinstance(node.op, HasInnerGraph):
...@@ -836,8 +838,8 @@ class Function: ...@@ -836,8 +838,8 @@ class Function:
# Set positional arguments # Set positional arguments
i = 0 i = 0
for arg in args: for arg in args:
# TODO: provide a Param option for skipping the filter if we # TODO: provide a option for skipping the filter if we really
# really want speed. # want speed.
s = self.input_storage[i] s = self.input_storage[i]
# see this emails for a discuation about None as input # see this emails for a discuation about None as input
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5 # https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
...@@ -1132,7 +1134,7 @@ def _pickle_Function(f): ...@@ -1132,7 +1134,7 @@ def _pickle_Function(f):
): ):
if np.may_share_memory(d_i, d_j): if np.may_share_memory(d_i, d_j):
if f.pickle_aliased_memory_strategy == "warn": if f.pickle_aliased_memory_strategy == "warn":
_logger.warning( warnings.warn(
"aliased relationship between " "aliased relationship between "
f"Function arguments {d_i}, {d_j} " f"Function arguments {d_i}, {d_j} "
"will not be preserved by " "will not be preserved by "
...@@ -1167,21 +1169,20 @@ copyreg.pickle(Function, _pickle_Function) ...@@ -1167,21 +1169,20 @@ copyreg.pickle(Function, _pickle_Function)
def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
""" """Insert deepcopy in the fgraph to break aliasing of outputs.
Insert deepcopy in the fgraph to break aliasing of outputs
""" This loop was inserted to remove aliasing between outputs when they all
# This loop was inserted to remove aliasing between outputs when evaluate to the same value. Originally it was OK for outputs to be aliased,
# they all evaluate to the same value. Originally it was OK for but some of the outputs can be shared variables, and is not good for shared
# outputs to be aliased, but some of the outputs can be shared variables to be aliased. It might be possible to optimize this by making
# variables, and is not good for shared variables to be sure there is no aliasing only between shared variables.
# aliased. It might be possible to optimize this by making sure
# there is no aliasing only between shared variables.
# If some outputs are constant, we add deep copy to respect the If some outputs are constant, we add deep copy to respect the memory
# memory contract contract
# We don't insert deep copy when the output.borrow is True for all We don't insert deep copy when :attr:`SymbolicOutput.borrow` is ``True``
# concerned outputs. for all concerned outputs.
"""
assert len(wrapped_inputs) == len(fgraph.inputs) assert len(wrapped_inputs) == len(fgraph.inputs)
assert len(wrapped_outputs) == len(fgraph.outputs) assert len(wrapped_outputs) == len(fgraph.outputs)
...@@ -1237,7 +1238,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1237,7 +1238,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
"output", "output",
i, i,
view_op(fgraph.outputs[i]), view_op(fgraph.outputs[i]),
reason="insert_deepcopy", reason=reason,
) )
break break
else: else:
...@@ -1245,7 +1246,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1245,7 +1246,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
"output", "output",
i, i,
deep_copy_op(fgraph.outputs[i]), deep_copy_op(fgraph.outputs[i]),
reason="insert_deepcopy", reason=reason,
) )
break break
elif wrapped_outputs[i].borrow: elif wrapped_outputs[i].borrow:
...@@ -1253,7 +1254,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1253,7 +1254,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
"output", "output",
i, i,
view_op(fgraph.outputs[i]), view_op(fgraph.outputs[i]),
reason="insert_deepcopy", reason=reason,
) )
break break
else: else:
...@@ -1261,7 +1262,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1261,7 +1262,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
"output", "output",
i, i,
deep_copy_op(fgraph.outputs[i]), deep_copy_op(fgraph.outputs[i]),
reason="insert_deepcopy", reason=reason,
) )
break break
...@@ -1453,7 +1454,6 @@ class FunctionMaker: ...@@ -1453,7 +1454,6 @@ class FunctionMaker:
self.fgraph = fgraph self.fgraph = fgraph
# Fetch the optimizer and linker
optimizer, linker = mode.optimizer, copy.copy(mode.linker) optimizer, linker = mode.optimizer, copy.copy(mode.linker)
if need_opt: if need_opt:
# Why we add stack on node when it get done in output var? # Why we add stack on node when it get done in output var?
...@@ -1504,7 +1504,6 @@ class FunctionMaker: ...@@ -1504,7 +1504,6 @@ class FunctionMaker:
stacklevel=3, stacklevel=3,
) )
# initialize the linker
if not hasattr(linker, "accept"): if not hasattr(linker, "accept"):
raise ValueError( raise ValueError(
"'linker' parameter of FunctionMaker should be " "'linker' parameter of FunctionMaker should be "
...@@ -1678,7 +1677,7 @@ def orig_function( ...@@ -1678,7 +1677,7 @@ def orig_function(
profile=None, profile=None,
on_unused_input=None, on_unused_input=None,
output_keys=None, output_keys=None,
): ) -> Function:
""" """
Return a Function that will calculate the outputs from the inputs. Return a Function that will calculate the outputs from the inputs.
...@@ -1701,35 +1700,21 @@ def orig_function( ...@@ -1701,35 +1700,21 @@ def orig_function(
profile : None or ProfileStats instance profile : None or ProfileStats instance
on_unused_input : {'raise', 'warn', 'ignore', None} on_unused_input : {'raise', 'warn', 'ignore', None}
What to do if a variable in the 'inputs' list is not used in the graph. What to do if a variable in the 'inputs' list is not used in the graph.
output_keys : output_keys
If the outputs were provided to aesara.function as a list, then If the outputs were provided to aesara.function as a list, then
output_keys is None. Otherwise, if outputs were provided as a dict, output_keys is None. Otherwise, if outputs were provided as a dict,
output_keys is the sorted list of keys from the outputs. output_keys is the sorted list of keys from the outputs.
fgraph
Notes An existing `FunctionGraph` to use instead of constructing a new one
----- from cloned `outputs`.
Currently, the library provides the following mode strings:
- FAST_RUN (default) (optimize without too much time)
- FAST_COMPILE (minimal optimization)
- DebugMode: verify many internal conditions that are normally assumed
(slow)
""" """
# Every element of the input list will be upgraded to an `In` instance if
# necessary, using the rules implemented by the `convert_function_input`
# function.
# Similarly, every element of the output list will be upgraded to an `Out`
# instance if necessary:
t1 = time.time() t1 = time.time()
mode = aesara.compile.mode.get_mode(mode) mode = aesara.compile.mode.get_mode(mode)
inputs = list(map(convert_function_input, inputs)) inputs = list(map(convert_function_input, inputs))
if outputs is not None: if outputs is not None:
if isinstance(outputs, (list, tuple)): if isinstance(outputs, (list, tuple)):
outputs = list(map(FunctionMaker.wrap_out, outputs)) outputs = list(map(FunctionMaker.wrap_out, outputs))
...@@ -1738,8 +1723,9 @@ def orig_function( ...@@ -1738,8 +1723,9 @@ def orig_function(
defaults = [getattr(input, "value", None) for input in inputs] defaults = [getattr(input, "value", None) for input in inputs]
if isinstance(mode, (list, tuple)): # "mode comparison" semantics if isinstance(mode, (list, tuple)):
raise Exception("We do not support the passing of multiple modes") raise ValueError("We do not support the passing of multiple modes")
fn = None fn = None
try: try:
Maker = getattr(mode, "function_maker", FunctionMaker) Maker = getattr(mode, "function_maker", FunctionMaker)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论