提交 aa156cc3 authored 作者: James Bergstra's avatar James Bergstra

more CVM vs. Linker hackiness

This commit introduces 'accept_var_updates' that the VM_Linker uses to get update expressions from the FunctionMaker class. The previous code didn't work right with ProfileMode, as triggered by the test_modes unit test.
上级 96827176
...@@ -68,6 +68,23 @@ def infer_reuse_pattern(env, outputs_to_disown): ...@@ -68,6 +68,23 @@ def infer_reuse_pattern(env, outputs_to_disown):
return rval return rval
def env_updated_vars(env, expanded_inputs):
"""
Reconstruct the full "updates" dictionary, mapping from Env input
variables to the env outputs that will replace their values.
:rtype: dict variable -> variable
"""
updated_vars = {}
potential_values = list(env.outputs) # copy the list
if len(expanded_inputs) != len(env.inputs):
raise ValueError('expanded_inputs must match len(env.inputs)')
for e_input, ivar in reversed(zip(expanded_inputs, env.inputs)):
if e_input.update is not None:
updated_vars[ivar] = potential_values.pop()
return updated_vars
class Supervisor: class Supervisor:
""" """
Listener for Env events which makes sure that no operation overwrites the Listener for Env events which makes sure that no operation overwrites the
...@@ -1030,8 +1047,10 @@ class FunctionMaker(object): ...@@ -1030,8 +1047,10 @@ class FunctionMaker(object):
else: else:
self.linker = linker.accept(env) self.linker = linker.accept(env)
#hacky thing so VMLinker if hasattr(linker, 'accept_var_updates'):
self.linker.expanded_inputs = expanded_inputs # hacky thing so VMLinker knows about updates
self.linker.accept_var_updates(
env_updated_vars(env, expanded_inputs))
self.indices = indices self.indices = indices
self.inputs = inputs self.inputs = inputs
......
...@@ -61,6 +61,10 @@ class VM(object): ...@@ -61,6 +61,10 @@ class VM(object):
self.call_counts = [0]*len(nodes) self.call_counts = [0]*len(nodes)
self.call_times = [0]*len(nodes) self.call_times = [0]*len(nodes)
self.time_thunks = False self.time_thunks = False
# This variable (self.need_update_inputs) is overshadowed by
# CLazyLinker in CVM which has an attribute of the same name that
# defaults to 0 (aka False).
self.need_update_inputs = True self.need_update_inputs = True
def __call__(self): def __call__(self):
...@@ -405,6 +409,7 @@ class VM_Linker(link.LocalLinker): ...@@ -405,6 +409,7 @@ class VM_Linker(link.LocalLinker):
self.allow_gc = allow_gc self.allow_gc = allow_gc
self.use_cloop = use_cloop self.use_cloop = use_cloop
self.callback = callback self.callback = callback
self.updated_vars = {}
def accept(self, env, no_recycling = []): def accept(self, env, no_recycling = []):
""" """
...@@ -420,6 +425,14 @@ class VM_Linker(link.LocalLinker): ...@@ -420,6 +425,14 @@ class VM_Linker(link.LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def accept_var_updates(self, updated_vars):
self.updated_vars = updated_vars
# This method simply records in the linker which variables have update
# expressions. It does not imply that the linker will actually
# implement these updates (see need_update_inputs). This mechanism is
# admittedly confusing, and it could use some cleaning up. The base
# Linker object should probably go away completely.
def make_vm(self, nodes, thunks, def make_vm(self, nodes, thunks,
input_storage, output_storage, storage_map, input_storage, output_storage, storage_map,
post_thunk_clear, post_thunk_clear,
...@@ -559,7 +572,6 @@ class VM_Linker(link.LocalLinker): ...@@ -559,7 +572,6 @@ class VM_Linker(link.LocalLinker):
def make_all(self, profiler = None, input_storage = None, def make_all(self, profiler = None, input_storage = None,
output_storage = None, output_storage = None,
): ):
expanded_inputs=self.expanded_inputs # hacky argumentpassing workaround
env = self.env env = self.env
order = list(env.toposort()) order = list(env.toposort())
no_recycling = self.no_recycling no_recycling = self.no_recycling
...@@ -590,24 +602,12 @@ class VM_Linker(link.LocalLinker): ...@@ -590,24 +602,12 @@ class VM_Linker(link.LocalLinker):
else: else:
post_thunk_clear = None post_thunk_clear = None
# calculate the update_storage map whose keys are shared var inputs
# and whose values are the outputs that hold their updates
updated_vars = {}
if expanded_inputs:
# Update the inputs that have an update function
potential_values = list(env.outputs)
assert len(expanded_inputs)==len(env.inputs)
for e_input, ivar in reversed(zip(expanded_inputs, env.inputs)):
if e_input.update is not None:
updated_vars[ivar] = potential_values.pop()
vm = self.make_vm(order, thunks, vm = self.make_vm(order, thunks,
input_storage, output_storage, storage_map, input_storage, output_storage, storage_map,
post_thunk_clear, post_thunk_clear,
computed, computed,
compute_map, compute_map,
updated_vars self.updated_vars
) )
return (vm, return (vm,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论