提交 beb510e6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Change the update_storage parameter in CLazyLinker

This changes the `update_storage` parameter from a list containing the input indices that are to be updated with the last N-many outputs to a tuple of tuples specifying input/output indices. Now, arbitrary output-to-input update pairings are possible, instead of forcing graphs and code to compensate for this unnecessary restriction.
上级 29b1ff7b
......@@ -116,12 +116,14 @@ def fgraph_updated_vars(fgraph, expanded_inputs):
"""
updated_vars = {}
potential_values = list(fgraph.outputs) # copy the list
if len(expanded_inputs) != len(fgraph.inputs):
raise ValueError("expanded_inputs must match len(fgraph.inputs)")
for e_input, ivar in reversed(list(zip(expanded_inputs, fgraph.inputs))):
if e_input.update is not None:
updated_vars[ivar] = potential_values.pop()
for out_idx, in_idx in fgraph.update_mapping.items():
assert expanded_inputs[in_idx].update is not None
updated_vars[fgraph.inputs[in_idx]] = fgraph.outputs[out_idx]
return updated_vars
......
......@@ -16,7 +16,7 @@ from aesara.link.c.cmodule import GCC_compiler
_logger = logging.getLogger(__file__)
force_compile = False
version = 0.211 # must match constant returned in function get_version()
version = 0.212 # must match constant returned in function get_version()
lazylinker_ext: Optional[ModuleType] = None
......
......@@ -801,12 +801,14 @@ class VMLinker(LocalLinker):
return self
def accept_var_updates(self, updated_vars):
"""Records in the `Linker` which variables have update expressions.
It does not imply that the `Linker` will actually implement these updates
(see `need_update_inputs`). This mechanism is admittedly confusing, and
it could use some cleaning up. The base `Linker` object should probably
go away completely.
"""
self.updated_vars = updated_vars
# This method simply records in the linker which variables have update
# expressions. It does not imply that the linker will actually
# implement these updates (see need_update_inputs). This mechanism is
# admittedly confusing, and it could use some cleaning up. The base
# Linker object should probably go away completely.
def compute_gc_dependencies(self, variables):
"""
......@@ -978,18 +980,14 @@ class VMLinker(LocalLinker):
prereq_var_idxs.sort() # TODO: why sort?
node_prereqs.append(prereq_var_idxs)
# Builds the list of input storage to update (according to update
# rules) when the outputs are computed.
# They are in the same order as the second part of output_vars
# (output_vars contains first the returned outputs, then the
# values of the update expressions).
update_storage = []
update_in_from_out = {}
for (ivar, ovar) in updated_vars.items():
update_in_from_out[vars_idx[ovar]] = vars_idx[ivar]
for oidx in output_vars:
if oidx in update_in_from_out:
update_storage.append(update_in_from_out[oidx])
# This is essentially a version of `self.fgraph.update_mapping`.
# It specifies the outputs-to-inputs updates via the pairs
# `(input_idx, output_idx)` (i.e. the input at index `input_idx`
# takes the value of the output at index `output_idx`).
update_storage = tuple(
(vars_idx[in_var], self.fgraph.outputs.index(out_var))
for in_var, out_var in updated_vars.items()
)
# PyPy has no sys.getrefcount, so ignore this check if not running
# under CPython.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论