提交 b7589469 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor VMs and add variable update functionality to all subclasses

All VM implementations now perform variable updates themselves. This leaves some now redundant update code in `Function`, but it also removes some from `Scan.perform`.
上级 0dd73dc4
...@@ -2152,6 +2152,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2152,6 +2152,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
fgraph.name = name fgraph.name = name
self.indices = indices self.indices = indices
self.inputs = inputs self.inputs = inputs
# TODO: Get rid of all this `expanded_inputs` nonsense
self.expanded_inputs = inputs self.expanded_inputs = inputs
self.outputs = outputs self.outputs = outputs
self.unpack_single = unpack_single self.unpack_single = unpack_single
......
...@@ -110,6 +110,9 @@ def fgraph_updated_vars(fgraph, expanded_inputs): ...@@ -110,6 +110,9 @@ def fgraph_updated_vars(fgraph, expanded_inputs):
Reconstruct the full "updates" dictionary, mapping from FunctionGraph input Reconstruct the full "updates" dictionary, mapping from FunctionGraph input
variables to the fgraph outputs that will replace their values. variables to the fgraph outputs that will replace their values.
TODO: Get rid of all this `expanded_inputs` nonsense and use
only `fgraph.update_mapping`.
Returns Returns
------- -------
dict variable -> variable dict variable -> variable
...@@ -555,6 +558,7 @@ class Function: ...@@ -555,6 +558,7 @@ class Function:
self._value = ValueAttribute() self._value = ValueAttribute()
self._container = ContainerAttribute() self._container = ContainerAttribute()
# TODO: Get rid of all this `expanded_inputs` nonsense
assert len(self.maker.expanded_inputs) == len(self.input_storage) assert len(self.maker.expanded_inputs) == len(self.input_storage)
# This is used only when `fn.need_update_inputs` is `False`, because # This is used only when `fn.need_update_inputs` is `False`, because
...@@ -1048,6 +1052,8 @@ class Function: ...@@ -1048,6 +1052,8 @@ class Function:
# WARNING: This circumvents the 'readonly' attribute in x # WARNING: This circumvents the 'readonly' attribute in x
o_container.storage[0] = None o_container.storage[0] = None
# TODO: Get rid of this and `expanded_inputs`, since all the VMs now
# perform the updates themselves
if getattr(self.fn, "need_update_inputs", True): if getattr(self.fn, "need_update_inputs", True):
# Update the inputs that have an update function # Update the inputs that have an update function
for input, storage in reversed( for input, storage in reversed(
...@@ -1565,11 +1571,15 @@ class FunctionMaker: ...@@ -1565,11 +1571,15 @@ class FunctionMaker:
self.linker = linker.accept(fgraph, profile=profile) self.linker = linker.accept(fgraph, profile=profile)
if hasattr(linker, "accept_var_updates"): if hasattr(linker, "accept_var_updates"):
# hacky thing so VMLinker knows about updates # TODO: This is a hack that makes `VMLinker` aware of updates;
# clean this up.
self.linker.accept_var_updates(fgraph_updated_vars(fgraph, inputs)) self.linker.accept_var_updates(fgraph_updated_vars(fgraph, inputs))
fgraph.name = name fgraph.name = name
self.indices = indices self.indices = indices
self.inputs = inputs self.inputs = inputs
# TODO: Get rid of all this `expanded_inputs` nonsense
self.expanded_inputs = inputs self.expanded_inputs = inputs
self.outputs = outputs self.outputs = outputs
self.unpack_single = unpack_single self.unpack_single = unpack_single
......
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
from aesara.link.utils import get_destroy_dependencies
logger = logging.getLogger("aesara.compile.profiling") logger = logging.getLogger("aesara.compile.profiling")
...@@ -1035,12 +1036,14 @@ class ProfileStats: ...@@ -1035,12 +1036,14 @@ class ProfileStats:
if isinstance(val, Constant): if isinstance(val, Constant):
compute_map[val][0] = 1 compute_map[val][0] = 1
destroy_dependencies = get_destroy_dependencies(fgraph)
# Initial executable_nodes # Initial executable_nodes
executable_nodes = set() executable_nodes = set()
for var in fgraph.inputs: for var in fgraph.inputs:
for c, _ in fgraph.clients[var]: for c, _ in fgraph.clients[var]:
if c != "output": if c != "output":
deps = c.inputs + c.destroy_dependencies deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps): if all(compute_map[v][0] for v in deps):
executable_nodes.add(c) executable_nodes.add(c)
...@@ -1163,10 +1166,12 @@ class ProfileStats: ...@@ -1163,10 +1166,12 @@ class ProfileStats:
# smaller, stop this iteration, move to next node # smaller, stop this iteration, move to next node
done_dict[frozen_set] = max_mem_count done_dict[frozen_set] = max_mem_count
destroy_dependencies = get_destroy_dependencies(fgraph)
for var in node.outputs: for var in node.outputs:
for c, _ in fgraph.clients[var]: for c, _ in fgraph.clients[var]:
if c != "output": if c != "output":
deps = c.inputs + c.destroy_dependencies deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps): if all(compute_map[v][0] for v in deps):
new_exec_nodes.add(c) new_exec_nodes.add(c)
......
...@@ -1186,10 +1186,10 @@ def add_vm_configvars(): ...@@ -1186,10 +1186,10 @@ def add_vm_configvars():
config.add( config.add(
"vm__lazy", "vm__lazy",
"Useful only for the vm linkers. When lazy is None," "Useful only for the VM Linkers. When lazy is None,"
" auto detect if lazy evaluation is needed and use the appropriate" " auto detect if lazy evaluation is needed and use the appropriate"
" version. If lazy is True/False, force the version used between" " version. If the C loop isn't being used and lazy is True, use "
" Loop/LoopGC and Stack.", "the Stack VM; otherwise, use the Loop VM.",
ConfigParam("None", apply=_filter_vm_lazy), ConfigParam("None", apply=_filter_vm_lazy),
in_c_key=False, in_c_key=False,
) )
......
...@@ -4,7 +4,7 @@ import re ...@@ -4,7 +4,7 @@ import re
import sys import sys
import traceback import traceback
import warnings import warnings
from collections import Counter from collections import Counter, defaultdict
from keyword import iskeyword from keyword import iskeyword
from operator import itemgetter from operator import itemgetter
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
...@@ -793,3 +793,22 @@ def {fgraph_name}({", ".join(fgraph_input_names)}): ...@@ -793,3 +793,22 @@ def {fgraph_name}({", ".join(fgraph_input_names)}):
) )
return fgraph_def return fgraph_def
def get_destroy_dependencies(fgraph: FunctionGraph) -> Dict[Apply, List[Variable]]:
"""Construct a ``dict`` of nodes to variables that are implicit dependencies induced by `Op.destroy_map` and `Op.view_map`
These variable dependencies are in contrast to each node's inputs, which
are _explicit_ dependencies.
The variables in the values of this ``dict`` would be impossible to compute
after the current key nodes are evaluated, because node.thunk() is going to
destroy a common input variable needed by whatever node owns each variable
in destroy_dependencies.
"""
order = fgraph.orderings()
destroy_dependencies = defaultdict(lambda: [])
for node in fgraph.apply_nodes:
for prereq in order.get(node, []):
destroy_dependencies[node].extend(prereq.outputs)
return destroy_dependencies
...@@ -5,33 +5,77 @@ A VM is not actually different from a Linker, we just decided ...@@ -5,33 +5,77 @@ A VM is not actually different from a Linker, we just decided
VM was a better name at some point. VM was a better name at some point.
""" """
import logging
import platform import platform
import sys import sys
import time import time
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from itertools import zip_longest
from typing import (
TYPE_CHECKING,
Any,
DefaultDict,
Dict,
List,
Optional,
Sequence,
Tuple,
)
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
from aesara.link.basic import Container, LocalLinker from aesara.link.basic import Container, LocalLinker
from aesara.link.c.exceptions import MissingGXX from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import gc_helper, map_storage, raise_with_op from aesara.link.utils import (
gc_helper,
get_destroy_dependencies,
logger = logging.getLogger(__name__) map_storage,
raise_with_op,
)
if TYPE_CHECKING:
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import (
BasicThunkType,
ComputeMapType,
StorageCellType,
StorageMapType,
)
def calculate_reallocate_info(
order: Sequence[Apply],
fgraph: "FunctionGraph",
storage_map: "StorageMapType",
compute_map_re: "ComputeMapType",
dependencies: Dict[Variable, List[Variable]],
) -> Dict[Variable, List[Variable]]:
"""Finds pairs of computed variables that can share a storage cell.
This apparently reduces memory allocations, but its scope is very limited
(e.g. only scalars, only used by the Python VMs without lazy computations).
Parameters
----------
order
List of nodes in compute order.
fgraph
The `FunctionGraph`.
storage_map
Map from variables to their storage cells.
compute_map_re
Reallocation map. TODO
dependencies
Map from variables to the variables that depend on them.
def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, dependencies):
"""
WRITEME : explain the parameters
""" """
reallocated_info = {} reallocated_info = {}
viewed_by = {} viewed_by: Dict[Variable, List[Variable]] = {}
for var in fgraph.variables: for var in fgraph.variables:
viewed_by[var] = [] viewed_by[var] = []
view_of = {} view_of: Dict[Variable, Variable] = {}
pre_allocated = set() pre_allocated = set()
allocated = set() allocated = set()
...@@ -43,20 +87,20 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend ...@@ -43,20 +87,20 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
idx_o = 0 idx_o = 0
for out in node.outputs: for out in node.outputs:
for var in node.outputs: for var in node.outputs:
compute_map_re[var][0] = 1 compute_map_re[var][0] = True
ins = None ins = None
if dmap and idx_o in dmap: if dmap and idx_o in dmap:
idx_v = dmap[idx_o] idx_v = dmap[idx_o]
assert len(idx_v) == 1, ( assert (
"Here we only support the possibility" " to destroy one input" len(idx_v) == 1
) ), "Here we only support the possibility to destroy one input"
ins = node.inputs[idx_v[0]] ins = node.inputs[idx_v[0]]
if vmap and idx_o in vmap: if vmap and idx_o in vmap:
assert ins is None assert ins is None
idx_v = vmap[idx_o] idx_v = vmap[idx_o]
assert len(idx_v) == 1, ( assert (
"Here we only support the possibility" " to view one input" len(idx_v) == 1
) ), "Here we only support the possibility to view one input"
ins = node.inputs[idx_v[0]] ins = node.inputs[idx_v[0]]
if ins is not None: if ins is not None:
assert isinstance(ins, Variable) assert isinstance(ins, Variable)
...@@ -68,24 +112,24 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend ...@@ -68,24 +112,24 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
for ins in node.inputs: for ins in node.inputs:
assert not (ins in view_of and viewed_by[ins]) assert not (ins in view_of and viewed_by[ins])
if ( if (
getattr(ins, "ndim", None) == 0 getattr(ins.type, "ndim", None) == 0
and not storage_map[ins][0] and not storage_map[ins][0]
and ins not in fgraph.outputs and ins not in fgraph.outputs
and ins.owner and ins.owner
and all(compute_map_re[v][0] for v in dependencies.get(ins, [])) and all(compute_map_re[v][0] for v in dependencies.get(ins, []))
and ins not in allocated and ins not in allocated
): ):
# Constant Memory cannot be changed # Constant memory cannot be changed
# Constant and shared variables' storage_map value is not empty # Constant and shared variables' storage_map value is not empty
reuse_out = None reuse_out = None
if ins not in view_of and not viewed_by.get(ins, []): if ins not in view_of and not viewed_by.get(ins, []):
# where gc # where gc
for i in range(idx + 1, len(order)): for i in range(idx + 1, len(order)):
if reuse_out is not None: if reuse_out is not None:
break break # type: ignore
for out in order[i].outputs: for out in order[i].outputs:
if ( if (
getattr(out, "ndim", None) == 0 getattr(out.type, "ndim", None) == 0
and out not in pre_allocated and out not in pre_allocated
and out.type.in_same_class(ins.type) and out.type.in_same_class(ins.type)
): ):
...@@ -108,7 +152,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend ...@@ -108,7 +152,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
break break
for out in order[i].outputs: for out in order[i].outputs:
if ( if (
getattr(out, "ndim", None) == 0 getattr(out.type, "ndim", None) == 0
and out not in pre_allocated and out not in pre_allocated
and (out.type.in_same_class(ins.type)) and (out.type.in_same_class(ins.type))
): ):
...@@ -136,17 +180,6 @@ class VM(ABC): ...@@ -136,17 +180,6 @@ class VM(ABC):
advantage of lazy computation, although they still produce the correct advantage of lazy computation, although they still produce the correct
output for lazy nodes. output for lazy nodes.
Parameters
----------
fgraph : FunctionGraph
The `FunctionGraph` associated with `nodes` and `thunks`.
nodes
A list of nodes in toposort order.
thunks
A list of thunks to execute those nodes, in toposort order.
pre_call_clear
A list of containers to empty at the beginning of each call.
Attributes Attributes
---------- ----------
call_counts call_counts
...@@ -165,7 +198,27 @@ class VM(ABC): ...@@ -165,7 +198,27 @@ class VM(ABC):
""" """
def __init__(self, fgraph, nodes, thunks, pre_call_clear): need_update_inputs = True
def __init__(
self,
fgraph: "FunctionGraph",
nodes: List[Apply],
thunks: List["BasicThunkType"],
pre_call_clear: List["StorageCellType"],
):
r"""
Parameters
----------
fgraph
The `FunctionGraph` associated with `nodes` and `thunks`.
nodes
A list of nodes in toposort order.
thunks
A list of thunks to execute those nodes, in toposort order.
pre_call_clear
A list of containers to empty at the beginning of each call.
"""
if len(nodes) != len(thunks): if len(nodes) != len(thunks):
raise ValueError("`nodes` and `thunks` must be the same length") raise ValueError("`nodes` and `thunks` must be the same length")
...@@ -178,11 +231,6 @@ class VM(ABC): ...@@ -178,11 +231,6 @@ class VM(ABC):
self.call_times = [0] * len(nodes) self.call_times = [0] * len(nodes)
self.time_thunks = False self.time_thunks = False
# This variable (self.need_update_inputs) is overshadowed by
# CLazyLinker in CVM which has an attribute of the same name that
# defaults to 0 (aka False).
self.need_update_inputs = True
@abstractmethod @abstractmethod
def __call__(self): def __call__(self):
r"""Run the virtual machine. r"""Run the virtual machine.
...@@ -232,53 +280,109 @@ class VM(ABC): ...@@ -232,53 +280,109 @@ class VM(ABC):
self.call_counts[i] = 0 self.call_counts[i] = 0
class Loop(VM): class UpdatingVM(VM):
""" """A `VM` that performs updates on its graph's inputs."""
Unconditional start-to-finish program execution in Python.
No garbage collection is allowed on intermediate results.
""" need_update_inputs = False
allow_gc = False def __init__(
self,
fgraph,
nodes,
thunks,
pre_call_clear,
storage_map: "StorageMapType",
input_storage: List["StorageCellType"],
output_storage: List["StorageCellType"],
update_vars: Dict[Variable, Variable],
):
r"""
Parameters
----------
storage_map
A ``dict`` mapping `Variable`\s to single-element lists where a
computed value for each `Variable` may be found.
input_storage
Storage cells for each input.
output_storage
Storage cells for each output.
update_vars
A ``dict`` from input to output variables that specify
output-to-input in-place storage updates that occur after
evaluation of the entire graph (i.e. all the thunks).
"""
super().__init__(fgraph, nodes, thunks, pre_call_clear)
def __call__(self): self.storage_map = storage_map
if self.time_thunks: self.input_storage = input_storage
for cont in self.pre_call_clear: self.output_storage = output_storage
cont[0] = None self.inp_storage_and_out_idx = tuple(
try: (inp_storage, self.fgraph.outputs.index(update_vars[inp]))
for i, (thunk, node) in enumerate(zip(self.thunks, self.nodes)): for inp, inp_storage in zip(self.fgraph.inputs, self.input_storage)
t0 = time.time() if inp in update_vars
thunk() )
t1 = time.time()
self.call_counts[i] += 1
self.call_times[i] += t1 - t0
except Exception:
raise_with_op(self.fgraph, node, thunk)
else:
for cont in self.pre_call_clear:
cont[0] = None
try:
for thunk, node in zip(self.thunks, self.nodes):
thunk()
except Exception:
raise_with_op(self.fgraph, node, thunk)
def perform_updates(self) -> List[Any]:
"""Perform the output-to-input updates and return the output values."""
class LoopGC(VM): # The outputs need to be collected *before* the updates that follow
outputs = [cell[0] for cell in self.output_storage]
for inp_storage, out_idx in self.inp_storage_and_out_idx:
inp_storage[0] = outputs[out_idx]
return outputs
class Loop(UpdatingVM):
"""Unconditional start-to-finish program execution in Python. """Unconditional start-to-finish program execution in Python.
Garbage collection is possible on intermediate results. Garbage collection is possible on intermediate results when the
`post_thunk_clear` constructor argument is non-``None``.
""" """
def __init__(self, fgraph, nodes, thunks, pre_call_clear, post_thunk_clear): def __init__(
super().__init__(fgraph, nodes, thunks, pre_call_clear) self,
self.post_thunk_clear = post_thunk_clear fgraph,
# Some other part of Aesara query that information nodes,
self.allow_gc = True thunks,
if not (len(nodes) == len(thunks) == len(post_thunk_clear)): pre_call_clear,
raise ValueError( storage_map,
"`nodes`, `thunks` and `post_thunk_clear` are not the same lengths" input_storage,
) output_storage,
update_vars,
post_thunk_clear: Optional[List["StorageCellType"]] = None,
):
r"""
Parameters
----------
post_thunk_clear
A list of storage cells for each thunk that should be cleared after
each thunk is evaluated. This is the "garbage collection"
functionality.
"""
super().__init__(
fgraph,
nodes,
thunks,
pre_call_clear,
storage_map,
input_storage,
output_storage,
update_vars,
)
if post_thunk_clear is not None:
if not (len(nodes) == len(thunks) == len(post_thunk_clear)):
raise ValueError(
"`nodes`, `thunks` and `post_thunk_clear` are not the same lengths"
)
# Some other part of Aesara use this information
self.allow_gc = True
self.post_thunk_clear = post_thunk_clear
else:
self.allow_gc = False
self.post_thunk_clear = []
def __call__(self): def __call__(self):
if self.time_thunks: if self.time_thunks:
...@@ -286,8 +390,8 @@ class LoopGC(VM): ...@@ -286,8 +390,8 @@ class LoopGC(VM):
cont[0] = None cont[0] = None
try: try:
i = 0 i = 0
for thunk, node, old_storage in zip( for thunk, node, old_storage in zip_longest(
self.thunks, self.nodes, self.post_thunk_clear self.thunks, self.nodes, self.post_thunk_clear, fillvalue=()
): ):
t0 = time.time() t0 = time.time()
thunk() thunk()
...@@ -303,8 +407,8 @@ class LoopGC(VM): ...@@ -303,8 +407,8 @@ class LoopGC(VM):
for cont in self.pre_call_clear: for cont in self.pre_call_clear:
cont[0] = None cont[0] = None
try: try:
for thunk, node, old_storage in zip( for thunk, node, old_storage in zip_longest(
self.thunks, self.nodes, self.post_thunk_clear self.thunks, self.nodes, self.post_thunk_clear, fillvalue=()
): ):
thunk() thunk()
for old_s in old_storage: for old_s in old_storage:
...@@ -312,8 +416,10 @@ class LoopGC(VM): ...@@ -312,8 +416,10 @@ class LoopGC(VM):
except Exception: except Exception:
raise_with_op(self.fgraph, node, thunk) raise_with_op(self.fgraph, node, thunk)
return self.perform_updates()
class Stack(VM): class Stack(UpdatingVM):
"""Finish-to-start evaluation order of thunks. """Finish-to-start evaluation order of thunks.
This supports lazy evaluation of subtrees and partial computations of This supports lazy evaluation of subtrees and partial computations of
...@@ -348,53 +454,53 @@ class Stack(VM): ...@@ -348,53 +454,53 @@ class Stack(VM):
thunks, thunks,
pre_call_clear, pre_call_clear,
storage_map, storage_map,
compute_map, input_storage,
allow_gc, output_storage,
n_updates, update_vars,
dependencies=None, compute_map: "ComputeMapType",
allow_gc: bool,
dependencies: Optional[Dict[Variable, List[Variable]]] = None,
callback=None, callback=None,
callback_input=None, callback_input=None,
): ):
super().__init__(fgraph, nodes, thunks, pre_call_clear) r"""
Parameters
----------
allow_gc
Determines whether or not garbage collection is performed.
dependencies
TODO
callback
TODO
callback_input
TODO
"""
super().__init__(
fgraph,
nodes,
thunks,
pre_call_clear,
storage_map,
input_storage,
output_storage,
update_vars,
)
self.update_vars = update_vars
self.compute_map = compute_map
self.allow_gc = allow_gc self.allow_gc = allow_gc
self.message = "" self.message = ""
self.base_apply_stack = [o.owner for o in fgraph.outputs if o.owner] self.base_apply_stack = [o.owner for o in fgraph.outputs if o.owner]
self.outputs = fgraph.outputs self.outputs = fgraph.outputs
self.storage_map = storage_map self.variable_shape: Dict[Variable, Any] = {} # Variable -> shape
self.variable_shape = {} # Variable -> shape self.variable_strides: Dict[Variable, Any] = {} # Variable -> strides
self.variable_strides = {} # Variable -> strides self.variable_offset: Dict[Variable, Any] = {} # Variable -> offset
self.variable_offset = {} # Variable -> offset node_idx = {node: i for i, node in enumerate(self.nodes)}
self.compute_map = compute_map self.node_idx = node_idx
self.node_idx = node_idx = {}
self.callback = callback self.callback = callback
self.callback_input = callback_input self.callback_input = callback_input
self.n_updates = n_updates self.destroy_dependencies = get_destroy_dependencies(fgraph)
ords = fgraph.orderings()
for i, node in enumerate(self.nodes):
node_idx[node] = i
# XXX: inconsistent style - why modify node here rather
# than track destroy_dependencies with dictionary like
# storage_map?
#
# destroy_dependencies
# --------------------
# The destroy_dependencies is a list of variables that are implicit
# dependencies induced by destroy_map and view_map (compared to
# node.inputs which are *explicit* dependencies). The variables in
# destroy_dependencies would be impossible to compute after the
# current `node` runs, because node.thunk() is going to destroy a
# common input variable needed by whatever node owns each variable
# in destroy_depenencies.
node.destroy_dependencies = []
if node in ords:
for prereq in ords[node]:
node.destroy_dependencies += prereq.outputs
self.dependencies = dependencies self.dependencies = dependencies
if self.allow_gc and self.dependencies is None: if self.allow_gc and self.dependencies is None:
...@@ -444,10 +550,14 @@ class Stack(VM): ...@@ -444,10 +550,14 @@ class Stack(VM):
# apply_stack contains nodes # apply_stack contains nodes
if output_subset is not None: if output_subset is not None:
first_updated = len(self.outputs) - self.n_updates # Add the outputs that are needed for the in-place updates of the
output_subset = output_subset + list( # inputs in `self.update_vars`
range(first_updated, len(self.outputs)) output_subset = list(output_subset)
) for inp, out in self.update_vars.items():
out_idx = self.fgraph.outputs.index(out)
if out_idx not in output_subset:
output_subset.append(out_idx)
apply_stack = [ apply_stack = [
self.outputs[i].owner for i in output_subset if self.outputs[i].owner self.outputs[i].owner for i in output_subset if self.outputs[i].owner
] ]
...@@ -489,7 +599,7 @@ class Stack(VM): ...@@ -489,7 +599,7 @@ class Stack(VM):
current_apply = apply_stack.pop() current_apply = apply_stack.pop()
current_inputs = current_apply.inputs current_inputs = current_apply.inputs
current_outputs = current_apply.outputs current_outputs = current_apply.outputs
current_deps = current_inputs + current_apply.destroy_dependencies current_deps = current_inputs + self.destroy_dependencies[current_apply]
computed_ins = all(compute_map[v][0] for v in current_deps) computed_ins = all(compute_map[v][0] for v in current_deps)
computed_outs = all(compute_map[v][0] for v in current_outputs) computed_outs = all(compute_map[v][0] for v in current_outputs)
...@@ -671,6 +781,8 @@ class Stack(VM): ...@@ -671,6 +781,8 @@ class Stack(VM):
self.node_cleared_order.append(final_index) self.node_cleared_order.append(final_index)
return self.perform_updates()
class VMLinker(LocalLinker): class VMLinker(LocalLinker):
"""Class that satisfies the `Linker` interface by acting as a `VM` factory. """Class that satisfies the `Linker` interface by acting as a `VM` factory.
...@@ -697,7 +809,7 @@ class VMLinker(LocalLinker): ...@@ -697,7 +809,7 @@ class VMLinker(LocalLinker):
Aesara flag ``vm__lazy`` value. Then if we have a ``None`` (default) we Aesara flag ``vm__lazy`` value. Then if we have a ``None`` (default) we
auto detect if lazy evaluation is needed and use the appropriate auto detect if lazy evaluation is needed and use the appropriate
version. If `lazy` is ``True`` or ``False``, we force the version used version. If `lazy` is ``True`` or ``False``, we force the version used
between `Loop`/`LoopGC` and `Stack`. between `Loop` and `Stack`.
c_thunks c_thunks
If ``None`` or ``True``, don't change the default. If ``False``, don't If ``None`` or ``True``, don't change the default. If ``False``, don't
compile C code for the thunks. compile C code for the thunks.
...@@ -768,7 +880,7 @@ class VMLinker(LocalLinker): ...@@ -768,7 +880,7 @@ class VMLinker(LocalLinker):
TODO: change the logic to remove the reference at the end TODO: change the logic to remove the reference at the end
of the call instead of the start. This will request all VM of the call instead of the start. This will request all VM
implementation (Loop, LoopGC, Stack, CVM).__call__ to implementation (Loop, Stack, CVM).__call__ to
return the user outputs as Function.__call__ won't be able return the user outputs as Function.__call__ won't be able
to find them anymore. to find them anymore.
...@@ -804,9 +916,11 @@ class VMLinker(LocalLinker): ...@@ -804,9 +916,11 @@ class VMLinker(LocalLinker):
"""Records in the `Linker` which variables have update expressions. """Records in the `Linker` which variables have update expressions.
It does not imply that the `Linker` will actually implement these updates It does not imply that the `Linker` will actually implement these updates
(see `need_update_inputs`). This mechanism is admittedly confusing, and (see `VM.need_update_inputs`). This mechanism is admittedly confusing, and
it could use some cleaning up. The base `Linker` object should probably it could use some cleaning up. The base `Linker` object should probably
go away completely. go away completely.
TODO: Remove this after refactoring the `VM`/`Linker` interfaces.
""" """
self.updated_vars = updated_vars self.updated_vars = updated_vars
...@@ -853,6 +967,40 @@ class VMLinker(LocalLinker): ...@@ -853,6 +967,40 @@ class VMLinker(LocalLinker):
dependencies[k] += ls dependencies[k] += ls
return dependencies return dependencies
def reduce_storage_allocations(
self, storage_map: "StorageMapType", order: Sequence[Apply]
) -> Tuple[Variable, ...]:
"""Reuse storage cells in a storage map.
`storage_map` is updated in-place.
When this feature is used, `storage_map` will no longer have a
one-to-one mapping with the original variables, because--for
example--some outputs may share storage with intermediate values.
Returns
-------
A tuple of the variables that were reallocated.
"""
# Collect Reallocation Info
compute_map_re: DefaultDict[Variable, List[bool]] = defaultdict(lambda: [False])
for var in self.fgraph.inputs:
compute_map_re[var][0] = True
if getattr(self.fgraph.profile, "dependencies", None):
dependencies = self.fgraph.profile.dependencies
else:
dependencies = self.compute_gc_dependencies(storage_map)
reallocated_info: Dict[Variable, List[Variable]] = calculate_reallocate_info(
order, self.fgraph, storage_map, compute_map_re, dependencies
)
for pair in reallocated_info.values():
storage_map[pair[1]] = storage_map[pair[0]]
return tuple(reallocated_info.keys())
def make_vm( def make_vm(
self, self,
nodes, nodes,
...@@ -883,12 +1031,12 @@ class VMLinker(LocalLinker): ...@@ -883,12 +1031,12 @@ class VMLinker(LocalLinker):
if self.use_cloop and ( if self.use_cloop and (
self.callback is not None or self.callback_input is not None self.callback is not None or self.callback_input is not None
): ):
logger.warning("CVM does not support callback, using Stack VM.") warnings.warn("CVM does not support callback, using Stack VM.")
if self.use_cloop and config.profile_memory: if self.use_cloop and config.profile_memory:
warnings.warn("CVM does not support memory profile, using Stack VM.") warnings.warn("CVM does not support memory profiling, using Stack VM.")
if not self.use_cloop and self.allow_partial_eval: if not self.use_cloop and self.allow_partial_eval:
warnings.warn( warnings.warn(
"LoopGC does not support partial evaluation, " "using Stack VM." "Loop VM does not support partial evaluation, using Stack VM."
) )
# Needed for allow_gc=True, profiling and storage_map reuse # Needed for allow_gc=True, profiling and storage_map reuse
deps = self.compute_gc_dependencies(storage_map) deps = self.compute_gc_dependencies(storage_map)
...@@ -898,9 +1046,11 @@ class VMLinker(LocalLinker): ...@@ -898,9 +1046,11 @@ class VMLinker(LocalLinker):
thunks, thunks,
pre_call_clear, pre_call_clear,
storage_map, storage_map,
input_storage,
output_storage,
updated_vars,
compute_map, compute_map,
self.allow_gc, self.allow_gc,
len(updated_vars),
dependencies=deps, dependencies=deps,
callback=self.callback, callback=self.callback,
callback_input=self.callback_input, callback_input=self.callback_input,
...@@ -1021,7 +1171,7 @@ class VMLinker(LocalLinker): ...@@ -1021,7 +1171,7 @@ class VMLinker(LocalLinker):
if platform.python_implementation() == "CPython" and c0 != sys.getrefcount( if platform.python_implementation() == "CPython" and c0 != sys.getrefcount(
node_n_inputs node_n_inputs
): ):
logger.warning( warnings.warn(
"Detected reference count inconsistency after CVM construction" "Detected reference count inconsistency after CVM construction"
) )
else: else:
...@@ -1032,21 +1182,17 @@ class VMLinker(LocalLinker): ...@@ -1032,21 +1182,17 @@ class VMLinker(LocalLinker):
lazy = any(th.lazy for th in thunks) lazy = any(th.lazy for th in thunks)
if not lazy: if not lazy:
# there is no conditional in the graph # there is no conditional in the graph
if self.allow_gc: vm = Loop(
vm = LoopGC( self.fgraph,
self.fgraph, nodes,
nodes, thunks,
thunks, pre_call_clear,
pre_call_clear, storage_map,
post_thunk_clear, input_storage,
) output_storage,
else: updated_vars,
vm = Loop( post_thunk_clear if self.allow_gc else None,
self.fgraph, )
nodes,
thunks,
pre_call_clear,
)
else: else:
# Needed when allow_gc=True and profiling # Needed when allow_gc=True and profiling
deps = self.compute_gc_dependencies(storage_map) deps = self.compute_gc_dependencies(storage_map)
...@@ -1056,9 +1202,11 @@ class VMLinker(LocalLinker): ...@@ -1056,9 +1202,11 @@ class VMLinker(LocalLinker):
thunks, thunks,
pre_call_clear, pre_call_clear,
storage_map, storage_map,
input_storage,
output_storage,
updated_vars,
compute_map, compute_map,
self.allow_gc, self.allow_gc,
len(updated_vars),
dependencies=deps, dependencies=deps,
) )
return vm return vm
...@@ -1082,19 +1230,6 @@ class VMLinker(LocalLinker): ...@@ -1082,19 +1230,6 @@ class VMLinker(LocalLinker):
thunks = [] thunks = []
# Collect Reallocation Info
compute_map_re = defaultdict(lambda: [0])
for var in fgraph.inputs:
compute_map_re[var][0] = 1
if getattr(fgraph.profile, "dependencies", None):
dependencies = fgraph.profile.dependencies
else:
dependencies = self.compute_gc_dependencies(storage_map)
reallocated_info = calculate_reallocate_info(
order, fgraph, storage_map, compute_map_re, dependencies
)
t0 = time.time() t0 = time.time()
linker_make_thunk_time = {} linker_make_thunk_time = {}
impl = None impl = None
...@@ -1140,8 +1275,9 @@ class VMLinker(LocalLinker): ...@@ -1140,8 +1275,9 @@ class VMLinker(LocalLinker):
or self.callback or self.callback
or self.callback_input or self.callback_input
): ):
for pair in reallocated_info.values(): reallocated_vars = self.reduce_storage_allocations(storage_map, order)
storage_map[pair[1]] = storage_map[pair[0]] else:
reallocated_vars = ()
computed, last_user = gc_helper(order) computed, last_user = gc_helper(order)
if self.allow_gc: if self.allow_gc:
...@@ -1153,7 +1289,7 @@ class VMLinker(LocalLinker): ...@@ -1153,7 +1289,7 @@ class VMLinker(LocalLinker):
input in computed input in computed
and input not in fgraph.outputs and input not in fgraph.outputs
and node == last_user[input] and node == last_user[input]
and input not in reallocated_info and input not in reallocated_vars
): ):
clear_after_this_thunk.append(storage_map[input]) clear_after_this_thunk.append(storage_map[input])
post_thunk_clear.append(clear_after_this_thunk) post_thunk_clear.append(clear_after_this_thunk)
......
...@@ -1360,6 +1360,9 @@ def pydotprint( ...@@ -1360,6 +1360,9 @@ def pydotprint(
# it, we must copy it. # it, we must copy it.
outputs = list(outputs) outputs = list(outputs)
if isinstance(fct, Function): if isinstance(fct, Function):
# TODO: Get rid of all this `expanded_inputs` nonsense and use
# `fgraph.update_mapping`
function_inputs = zip(fct.maker.expanded_inputs, fgraph.inputs) function_inputs = zip(fct.maker.expanded_inputs, fgraph.inputs)
for i, fg_ii in reversed(list(function_inputs)): for i, fg_ii in reversed(list(function_inputs)):
if i.update is not None: if i.update is not None:
......
...@@ -15,6 +15,7 @@ from aesara.configdefaults import config ...@@ -15,6 +15,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.graph.opt import OpKeyOptimizer, PatternSub from aesara.graph.opt import OpKeyOptimizer, PatternSub
from aesara.graph.utils import MissingInputError from aesara.graph.utils import MissingInputError
from aesara.link.vm import VMLinker
from aesara.tensor.math import dot from aesara.tensor.math import dot
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh from aesara.tensor.math import tanh
...@@ -227,19 +228,36 @@ class TestFunction: ...@@ -227,19 +228,36 @@ class TestFunction:
# got multiple values for keyword argument 'x' # got multiple values for keyword argument 'x'
f(5.0, x=9) f(5.0, x=9)
def test_state_access(self): @pytest.mark.parametrize(
a = scalar() # the a is for 'anonymous' (un-named). "mode",
[
Mode(
linker=VMLinker(allow_gc=True, use_cloop=False, c_thunks=False),
optimizer="fast_compile",
),
Mode(
linker=VMLinker(allow_gc=True, use_cloop=False, c_thunks=False),
optimizer="fast_run",
),
Mode(linker="cvm", optimizer="fast_compile"),
Mode(linker="cvm", optimizer="fast_run"),
],
)
def test_state_access(self, mode):
a = scalar()
x, s = scalars("xs") x, s = scalars("xs")
f = function( f = function(
[x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)], [x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)],
s + a * x, s + a * x,
mode=mode,
) )
assert f[a] == 1.0 assert f[a] == 1.0
assert f[s] == 0.0 assert f[s] == 0.0
assert f(3.0) == 3.0 assert f(3.0) == 3.0
assert f[s] == 3.0
assert f(3.0, a=2.0) == 9.0 # 3.0 + 2*3.0 assert f(3.0, a=2.0) == 9.0 # 3.0 + 2*3.0
assert ( assert (
......
...@@ -15,7 +15,7 @@ from aesara.ifelse import ifelse ...@@ -15,7 +15,7 @@ from aesara.ifelse import ifelse
from aesara.link.c.basic import OpWiseCLinker from aesara.link.c.basic import OpWiseCLinker
from aesara.link.c.exceptions import MissingGXX from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import map_storage from aesara.link.utils import map_storage
from aesara.link.vm import VM, Loop, LoopGC, Stack, VMLinker from aesara.link.vm import VM, Loop, Stack, VMLinker
from aesara.tensor.math import cosh, tanh from aesara.tensor.math import cosh, tanh
from aesara.tensor.type import lscalar, scalar, scalars, vector, vectors from aesara.tensor.type import lscalar, scalar, scalars, vector, vectors
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant
...@@ -307,11 +307,6 @@ class RunOnce(Op): ...@@ -307,11 +307,6 @@ class RunOnce(Op):
def test_vm_gc(): def test_vm_gc():
# This already caused a bug in the trunk of Aesara.
#
# The bug was introduced in the trunk on July 5th, 2012 and fixed on
# July 30th.
x = vector() x = vector()
p = RunOnce()(x) p = RunOnce()(x)
mode = Mode(linker=VMLinker(lazy=True)) mode = Mode(linker=VMLinker(lazy=True))
...@@ -445,7 +440,7 @@ def test_VM_exception(): ...@@ -445,7 +440,7 @@ def test_VM_exception():
SomeVM(fg, fg.apply_nodes, [], []) SomeVM(fg, fg.apply_nodes, [], [])
def test_LoopGC_exception(): def test_Loop_exception():
a = scalar() a = scalar()
fg = FunctionGraph(outputs=[SomeOp()(a)]) fg = FunctionGraph(outputs=[SomeOp()(a)])
...@@ -460,9 +455,99 @@ def test_LoopGC_exception(): ...@@ -460,9 +455,99 @@ def test_LoopGC_exception():
for k in storage_map: for k in storage_map:
compute_map[k] = [k.owner is None] compute_map[k] = [k.owner is None]
thunks = [ thunks = [node.op.make_thunk(node, storage_map, compute_map, []) for node in nodes]
node.op.make_thunk(node, storage_map, compute_map, True) for node in nodes
]
with pytest.raises(ValueError, match="`nodes`, `thunks` and `post_thunk_clear`.*"): with pytest.raises(ValueError, match="`nodes`, `thunks` and `post_thunk_clear`.*"):
LoopGC(fg, fg.apply_nodes, thunks, [], []) Loop(
fg,
fg.apply_nodes,
thunks,
[],
storage_map,
input_storage,
output_storage,
{},
[],
)
def test_Loop_updates():
a = scalar("a")
a_plus_1 = a + 1
fg = FunctionGraph(outputs=[a, a_plus_1], clone=False)
nodes = fg.toposort()
input_storage, output_storage, storage_map = map_storage(
fg, nodes, None, None, None
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = [node.op.make_thunk(node, storage_map, compute_map, []) for node in nodes]
assert a in storage_map
update_vars = {a: a_plus_1}
loop_vm = Loop(
fg,
fg.apply_nodes,
thunks,
[],
storage_map,
input_storage,
output_storage,
update_vars,
)
storage_map[a][0] = np.array(1.0, dtype=config.floatX)
res = loop_vm()
assert res == [np.array(1.0), np.array(2.0)]
assert storage_map[a][0] == np.array(2.0)
def test_Stack_updates():
a = scalar("a")
a_plus_1 = a + 1
fg = FunctionGraph(outputs=[a, a_plus_1], clone=False)
nodes = fg.toposort()
input_storage, output_storage, storage_map = map_storage(
fg, nodes, None, None, None
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = [node.op.make_thunk(node, storage_map, compute_map, []) for node in nodes]
assert a in storage_map
update_vars = {a: a_plus_1}
stack_vm = Stack(
fg,
fg.apply_nodes,
thunks,
[],
storage_map,
input_storage,
output_storage,
update_vars,
compute_map,
False,
)
storage_map[a][0] = np.array(1.0, dtype=config.floatX)
res = stack_vm()
assert res == [np.array(1.0), np.array(2.0)]
assert storage_map[a][0] == np.array(2.0)
...@@ -1482,6 +1482,14 @@ class TestScan: ...@@ -1482,6 +1482,14 @@ class TestScan:
@pytest.mark.slow @pytest.mark.slow
def test_grad_multiple_outs_taps_backwards(self): def test_grad_multiple_outs_taps_backwards(self):
"""
This test is special because when the inner-graph compilation is set to
"fast_compile", and the `Loop` `VM` is used, its inner-graph will
reallocate storage cells, which is a good test for correct, direct
storage use in `Scan.perform`.
TODO: Create a much more direct test.
"""
n = 5 n = 5
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
vW_in2 = asarrayX(rng.uniform(-0.2, 0.2, size=(2,))) vW_in2 = asarrayX(rng.uniform(-0.2, 0.2, size=(2,)))
...@@ -1591,6 +1599,7 @@ class TestScan: ...@@ -1591,6 +1599,7 @@ class TestScan:
) )
def reset_rng_fn(fn, *args): def reset_rng_fn(fn, *args):
# TODO: Get rid of all this `expanded_inputs` nonsense
for idx, arg in enumerate(fn.maker.expanded_inputs): for idx, arg in enumerate(fn.maker.expanded_inputs):
if arg.value and isinstance(arg.value.data, np.random.Generator): if arg.value and isinstance(arg.value.data, np.random.Generator):
obj = fn.maker.expanded_inputs[idx].value obj = fn.maker.expanded_inputs[idx].value
...@@ -1673,6 +1682,7 @@ class TestScan: ...@@ -1673,6 +1682,7 @@ class TestScan:
) )
def reset_rng_fn(fn, *args): def reset_rng_fn(fn, *args):
# TODO: Get rid of all this `expanded_inputs` nonsense
for idx, arg in enumerate(fn.maker.expanded_inputs): for idx, arg in enumerate(fn.maker.expanded_inputs):
if arg.value and isinstance(arg.value.data, np.random.Generator): if arg.value and isinstance(arg.value.data, np.random.Generator):
obj = fn.maker.expanded_inputs[idx].value obj = fn.maker.expanded_inputs[idx].value
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论