提交 b7589469 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor VMs and add variable update functionality to all subclasses

All VM implementations now perform variable updates themselves. This leaves some now redundant update code in `Function`, but it also removes some from `Scan.perform`.
上级 0dd73dc4
...@@ -2152,6 +2152,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2152,6 +2152,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
fgraph.name = name fgraph.name = name
self.indices = indices self.indices = indices
self.inputs = inputs self.inputs = inputs
# TODO: Get rid of all this `expanded_inputs` nonsense
self.expanded_inputs = inputs self.expanded_inputs = inputs
self.outputs = outputs self.outputs = outputs
self.unpack_single = unpack_single self.unpack_single = unpack_single
......
...@@ -110,6 +110,9 @@ def fgraph_updated_vars(fgraph, expanded_inputs): ...@@ -110,6 +110,9 @@ def fgraph_updated_vars(fgraph, expanded_inputs):
Reconstruct the full "updates" dictionary, mapping from FunctionGraph input Reconstruct the full "updates" dictionary, mapping from FunctionGraph input
variables to the fgraph outputs that will replace their values. variables to the fgraph outputs that will replace their values.
TODO: Get rid of all this `expanded_inputs` nonsense and use
only `fgraph.update_mapping`.
Returns Returns
------- -------
dict variable -> variable dict variable -> variable
...@@ -555,6 +558,7 @@ class Function: ...@@ -555,6 +558,7 @@ class Function:
self._value = ValueAttribute() self._value = ValueAttribute()
self._container = ContainerAttribute() self._container = ContainerAttribute()
# TODO: Get rid of all this `expanded_inputs` nonsense
assert len(self.maker.expanded_inputs) == len(self.input_storage) assert len(self.maker.expanded_inputs) == len(self.input_storage)
# This is used only when `fn.need_update_inputs` is `False`, because # This is used only when `fn.need_update_inputs` is `False`, because
...@@ -1048,6 +1052,8 @@ class Function: ...@@ -1048,6 +1052,8 @@ class Function:
# WARNING: This circumvents the 'readonly' attribute in x # WARNING: This circumvents the 'readonly' attribute in x
o_container.storage[0] = None o_container.storage[0] = None
# TODO: Get rid of this and `expanded_inputs`, since all the VMs now
# perform the updates themselves
if getattr(self.fn, "need_update_inputs", True): if getattr(self.fn, "need_update_inputs", True):
# Update the inputs that have an update function # Update the inputs that have an update function
for input, storage in reversed( for input, storage in reversed(
...@@ -1565,11 +1571,15 @@ class FunctionMaker: ...@@ -1565,11 +1571,15 @@ class FunctionMaker:
self.linker = linker.accept(fgraph, profile=profile) self.linker = linker.accept(fgraph, profile=profile)
if hasattr(linker, "accept_var_updates"): if hasattr(linker, "accept_var_updates"):
# hacky thing so VMLinker knows about updates # TODO: This is a hack that makes `VMLinker` aware of updates;
# clean this up.
self.linker.accept_var_updates(fgraph_updated_vars(fgraph, inputs)) self.linker.accept_var_updates(fgraph_updated_vars(fgraph, inputs))
fgraph.name = name fgraph.name = name
self.indices = indices self.indices = indices
self.inputs = inputs self.inputs = inputs
# TODO: Get rid of all this `expanded_inputs` nonsense
self.expanded_inputs = inputs self.expanded_inputs = inputs
self.outputs = outputs self.outputs = outputs
self.unpack_single = unpack_single self.unpack_single = unpack_single
......
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
import aesara import aesara
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
from aesara.link.utils import get_destroy_dependencies
logger = logging.getLogger("aesara.compile.profiling") logger = logging.getLogger("aesara.compile.profiling")
...@@ -1035,12 +1036,14 @@ class ProfileStats: ...@@ -1035,12 +1036,14 @@ class ProfileStats:
if isinstance(val, Constant): if isinstance(val, Constant):
compute_map[val][0] = 1 compute_map[val][0] = 1
destroy_dependencies = get_destroy_dependencies(fgraph)
# Initial executable_nodes # Initial executable_nodes
executable_nodes = set() executable_nodes = set()
for var in fgraph.inputs: for var in fgraph.inputs:
for c, _ in fgraph.clients[var]: for c, _ in fgraph.clients[var]:
if c != "output": if c != "output":
deps = c.inputs + c.destroy_dependencies deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps): if all(compute_map[v][0] for v in deps):
executable_nodes.add(c) executable_nodes.add(c)
...@@ -1163,10 +1166,12 @@ class ProfileStats: ...@@ -1163,10 +1166,12 @@ class ProfileStats:
# smaller, stop this iteration, move to next node # smaller, stop this iteration, move to next node
done_dict[frozen_set] = max_mem_count done_dict[frozen_set] = max_mem_count
destroy_dependencies = get_destroy_dependencies(fgraph)
for var in node.outputs: for var in node.outputs:
for c, _ in fgraph.clients[var]: for c, _ in fgraph.clients[var]:
if c != "output": if c != "output":
deps = c.inputs + c.destroy_dependencies deps = c.inputs + destroy_dependencies[c]
if all(compute_map[v][0] for v in deps): if all(compute_map[v][0] for v in deps):
new_exec_nodes.add(c) new_exec_nodes.add(c)
......
...@@ -1186,10 +1186,10 @@ def add_vm_configvars(): ...@@ -1186,10 +1186,10 @@ def add_vm_configvars():
config.add( config.add(
"vm__lazy", "vm__lazy",
"Useful only for the vm linkers. When lazy is None," "Useful only for the VM Linkers. When lazy is None,"
" auto detect if lazy evaluation is needed and use the appropriate" " auto detect if lazy evaluation is needed and use the appropriate"
" version. If lazy is True/False, force the version used between" " version. If the C loop isn't being used and lazy is True, use "
" Loop/LoopGC and Stack.", "the Stack VM; otherwise, use the Loop VM.",
ConfigParam("None", apply=_filter_vm_lazy), ConfigParam("None", apply=_filter_vm_lazy),
in_c_key=False, in_c_key=False,
) )
......
...@@ -4,7 +4,7 @@ import re ...@@ -4,7 +4,7 @@ import re
import sys import sys
import traceback import traceback
import warnings import warnings
from collections import Counter from collections import Counter, defaultdict
from keyword import iskeyword from keyword import iskeyword
from operator import itemgetter from operator import itemgetter
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
...@@ -793,3 +793,22 @@ def {fgraph_name}({", ".join(fgraph_input_names)}): ...@@ -793,3 +793,22 @@ def {fgraph_name}({", ".join(fgraph_input_names)}):
) )
return fgraph_def return fgraph_def
def get_destroy_dependencies(fgraph: FunctionGraph) -> Dict[Apply, List[Variable]]:
"""Construct a ``dict`` of nodes to variables that are implicit dependencies induced by `Op.destroy_map` and `Op.view_map`
These variable dependencies are in contrast to each node's inputs, which
are _explicit_ dependencies.
The variables in the values of this ``dict`` would be impossible to compute
after the current key nodes are evaluated, because node.thunk() is going to
destroy a common input variable needed by whatever node owns each variable
in destroy_dependencies.
"""
order = fgraph.orderings()
destroy_dependencies = defaultdict(lambda: [])
for node in fgraph.apply_nodes:
for prereq in order.get(node, []):
destroy_dependencies[node].extend(prereq.outputs)
return destroy_dependencies
差异被折叠。
...@@ -1360,6 +1360,9 @@ def pydotprint( ...@@ -1360,6 +1360,9 @@ def pydotprint(
# it, we must copy it. # it, we must copy it.
outputs = list(outputs) outputs = list(outputs)
if isinstance(fct, Function): if isinstance(fct, Function):
# TODO: Get rid of all this `expanded_inputs` nonsense and use
# `fgraph.update_mapping`
function_inputs = zip(fct.maker.expanded_inputs, fgraph.inputs) function_inputs = zip(fct.maker.expanded_inputs, fgraph.inputs)
for i, fg_ii in reversed(list(function_inputs)): for i, fg_ii in reversed(list(function_inputs)):
if i.update is not None: if i.update is not None:
......
...@@ -15,6 +15,7 @@ from aesara.configdefaults import config ...@@ -15,6 +15,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Constant from aesara.graph.basic import Constant
from aesara.graph.opt import OpKeyOptimizer, PatternSub from aesara.graph.opt import OpKeyOptimizer, PatternSub
from aesara.graph.utils import MissingInputError from aesara.graph.utils import MissingInputError
from aesara.link.vm import VMLinker
from aesara.tensor.math import dot from aesara.tensor.math import dot
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import tanh from aesara.tensor.math import tanh
...@@ -227,19 +228,36 @@ class TestFunction: ...@@ -227,19 +228,36 @@ class TestFunction:
# got multiple values for keyword argument 'x' # got multiple values for keyword argument 'x'
f(5.0, x=9) f(5.0, x=9)
def test_state_access(self): @pytest.mark.parametrize(
a = scalar() # the a is for 'anonymous' (un-named). "mode",
[
Mode(
linker=VMLinker(allow_gc=True, use_cloop=False, c_thunks=False),
optimizer="fast_compile",
),
Mode(
linker=VMLinker(allow_gc=True, use_cloop=False, c_thunks=False),
optimizer="fast_run",
),
Mode(linker="cvm", optimizer="fast_compile"),
Mode(linker="cvm", optimizer="fast_run"),
],
)
def test_state_access(self, mode):
a = scalar()
x, s = scalars("xs") x, s = scalars("xs")
f = function( f = function(
[x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)], [x, In(a, value=1.0, name="a"), In(s, value=0.0, update=s + a * x)],
s + a * x, s + a * x,
mode=mode,
) )
assert f[a] == 1.0 assert f[a] == 1.0
assert f[s] == 0.0 assert f[s] == 0.0
assert f(3.0) == 3.0 assert f(3.0) == 3.0
assert f[s] == 3.0
assert f(3.0, a=2.0) == 9.0 # 3.0 + 2*3.0 assert f(3.0, a=2.0) == 9.0 # 3.0 + 2*3.0
assert ( assert (
......
...@@ -15,7 +15,7 @@ from aesara.ifelse import ifelse ...@@ -15,7 +15,7 @@ from aesara.ifelse import ifelse
from aesara.link.c.basic import OpWiseCLinker from aesara.link.c.basic import OpWiseCLinker
from aesara.link.c.exceptions import MissingGXX from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import map_storage from aesara.link.utils import map_storage
from aesara.link.vm import VM, Loop, LoopGC, Stack, VMLinker from aesara.link.vm import VM, Loop, Stack, VMLinker
from aesara.tensor.math import cosh, tanh from aesara.tensor.math import cosh, tanh
from aesara.tensor.type import lscalar, scalar, scalars, vector, vectors from aesara.tensor.type import lscalar, scalar, scalars, vector, vectors
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant
...@@ -307,11 +307,6 @@ class RunOnce(Op): ...@@ -307,11 +307,6 @@ class RunOnce(Op):
def test_vm_gc(): def test_vm_gc():
# This already caused a bug in the trunk of Aesara.
#
# The bug was introduced in the trunk on July 5th, 2012 and fixed on
# July 30th.
x = vector() x = vector()
p = RunOnce()(x) p = RunOnce()(x)
mode = Mode(linker=VMLinker(lazy=True)) mode = Mode(linker=VMLinker(lazy=True))
...@@ -445,7 +440,7 @@ def test_VM_exception(): ...@@ -445,7 +440,7 @@ def test_VM_exception():
SomeVM(fg, fg.apply_nodes, [], []) SomeVM(fg, fg.apply_nodes, [], [])
def test_LoopGC_exception(): def test_Loop_exception():
a = scalar() a = scalar()
fg = FunctionGraph(outputs=[SomeOp()(a)]) fg = FunctionGraph(outputs=[SomeOp()(a)])
...@@ -460,9 +455,99 @@ def test_LoopGC_exception(): ...@@ -460,9 +455,99 @@ def test_LoopGC_exception():
for k in storage_map: for k in storage_map:
compute_map[k] = [k.owner is None] compute_map[k] = [k.owner is None]
thunks = [ thunks = [node.op.make_thunk(node, storage_map, compute_map, []) for node in nodes]
node.op.make_thunk(node, storage_map, compute_map, True) for node in nodes
]
with pytest.raises(ValueError, match="`nodes`, `thunks` and `post_thunk_clear`.*"): with pytest.raises(ValueError, match="`nodes`, `thunks` and `post_thunk_clear`.*"):
LoopGC(fg, fg.apply_nodes, thunks, [], []) Loop(
fg,
fg.apply_nodes,
thunks,
[],
storage_map,
input_storage,
output_storage,
{},
[],
)
def test_Loop_updates():
a = scalar("a")
a_plus_1 = a + 1
fg = FunctionGraph(outputs=[a, a_plus_1], clone=False)
nodes = fg.toposort()
input_storage, output_storage, storage_map = map_storage(
fg, nodes, None, None, None
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = [node.op.make_thunk(node, storage_map, compute_map, []) for node in nodes]
assert a in storage_map
update_vars = {a: a_plus_1}
loop_vm = Loop(
fg,
fg.apply_nodes,
thunks,
[],
storage_map,
input_storage,
output_storage,
update_vars,
)
storage_map[a][0] = np.array(1.0, dtype=config.floatX)
res = loop_vm()
assert res == [np.array(1.0), np.array(2.0)]
assert storage_map[a][0] == np.array(2.0)
def test_Stack_updates():
a = scalar("a")
a_plus_1 = a + 1
fg = FunctionGraph(outputs=[a, a_plus_1], clone=False)
nodes = fg.toposort()
input_storage, output_storage, storage_map = map_storage(
fg, nodes, None, None, None
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = [node.op.make_thunk(node, storage_map, compute_map, []) for node in nodes]
assert a in storage_map
update_vars = {a: a_plus_1}
stack_vm = Stack(
fg,
fg.apply_nodes,
thunks,
[],
storage_map,
input_storage,
output_storage,
update_vars,
compute_map,
False,
)
storage_map[a][0] = np.array(1.0, dtype=config.floatX)
res = stack_vm()
assert res == [np.array(1.0), np.array(2.0)]
assert storage_map[a][0] == np.array(2.0)
...@@ -1482,6 +1482,14 @@ class TestScan: ...@@ -1482,6 +1482,14 @@ class TestScan:
@pytest.mark.slow @pytest.mark.slow
def test_grad_multiple_outs_taps_backwards(self): def test_grad_multiple_outs_taps_backwards(self):
"""
This test is special because when the inner-graph compilation is set to
"fast_compile", and the `Loop` `VM` is used, its inner-graph will
reallocate storage cells, which is a good test for correct, direct
storage use in `Scan.perform`.
TODO: Create a much more direct test.
"""
n = 5 n = 5
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
vW_in2 = asarrayX(rng.uniform(-0.2, 0.2, size=(2,))) vW_in2 = asarrayX(rng.uniform(-0.2, 0.2, size=(2,)))
...@@ -1591,6 +1599,7 @@ class TestScan: ...@@ -1591,6 +1599,7 @@ class TestScan:
) )
def reset_rng_fn(fn, *args): def reset_rng_fn(fn, *args):
# TODO: Get rid of all this `expanded_inputs` nonsense
for idx, arg in enumerate(fn.maker.expanded_inputs): for idx, arg in enumerate(fn.maker.expanded_inputs):
if arg.value and isinstance(arg.value.data, np.random.Generator): if arg.value and isinstance(arg.value.data, np.random.Generator):
obj = fn.maker.expanded_inputs[idx].value obj = fn.maker.expanded_inputs[idx].value
...@@ -1673,6 +1682,7 @@ class TestScan: ...@@ -1673,6 +1682,7 @@ class TestScan:
) )
def reset_rng_fn(fn, *args): def reset_rng_fn(fn, *args):
# TODO: Get rid of all this `expanded_inputs` nonsense
for idx, arg in enumerate(fn.maker.expanded_inputs): for idx, arg in enumerate(fn.maker.expanded_inputs):
if arg.value and isinstance(arg.value.data, np.random.Generator): if arg.value and isinstance(arg.value.data, np.random.Generator):
obj = fn.maker.expanded_inputs[idx].value obj = fn.maker.expanded_inputs[idx].value
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论