提交 b7589469 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor VMs and add variable update functionality to all subclasses

All VM implementations now perform variable updates themselves. This leaves some now redundant update code in `Function`, but it also removes some from `Scan.perform`.
上级 0dd73dc4
......@@ -2152,6 +2152,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
fgraph.name = name
self.indices = indices
self.inputs = inputs
# TODO: Get rid of all this `expanded_inputs` nonsense
self.expanded_inputs = inputs
self.outputs = outputs
self.unpack_single = unpack_single
......
......@@ -110,6 +110,9 @@ def fgraph_updated_vars(fgraph, expanded_inputs):
Reconstruct the full "updates" dictionary, mapping from FunctionGraph input
variables to the fgraph outputs that will replace their values.
TODO: Get rid of all this `expanded_inputs` nonsense and use
only `fgraph.update_mapping`.
Returns
-------
dict variable -> variable
......@@ -555,6 +558,7 @@ class Function:
self._value = ValueAttribute()
self._container = ContainerAttribute()
# TODO: Get rid of all this `expanded_inputs` nonsense
assert len(self.maker.expanded_inputs) == len(self.input_storage)
# This is used only when `fn.need_update_inputs` is `False`, because
......@@ -1048,6 +1052,8 @@ class Function:
# WARNING: This circumvents the 'readonly' attribute in x
o_container.storage[0] = None
# TODO: Get rid of this and `expanded_inputs`, since all the VMs now
# perform the updates themselves
if getattr(self.fn, "need_update_inputs", True):
# Update the inputs that have an update function
for input, storage in reversed(
......@@ -1565,11 +1571,15 @@ class FunctionMaker:
self.linker = linker.accept(fgraph, profile=profile)
if hasattr(linker, "accept_var_updates"):
# hacky thing so VMLinker knows about updates
# TODO: This is a hack that makes `VMLinker` aware of updates;
# clean this up.
self.linker.accept_var_updates(fgraph_updated_vars(fgraph, inputs))
fgraph.name = name
self.indices = indices
self.inputs = inputs
# TODO: Get rid of all this `expanded_inputs` nonsense
self.expanded_inputs = inputs
self.outputs = outputs
self.unpack_single = unpack_single
......
......@@ -22,6 +22,7 @@ import numpy as np
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.link.utils import get_destroy_dependencies
logger = logging.getLogger("aesara.compile.profiling")
......@@ -1035,12 +1036,14 @@ class ProfileStats:
if isinstance(val, Constant):
compute_map[val][0] = 1
destroy_dependencies = get_destroy_dependencies(fgraph)
# Initial executable_nodes
executable_nodes = set()
for var in fgraph.inputs:
for c, _ in fgraph.clients[var]:
if c != "output":
deps = c.inputs + c.destroy_dependencies
deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps):
executable_nodes.add(c)
......@@ -1163,10 +1166,12 @@ class ProfileStats:
# smaller, stop this iteration, move to next node
done_dict[frozen_set] = max_mem_count
destroy_dependencies = get_destroy_dependencies(fgraph)
for var in node.outputs:
for c, _ in fgraph.clients[var]:
if c != "output":
deps = c.inputs + c.destroy_dependencies
deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps):
new_exec_nodes.add(c)
......
......@@ -1186,10 +1186,10 @@ def add_vm_configvars():
config.add(
"vm__lazy",
"Useful only for the vm linkers. When lazy is None,"
"Useful only for the VM Linkers. When lazy is None,"
" auto detect if lazy evaluation is needed and use the appropriate"
" version. If lazy is True/False, force the version used between"
" Loop/LoopGC and Stack.",
" version. If the C loop isn't being used and lazy is True, use "
"the Stack VM; otherwise, use the Loop VM.",
ConfigParam("None", apply=_filter_vm_lazy),
in_c_key=False,
)
......
......@@ -4,7 +4,7 @@ import re
import sys
import traceback
import warnings
from collections import Counter
from collections import Counter, defaultdict
from keyword import iskeyword
from operator import itemgetter
from tempfile import NamedTemporaryFile
......@@ -793,3 +793,22 @@ def {fgraph_name}({", ".join(fgraph_input_names)}):
)
return fgraph_def
def get_destroy_dependencies(fgraph: FunctionGraph) -> Dict[Apply, List[Variable]]:
"""Construct a ``dict`` of nodes to variables that are implicit dependencies induced by `Op.destroy_map` and `Op.view_map`
These variable dependencies are in contrast to each node's inputs, which
are _explicit_ dependencies.
The variables in the values of this ``dict`` would be impossible to compute
after the current key nodes are evaluated, because node.thunk() is going to
destroy a common input variable needed by whatever node owns each variable
in destroy_dependencies.
"""
order = fgraph.orderings()
destroy_dependencies = defaultdict(lambda: [])
for node in fgraph.apply_nodes:
for prereq in order.get(node, []):
destroy_dependencies[node].extend(prereq.outputs)
return destroy_dependencies
......@@ -5,33 +5,77 @@ A VM is not actually different from a Linker, we just decided
VM was a better name at some point.
"""
import logging
import platform
import sys
import time
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from itertools import zip_longest
from typing import (
TYPE_CHECKING,
Any,
DefaultDict,
Dict,
List,
Optional,
Sequence,
Tuple,
)
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.graph.basic import Apply, Constant, Variable
from aesara.link.basic import Container, LocalLinker
from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import gc_helper, map_storage, raise_with_op
logger = logging.getLogger(__name__)
from aesara.link.utils import (
gc_helper,
get_destroy_dependencies,
map_storage,
raise_with_op,
)
if TYPE_CHECKING:
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import (
BasicThunkType,
ComputeMapType,
StorageCellType,
StorageMapType,
)
def calculate_reallocate_info(
order: Sequence[Apply],
fgraph: "FunctionGraph",
storage_map: "StorageMapType",
compute_map_re: "ComputeMapType",
dependencies: Dict[Variable, List[Variable]],
) -> Dict[Variable, List[Variable]]:
"""Finds pairs of computed variables that can share a storage cell.
This apparently reduces memory allocations, but its scope is very limited
(e.g. only scalars, only used by the Python VMs without lazy computations).
Parameters
----------
order
List of nodes in compute order.
fgraph
The `FunctionGraph`.
storage_map
Map from variables to their storage cells.
compute_map_re
Reallocation map. TODO
dependencies
Map from variables to the variables that depend on them.
def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, dependencies):
"""
WRITEME : explain the parameters
"""
reallocated_info = {}
viewed_by = {}
viewed_by: Dict[Variable, List[Variable]] = {}
for var in fgraph.variables:
viewed_by[var] = []
view_of = {}
view_of: Dict[Variable, Variable] = {}
pre_allocated = set()
allocated = set()
......@@ -43,20 +87,20 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
idx_o = 0
for out in node.outputs:
for var in node.outputs:
compute_map_re[var][0] = 1
compute_map_re[var][0] = True
ins = None
if dmap and idx_o in dmap:
idx_v = dmap[idx_o]
assert len(idx_v) == 1, (
"Here we only support the possibility" " to destroy one input"
)
assert (
len(idx_v) == 1
), "Here we only support the possibility to destroy one input"
ins = node.inputs[idx_v[0]]
if vmap and idx_o in vmap:
assert ins is None
idx_v = vmap[idx_o]
assert len(idx_v) == 1, (
"Here we only support the possibility" " to view one input"
)
assert (
len(idx_v) == 1
), "Here we only support the possibility to view one input"
ins = node.inputs[idx_v[0]]
if ins is not None:
assert isinstance(ins, Variable)
......@@ -68,24 +112,24 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
for ins in node.inputs:
assert not (ins in view_of and viewed_by[ins])
if (
getattr(ins, "ndim", None) == 0
getattr(ins.type, "ndim", None) == 0
and not storage_map[ins][0]
and ins not in fgraph.outputs
and ins.owner
and all(compute_map_re[v][0] for v in dependencies.get(ins, []))
and ins not in allocated
):
# Constant Memory cannot be changed
# Constant memory cannot be changed
# Constant and shared variables' storage_map value is not empty
reuse_out = None
if ins not in view_of and not viewed_by.get(ins, []):
# where gc
for i in range(idx + 1, len(order)):
if reuse_out is not None:
break
break # type: ignore
for out in order[i].outputs:
if (
getattr(out, "ndim", None) == 0
getattr(out.type, "ndim", None) == 0
and out not in pre_allocated
and out.type.in_same_class(ins.type)
):
......@@ -108,7 +152,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
break
for out in order[i].outputs:
if (
getattr(out, "ndim", None) == 0
getattr(out.type, "ndim", None) == 0
and out not in pre_allocated
and (out.type.in_same_class(ins.type))
):
......@@ -136,17 +180,6 @@ class VM(ABC):
advantage of lazy computation, although they still produce the correct
output for lazy nodes.
Parameters
----------
fgraph : FunctionGraph
The `FunctionGraph` associated with `nodes` and `thunks`.
nodes
A list of nodes in toposort order.
thunks
A list of thunks to execute those nodes, in toposort order.
pre_call_clear
A list of containers to empty at the beginning of each call.
Attributes
----------
call_counts
......@@ -165,7 +198,27 @@ class VM(ABC):
"""
def __init__(self, fgraph, nodes, thunks, pre_call_clear):
need_update_inputs = True
def __init__(
self,
fgraph: "FunctionGraph",
nodes: List[Apply],
thunks: List["BasicThunkType"],
pre_call_clear: List["StorageCellType"],
):
r"""
Parameters
----------
fgraph
The `FunctionGraph` associated with `nodes` and `thunks`.
nodes
A list of nodes in toposort order.
thunks
A list of thunks to execute those nodes, in toposort order.
pre_call_clear
A list of containers to empty at the beginning of each call.
"""
if len(nodes) != len(thunks):
raise ValueError("`nodes` and `thunks` must be the same length")
......@@ -178,11 +231,6 @@ class VM(ABC):
self.call_times = [0] * len(nodes)
self.time_thunks = False
# This variable (self.need_update_inputs) is overshadowed by
# CLazyLinker in CVM which has an attribute of the same name that
# defaults to 0 (aka False).
self.need_update_inputs = True
@abstractmethod
def __call__(self):
r"""Run the virtual machine.
......@@ -232,53 +280,109 @@ class VM(ABC):
self.call_counts[i] = 0
class Loop(VM):
"""
Unconditional start-to-finish program execution in Python.
No garbage collection is allowed on intermediate results.
class UpdatingVM(VM):
"""A `VM` that performs updates on its graph's inputs."""
"""
need_update_inputs = False
allow_gc = False
def __init__(
self,
fgraph,
nodes,
thunks,
pre_call_clear,
storage_map: "StorageMapType",
input_storage: List["StorageCellType"],
output_storage: List["StorageCellType"],
update_vars: Dict[Variable, Variable],
):
r"""
Parameters
----------
storage_map
A ``dict`` mapping `Variable`\s to single-element lists where a
computed value for each `Variable` may be found.
input_storage
Storage cells for each input.
output_storage
Storage cells for each output.
update_vars
A ``dict`` from input to output variables that specify
output-to-input in-place storage updates that occur after
evaluation of the entire graph (i.e. all the thunks).
"""
super().__init__(fgraph, nodes, thunks, pre_call_clear)
def __call__(self):
if self.time_thunks:
for cont in self.pre_call_clear:
cont[0] = None
try:
for i, (thunk, node) in enumerate(zip(self.thunks, self.nodes)):
t0 = time.time()
thunk()
t1 = time.time()
self.call_counts[i] += 1
self.call_times[i] += t1 - t0
except Exception:
raise_with_op(self.fgraph, node, thunk)
else:
for cont in self.pre_call_clear:
cont[0] = None
try:
for thunk, node in zip(self.thunks, self.nodes):
thunk()
except Exception:
raise_with_op(self.fgraph, node, thunk)
self.storage_map = storage_map
self.input_storage = input_storage
self.output_storage = output_storage
self.inp_storage_and_out_idx = tuple(
(inp_storage, self.fgraph.outputs.index(update_vars[inp]))
for inp, inp_storage in zip(self.fgraph.inputs, self.input_storage)
if inp in update_vars
)
def perform_updates(self) -> List[Any]:
"""Perform the output-to-input updates and return the output values."""
class LoopGC(VM):
# The outputs need to be collected *before* the updates that follow
outputs = [cell[0] for cell in self.output_storage]
for inp_storage, out_idx in self.inp_storage_and_out_idx:
inp_storage[0] = outputs[out_idx]
return outputs
class Loop(UpdatingVM):
"""Unconditional start-to-finish program execution in Python.
Garbage collection is possible on intermediate results.
Garbage collection is possible on intermediate results when the
`post_thunk_clear` constructor argument is non-``None``.
"""
def __init__(self, fgraph, nodes, thunks, pre_call_clear, post_thunk_clear):
super().__init__(fgraph, nodes, thunks, pre_call_clear)
self.post_thunk_clear = post_thunk_clear
# Some other part of Aesara query that information
self.allow_gc = True
if not (len(nodes) == len(thunks) == len(post_thunk_clear)):
raise ValueError(
"`nodes`, `thunks` and `post_thunk_clear` are not the same lengths"
)
def __init__(
self,
fgraph,
nodes,
thunks,
pre_call_clear,
storage_map,
input_storage,
output_storage,
update_vars,
post_thunk_clear: Optional[List["StorageCellType"]] = None,
):
r"""
Parameters
----------
post_thunk_clear
A list of storage cells for each thunk that should be cleared after
each thunk is evaluated. This is the "garbage collection"
functionality.
"""
super().__init__(
fgraph,
nodes,
thunks,
pre_call_clear,
storage_map,
input_storage,
output_storage,
update_vars,
)
if post_thunk_clear is not None:
if not (len(nodes) == len(thunks) == len(post_thunk_clear)):
raise ValueError(
"`nodes`, `thunks` and `post_thunk_clear` are not the same lengths"
)
# Some other part of Aesara use this information
self.allow_gc = True
self.post_thunk_clear = post_thunk_clear
else:
self.allow_gc = False
self.post_thunk_clear = []
def __call__(self):
if self.time_thunks:
......@@ -286,8 +390,8 @@ class LoopGC(VM):
cont[0] = None
try:
i = 0
for thunk, node, old_storage in zip(
self.thunks, self.nodes, self.post_thunk_clear
for thunk, node, old_storage in zip_longest(
self.thunks, self.nodes, self.post_thunk_clear, fillvalue=()
):
t0 = time.time()
thunk()
......@@ -303,8 +407,8 @@ class LoopGC(VM):
for cont in self.pre_call_clear:
cont[0] = None
try:
for thunk, node, old_storage in zip(
self.thunks, self.nodes, self.post_thunk_clear
for thunk, node, old_storage in zip_longest(
self.thunks, self.nodes, self.post_thunk_clear, fillvalue=()
):
thunk()
for old_s in old_storage:
......@@ -312,8 +416,10 @@ class LoopGC(VM):
except Exception:
raise_with_op(self.fgraph, node, thunk)
return self.perform_updates()
class Stack(VM):
class Stack(UpdatingVM):
"""Finish-to-start evaluation order of thunks.
This supports lazy evaluation of subtrees and partial computations of
......@@ -348,53 +454,53 @@ class Stack(VM):
thunks,
pre_call_clear,
storage_map,
compute_map,
allow_gc,
n_updates,
dependencies=None,
input_storage,
output_storage,
update_vars,
compute_map: "ComputeMapType",
allow_gc: bool,
dependencies: Optional[Dict[Variable, List[Variable]]] = None,
callback=None,
callback_input=None,
):
super().__init__(fgraph, nodes, thunks, pre_call_clear)
r"""
Parameters
----------
allow_gc
Determines whether or not garbage collection is performed.
dependencies
TODO
callback
TODO
callback_input
TODO
"""
super().__init__(
fgraph,
nodes,
thunks,
pre_call_clear,
storage_map,
input_storage,
output_storage,
update_vars,
)
self.update_vars = update_vars
self.compute_map = compute_map
self.allow_gc = allow_gc
self.message = ""
self.base_apply_stack = [o.owner for o in fgraph.outputs if o.owner]
self.outputs = fgraph.outputs
self.storage_map = storage_map
self.variable_shape = {} # Variable -> shape
self.variable_strides = {} # Variable -> strides
self.variable_offset = {} # Variable -> offset
self.compute_map = compute_map
self.node_idx = node_idx = {}
self.variable_shape: Dict[Variable, Any] = {} # Variable -> shape
self.variable_strides: Dict[Variable, Any] = {} # Variable -> strides
self.variable_offset: Dict[Variable, Any] = {} # Variable -> offset
node_idx = {node: i for i, node in enumerate(self.nodes)}
self.node_idx = node_idx
self.callback = callback
self.callback_input = callback_input
self.n_updates = n_updates
ords = fgraph.orderings()
for i, node in enumerate(self.nodes):
node_idx[node] = i
# XXX: inconsistent style - why modify node here rather
# than track destroy_dependencies with dictionary like
# storage_map?
#
# destroy_dependencies
# --------------------
# The destroy_dependencies is a list of variables that are implicit
# dependencies induced by destroy_map and view_map (compared to
# node.inputs which are *explicit* dependencies). The variables in
# destroy_dependencies would be impossible to compute after the
# current `node` runs, because node.thunk() is going to destroy a
# common input variable needed by whatever node owns each variable
# in destroy_depenencies.
node.destroy_dependencies = []
if node in ords:
for prereq in ords[node]:
node.destroy_dependencies += prereq.outputs
self.destroy_dependencies = get_destroy_dependencies(fgraph)
self.dependencies = dependencies
if self.allow_gc and self.dependencies is None:
......@@ -444,10 +550,14 @@ class Stack(VM):
# apply_stack contains nodes
if output_subset is not None:
first_updated = len(self.outputs) - self.n_updates
output_subset = output_subset + list(
range(first_updated, len(self.outputs))
)
# Add the outputs that are needed for the in-place updates of the
# inputs in `self.update_vars`
output_subset = list(output_subset)
for inp, out in self.update_vars.items():
out_idx = self.fgraph.outputs.index(out)
if out_idx not in output_subset:
output_subset.append(out_idx)
apply_stack = [
self.outputs[i].owner for i in output_subset if self.outputs[i].owner
]
......@@ -489,7 +599,7 @@ class Stack(VM):
current_apply = apply_stack.pop()
current_inputs = current_apply.inputs
current_outputs = current_apply.outputs
current_deps = current_inputs + current_apply.destroy_dependencies
current_deps = current_inputs + self.destroy_dependencies[current_apply]
computed_ins = all(compute_map[v][0] for v in current_deps)
computed_outs = all(compute_map[v][0] for v in current_outputs)
......@@ -671,6 +781,8 @@ class Stack(VM):
self.node_cleared_order.append(final_index)
return self.perform_updates()
class VMLinker(LocalLinker):
"""Class that satisfies the `Linker` interface by acting as a `VM` factory.
......@@ -697,7 +809,7 @@ class VMLinker(LocalLinker):
Aesara flag ``vm__lazy`` value. Then if we have a ``None`` (default) we
auto detect if lazy evaluation is needed and use the appropriate
version. If `lazy` is ``True`` or ``False``, we force the version used
between `Loop`/`LoopGC` and `Stack`.
between `Loop` and `Stack`.
c_thunks
If ``None`` or ``True``, don't change the default. If ``False``, don't
compile C code for the thunks.
......@@ -768,7 +880,7 @@ class VMLinker(LocalLinker):
TODO: change the logic to remove the reference at the end
of the call instead of the start. This will request all VM
implementation (Loop, LoopGC, Stack, CVM).__call__ to
implementation (Loop, Stack, CVM).__call__ to
return the user outputs as Function.__call__ won't be able
to find them anymore.
......@@ -804,9 +916,11 @@ class VMLinker(LocalLinker):
"""Records in the `Linker` which variables have update expressions.
It does not imply that the `Linker` will actually implement these updates
(see `need_update_inputs`). This mechanism is admittedly confusing, and
(see `VM.need_update_inputs`). This mechanism is admittedly confusing, and
it could use some cleaning up. The base `Linker` object should probably
go away completely.
TODO: Remove this after refactoring the `VM`/`Linker` interfaces.
"""
self.updated_vars = updated_vars
......@@ -853,6 +967,40 @@ class VMLinker(LocalLinker):
dependencies[k] += ls
return dependencies
def reduce_storage_allocations(
self, storage_map: "StorageMapType", order: Sequence[Apply]
) -> Tuple[Variable, ...]:
"""Reuse storage cells in a storage map.
`storage_map` is updated in-place.
When this feature is used, `storage_map` will no longer have a
one-to-one mapping with the original variables, because--for
example--some outputs may share storage with intermediate values.
Returns
-------
A tuple of the variables that were reallocated.
"""
# Collect Reallocation Info
compute_map_re: DefaultDict[Variable, List[bool]] = defaultdict(lambda: [False])
for var in self.fgraph.inputs:
compute_map_re[var][0] = True
if getattr(self.fgraph.profile, "dependencies", None):
dependencies = self.fgraph.profile.dependencies
else:
dependencies = self.compute_gc_dependencies(storage_map)
reallocated_info: Dict[Variable, List[Variable]] = calculate_reallocate_info(
order, self.fgraph, storage_map, compute_map_re, dependencies
)
for pair in reallocated_info.values():
storage_map[pair[1]] = storage_map[pair[0]]
return tuple(reallocated_info.keys())
def make_vm(
self,
nodes,
......@@ -883,12 +1031,12 @@ class VMLinker(LocalLinker):
if self.use_cloop and (
self.callback is not None or self.callback_input is not None
):
logger.warning("CVM does not support callback, using Stack VM.")
warnings.warn("CVM does not support callback, using Stack VM.")
if self.use_cloop and config.profile_memory:
warnings.warn("CVM does not support memory profile, using Stack VM.")
warnings.warn("CVM does not support memory profiling, using Stack VM.")
if not self.use_cloop and self.allow_partial_eval:
warnings.warn(
"LoopGC does not support partial evaluation, " "using Stack VM."
"Loop VM does not support partial evaluation, using Stack VM."
)
# Needed for allow_gc=True, profiling and storage_map reuse
deps = self.compute_gc_dependencies(storage_map)
......@@ -898,9 +1046,11 @@ class VMLinker(LocalLinker):
thunks,
pre_call_clear,
storage_map,
input_storage,
output_storage,
updated_vars,
compute_map,
self.allow_gc,
len(updated_vars),
dependencies=deps,
callback=self.callback,
callback_input=self.callback_input,
......@@ -1021,7 +1171,7 @@ class VMLinker(LocalLinker):
if platform.python_implementation() == "CPython" and c0 != sys.getrefcount(
node_n_inputs
):
logger.warning(
warnings.warn(
"Detected reference count inconsistency after CVM construction"
)
else:
......@@ -1032,21 +1182,17 @@ class VMLinker(LocalLinker):
lazy = any(th.lazy for th in thunks)
if not lazy:
# there is no conditional in the graph
if self.allow_gc:
vm = LoopGC(
self.fgraph,
nodes,
thunks,
pre_call_clear,
post_thunk_clear,
)
else:
vm = Loop(
self.fgraph,
nodes,
thunks,
pre_call_clear,
)
vm = Loop(
self.fgraph,
nodes,
thunks,
pre_call_clear,
storage_map,
input_storage,
output_storage,
updated_vars,
post_thunk_clear if self.allow_gc else None,
)
else:
# Needed when allow_gc=True and profiling
deps = self.compute_gc_dependencies(storage_map)
......@@ -1056,9 +1202,11 @@ class VMLinker(LocalLinker):
thunks,
pre_call_clear,
storage_map,
input_storage,
output_storage,
updated_vars,
compute_map,
self.allow_gc,
len(updated_vars),
dependencies=deps,
)
return vm
......@@ -1082,19 +1230,6 @@ class VMLinker(LocalLinker):
thunks = []
# Collect Reallocation Info
compute_map_re = defaultdict(lambda: [0])
for var in fgraph.inputs:
compute_map_re[var][0] = 1
if getattr(fgraph.profile, "dependencies", None):
dependencies = fgraph.profile.dependencies
else:
dependencies = self.compute_gc_dependencies(storage_map)
reallocated_info = calculate_reallocate_info(
order, fgraph, storage_map, compute_map_re, dependencies
)
t0 = time.time()
linker_make_thunk_time = {}
impl = None
......@@ -1140,8 +1275,9 @@ class VMLinker(LocalLinker):
or self.callback
or self.callback_input
):
for pair in reallocated_info.values():
storage_map[pair[1]] = storage_map[pair[0]]
reallocated_vars = self.reduce_storage_allocations(storage_map, order)
else:
reallocated_vars = ()
computed, last_user = gc_helper(order)
if self.allow_gc:
......@@ -1153,7 +1289,7 @@ class VMLinker(LocalLinker):
input in computed
and input not in fgraph.outputs
and node == last_user[input]
and input not in reallocated_info
and input not in reallocated_vars
):
clear_after_this_thunk.append(storage_map[input])
post_thunk_clear.append(clear_after_this_thunk)
......
......@@ -1360,6 +1360,9 @@ def pydotprint(
# it, we must copy it.
outputs = list(outputs)
if isinstance(fct, Function):
# TODO: Get rid of all this `expanded_inputs` nonsense and use
# `fgraph.update_mapping`
function_inputs = zip(fct.maker.expanded_inputs, fgraph.inputs)
for i, fg_ii in reversed(list(function_inputs)):
if i.update is not None:
......
......@@ -15,6 +15,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Constant
from aesara.graph.opt import OpKeyOptimizer, PatternSub
from aesara.graph.utils import MissingInputError
from aesara.link.vm import VMLinker
from aesara.tensor.math import dot
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh
......@@ -227,19 +228,36 @@ class TestFunction:
# got multiple values for keyword argument 'x'
f(5.0, x=9)
def test_state_access(self):
a = scalar() # the a is for 'anonymous' (un-named).
@pytest.mark.parametrize(
"mode",
[
Mode(
linker=VMLinker(allow_gc=True, use_cloop=False, c_thunks=False),
optimizer="fast_compile",
),
Mode(
linker=VMLinker(allow_gc=True, use_cloop=False, c_thunks=False),
optimizer="fast_run",
),
Mode(linker="cvm", optimizer="fast_compile"),
Mode(linker="cvm", optimizer="fast_run"),
],
)
def test_state_access(self, mode):
a = scalar()
x, s = scalars("xs")
f = function(
[x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)],
s + a * x,
mode=mode,
)
assert f[a] == 1.0
assert f[s] == 0.0
assert f(3.0) == 3.0
assert f[s] == 3.0
assert f(3.0, a=2.0) == 9.0 # 3.0 + 2*3.0
assert (
......
......@@ -15,7 +15,7 @@ from aesara.ifelse import ifelse
from aesara.link.c.basic import OpWiseCLinker
from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import map_storage
from aesara.link.vm import VM, Loop, LoopGC, Stack, VMLinker
from aesara.link.vm import VM, Loop, Stack, VMLinker
from aesara.tensor.math import cosh, tanh
from aesara.tensor.type import lscalar, scalar, scalars, vector, vectors
from aesara.tensor.var import TensorConstant
......@@ -307,11 +307,6 @@ class RunOnce(Op):
def test_vm_gc():
# This already caused a bug in the trunk of Aesara.
#
# The bug was introduced in the trunk on July 5th, 2012 and fixed on
# July 30th.
x = vector()
p = RunOnce()(x)
mode = Mode(linker=VMLinker(lazy=True))
......@@ -445,7 +440,7 @@ def test_VM_exception():
SomeVM(fg, fg.apply_nodes, [], [])
def test_LoopGC_exception():
def test_Loop_exception():
a = scalar()
fg = FunctionGraph(outputs=[SomeOp()(a)])
......@@ -460,9 +455,99 @@ def test_LoopGC_exception():
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = [
node.op.make_thunk(node, storage_map, compute_map, True) for node in nodes
]
thunks = [node.op.make_thunk(node, storage_map, compute_map, []) for node in nodes]
with pytest.raises(ValueError, match="`nodes`, `thunks` and `post_thunk_clear`.*"):
LoopGC(fg, fg.apply_nodes, thunks, [], [])
Loop(
fg,
fg.apply_nodes,
thunks,
[],
storage_map,
input_storage,
output_storage,
{},
[],
)
def test_Loop_updates():
a = scalar("a")
a_plus_1 = a + 1
fg = FunctionGraph(outputs=[a, a_plus_1], clone=False)
nodes = fg.toposort()
input_storage, output_storage, storage_map = map_storage(
fg, nodes, None, None, None
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = [node.op.make_thunk(node, storage_map, compute_map, []) for node in nodes]
assert a in storage_map
update_vars = {a: a_plus_1}
loop_vm = Loop(
fg,
fg.apply_nodes,
thunks,
[],
storage_map,
input_storage,
output_storage,
update_vars,
)
storage_map[a][0] = np.array(1.0, dtype=config.floatX)
res = loop_vm()
assert res == [np.array(1.0), np.array(2.0)]
assert storage_map[a][0] == np.array(2.0)
def test_Stack_updates():
a = scalar("a")
a_plus_1 = a + 1
fg = FunctionGraph(outputs=[a, a_plus_1], clone=False)
nodes = fg.toposort()
input_storage, output_storage, storage_map = map_storage(
fg, nodes, None, None, None
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = [node.op.make_thunk(node, storage_map, compute_map, []) for node in nodes]
assert a in storage_map
update_vars = {a: a_plus_1}
stack_vm = Stack(
fg,
fg.apply_nodes,
thunks,
[],
storage_map,
input_storage,
output_storage,
update_vars,
compute_map,
False,
)
storage_map[a][0] = np.array(1.0, dtype=config.floatX)
res = stack_vm()
assert res == [np.array(1.0), np.array(2.0)]
assert storage_map[a][0] == np.array(2.0)
......@@ -1482,6 +1482,14 @@ class TestScan:
@pytest.mark.slow
def test_grad_multiple_outs_taps_backwards(self):
"""
This test is special because when the inner-graph compilation is set to
"fast_compile", and the `Loop` `VM` is used, its inner-graph will
reallocate storage cells, which is a good test for correct, direct
storage use in `Scan.perform`.
TODO: Create a much more direct test.
"""
n = 5
rng = np.random.default_rng(utt.fetch_seed())
vW_in2 = asarrayX(rng.uniform(-0.2, 0.2, size=(2,)))
......@@ -1591,6 +1599,7 @@ class TestScan:
)
def reset_rng_fn(fn, *args):
# TODO: Get rid of all this `expanded_inputs` nonsense
for idx, arg in enumerate(fn.maker.expanded_inputs):
if arg.value and isinstance(arg.value.data, np.random.Generator):
obj = fn.maker.expanded_inputs[idx].value
......@@ -1673,6 +1682,7 @@ class TestScan:
)
def reset_rng_fn(fn, *args):
# TODO: Get rid of all this `expanded_inputs` nonsense
for idx, arg in enumerate(fn.maker.expanded_inputs):
if arg.value and isinstance(arg.value.data, np.random.Generator):
obj = fn.maker.expanded_inputs[idx].value
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论