提交 b7589469 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor VMs and add variable update functionality to all subclasses

All VM implementations now perform variable updates themselves. This leaves some now redundant update code in `Function`, but it also removes some from `Scan.perform`.
上级 0dd73dc4
......@@ -2152,6 +2152,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
fgraph.name = name
self.indices = indices
self.inputs = inputs
# TODO: Get rid of all this `expanded_inputs` nonsense
self.expanded_inputs = inputs
self.outputs = outputs
self.unpack_single = unpack_single
......
......@@ -110,6 +110,9 @@ def fgraph_updated_vars(fgraph, expanded_inputs):
Reconstruct the full "updates" dictionary, mapping from FunctionGraph input
variables to the fgraph outputs that will replace their values.
TODO: Get rid of all this `expanded_inputs` nonsense and use
only `fgraph.update_mapping`.
Returns
-------
dict variable -> variable
......@@ -555,6 +558,7 @@ class Function:
self._value = ValueAttribute()
self._container = ContainerAttribute()
# TODO: Get rid of all this `expanded_inputs` nonsense
assert len(self.maker.expanded_inputs) == len(self.input_storage)
# This is used only when `fn.need_update_inputs` is `False`, because
......@@ -1048,6 +1052,8 @@ class Function:
# WARNING: This circumvents the 'readonly' attribute in x
o_container.storage[0] = None
# TODO: Get rid of this and `expanded_inputs`, since all the VMs now
# perform the updates themselves
if getattr(self.fn, "need_update_inputs", True):
# Update the inputs that have an update function
for input, storage in reversed(
......@@ -1565,11 +1571,15 @@ class FunctionMaker:
self.linker = linker.accept(fgraph, profile=profile)
if hasattr(linker, "accept_var_updates"):
# hacky thing so VMLinker knows about updates
# TODO: This is a hack that makes `VMLinker` aware of updates;
# clean this up.
self.linker.accept_var_updates(fgraph_updated_vars(fgraph, inputs))
fgraph.name = name
self.indices = indices
self.inputs = inputs
# TODO: Get rid of all this `expanded_inputs` nonsense
self.expanded_inputs = inputs
self.outputs = outputs
self.unpack_single = unpack_single
......
......@@ -22,6 +22,7 @@ import numpy as np
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable
from aesara.link.utils import get_destroy_dependencies
logger = logging.getLogger("aesara.compile.profiling")
......@@ -1035,12 +1036,14 @@ class ProfileStats:
if isinstance(val, Constant):
compute_map[val][0] = 1
destroy_dependencies = get_destroy_dependencies(fgraph)
# Initial executable_nodes
executable_nodes = set()
for var in fgraph.inputs:
for c, _ in fgraph.clients[var]:
if c != "output":
deps = c.inputs + c.destroy_dependencies
deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps):
executable_nodes.add(c)
......@@ -1163,10 +1166,12 @@ class ProfileStats:
# smaller, stop this iteration, move to next node
done_dict[frozen_set] = max_mem_count
destroy_dependencies = get_destroy_dependencies(fgraph)
for var in node.outputs:
for c, _ in fgraph.clients[var]:
if c != "output":
deps = c.inputs + c.destroy_dependencies
deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps):
new_exec_nodes.add(c)
......
......@@ -1186,10 +1186,10 @@ def add_vm_configvars():
config.add(
"vm__lazy",
"Useful only for the vm linkers. When lazy is None,"
"Useful only for the VM Linkers. When lazy is None,"
" auto detect if lazy evaluation is needed and use the appropriate"
" version. If lazy is True/False, force the version used between"
" Loop/LoopGC and Stack.",
" version. If the C loop isn't being used and lazy is True, use "
"the Stack VM; otherwise, use the Loop VM.",
ConfigParam("None", apply=_filter_vm_lazy),
in_c_key=False,
)
......
......@@ -4,7 +4,7 @@ import re
import sys
import traceback
import warnings
from collections import Counter
from collections import Counter, defaultdict
from keyword import iskeyword
from operator import itemgetter
from tempfile import NamedTemporaryFile
......@@ -793,3 +793,22 @@ def {fgraph_name}({", ".join(fgraph_input_names)}):
)
return fgraph_def
def get_destroy_dependencies(fgraph: FunctionGraph) -> Dict[Apply, List[Variable]]:
"""Construct a ``dict`` of nodes to variables that are implicit dependencies induced by `Op.destroy_map` and `Op.view_map`
These variable dependencies are in contrast to each node's inputs, which
are _explicit_ dependencies.
The variables in the values of this ``dict`` would be impossible to compute
after the current key nodes are evaluated, because node.thunk() is going to
destroy a common input variable needed by whatever node owns each variable
in destroy_dependencies.
"""
order = fgraph.orderings()
destroy_dependencies = defaultdict(lambda: [])
for node in fgraph.apply_nodes:
for prereq in order.get(node, []):
destroy_dependencies[node].extend(prereq.outputs)
return destroy_dependencies
差异被折叠。
......@@ -1360,6 +1360,9 @@ def pydotprint(
# it, we must copy it.
outputs = list(outputs)
if isinstance(fct, Function):
# TODO: Get rid of all this `expanded_inputs` nonsense and use
# `fgraph.update_mapping`
function_inputs = zip(fct.maker.expanded_inputs, fgraph.inputs)
for i, fg_ii in reversed(list(function_inputs)):
if i.update is not None:
......
......@@ -15,6 +15,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Constant
from aesara.graph.opt import OpKeyOptimizer, PatternSub
from aesara.graph.utils import MissingInputError
from aesara.link.vm import VMLinker
from aesara.tensor.math import dot
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh
......@@ -227,19 +228,36 @@ class TestFunction:
# got multiple values for keyword argument 'x'
f(5.0, x=9)
def test_state_access(self):
a = scalar() # the a is for 'anonymous' (un-named).
@pytest.mark.parametrize(
"mode",
[
Mode(
linker=VMLinker(allow_gc=True, use_cloop=False, c_thunks=False),
optimizer="fast_compile",
),
Mode(
linker=VMLinker(allow_gc=True, use_cloop=False, c_thunks=False),
optimizer="fast_run",
),
Mode(linker="cvm", optimizer="fast_compile"),
Mode(linker="cvm", optimizer="fast_run"),
],
)
def test_state_access(self, mode):
a = scalar()
x, s = scalars("xs")
f = function(
[x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)],
s + a * x,
mode=mode,
)
assert f[a] == 1.0
assert f[s] == 0.0
assert f(3.0) == 3.0
assert f[s] == 3.0
assert f(3.0, a=2.0) == 9.0 # 3.0 + 2*3.0
assert (
......
......@@ -15,7 +15,7 @@ from aesara.ifelse import ifelse
from aesara.link.c.basic import OpWiseCLinker
from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import map_storage
from aesara.link.vm import VM, Loop, LoopGC, Stack, VMLinker
from aesara.link.vm import VM, Loop, Stack, VMLinker
from aesara.tensor.math import cosh, tanh
from aesara.tensor.type import lscalar, scalar, scalars, vector, vectors
from aesara.tensor.var import TensorConstant
......@@ -307,11 +307,6 @@ class RunOnce(Op):
def test_vm_gc():
# This already caused a bug in the trunk of Aesara.
#
# The bug was introduced in the trunk on July 5th, 2012 and fixed on
# July 30th.
x = vector()
p = RunOnce()(x)
mode = Mode(linker=VMLinker(lazy=True))
......@@ -445,7 +440,7 @@ def test_VM_exception():
SomeVM(fg, fg.apply_nodes, [], [])
def test_LoopGC_exception():
def test_Loop_exception():
a = scalar()
fg = FunctionGraph(outputs=[SomeOp()(a)])
......@@ -460,9 +455,99 @@ def test_LoopGC_exception():
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = [
node.op.make_thunk(node, storage_map, compute_map, True) for node in nodes
]
thunks = [node.op.make_thunk(node, storage_map, compute_map, []) for node in nodes]
with pytest.raises(ValueError, match="`nodes`, `thunks` and `post_thunk_clear`.*"):
LoopGC(fg, fg.apply_nodes, thunks, [], [])
Loop(
fg,
fg.apply_nodes,
thunks,
[],
storage_map,
input_storage,
output_storage,
{},
[],
)
def test_Loop_updates():
a = scalar("a")
a_plus_1 = a + 1
fg = FunctionGraph(outputs=[a, a_plus_1], clone=False)
nodes = fg.toposort()
input_storage, output_storage, storage_map = map_storage(
fg, nodes, None, None, None
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = [node.op.make_thunk(node, storage_map, compute_map, []) for node in nodes]
assert a in storage_map
update_vars = {a: a_plus_1}
loop_vm = Loop(
fg,
fg.apply_nodes,
thunks,
[],
storage_map,
input_storage,
output_storage,
update_vars,
)
storage_map[a][0] = np.array(1.0, dtype=config.floatX)
res = loop_vm()
assert res == [np.array(1.0), np.array(2.0)]
assert storage_map[a][0] == np.array(2.0)
def test_Stack_updates():
a = scalar("a")
a_plus_1 = a + 1
fg = FunctionGraph(outputs=[a, a_plus_1], clone=False)
nodes = fg.toposort()
input_storage, output_storage, storage_map = map_storage(
fg, nodes, None, None, None
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = [node.op.make_thunk(node, storage_map, compute_map, []) for node in nodes]
assert a in storage_map
update_vars = {a: a_plus_1}
stack_vm = Stack(
fg,
fg.apply_nodes,
thunks,
[],
storage_map,
input_storage,
output_storage,
update_vars,
compute_map,
False,
)
storage_map[a][0] = np.array(1.0, dtype=config.floatX)
res = stack_vm()
assert res == [np.array(1.0), np.array(2.0)]
assert storage_map[a][0] == np.array(2.0)
......@@ -1482,6 +1482,14 @@ class TestScan:
@pytest.mark.slow
def test_grad_multiple_outs_taps_backwards(self):
"""
This test is special because when the inner-graph compilation is set to
"fast_compile", and the `Loop` `VM` is used, its inner-graph will
reallocate storage cells, which is a good test for correct, direct
storage use in `Scan.perform`.
TODO: Create a much more direct test.
"""
n = 5
rng = np.random.default_rng(utt.fetch_seed())
vW_in2 = asarrayX(rng.uniform(-0.2, 0.2, size=(2,)))
......@@ -1591,6 +1599,7 @@ class TestScan:
)
def reset_rng_fn(fn, *args):
# TODO: Get rid of all this `expanded_inputs` nonsense
for idx, arg in enumerate(fn.maker.expanded_inputs):
if arg.value and isinstance(arg.value.data, np.random.Generator):
obj = fn.maker.expanded_inputs[idx].value
......@@ -1673,6 +1682,7 @@ class TestScan:
)
def reset_rng_fn(fn, *args):
# TODO: Get rid of all this `expanded_inputs` nonsense
for idx, arg in enumerate(fn.maker.expanded_inputs):
if arg.value and isinstance(arg.value.data, np.random.Generator):
obj = fn.maker.expanded_inputs[idx].value
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论