提交 530f442a authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add optional callback on inputs in StackVM.

上级 49590e13
......@@ -327,7 +327,7 @@ class Stack(VM):
def __init__(self, nodes, thunks, pre_call_clear,
storage_map, compute_map, fgraph, allow_gc,
dependencies=None, callback=None):
dependencies=None, callback=None, callback_input=None):
super(Stack, self).__init__(nodes, thunks, pre_call_clear)
self.allow_gc = allow_gc
......@@ -340,6 +340,7 @@ class Stack(VM):
self.compute_map = compute_map
self.node_idx = node_idx = {}
self.callback = callback
self.callback_input = callback_input
ords = fgraph.orderings()
......@@ -406,6 +407,8 @@ class Stack(VM):
for k in self.storage_map:
compute_map[k][0] = (k.owner is None)
if self.callback_input and compute_map[k][0]:
self.callback_input(k, self.storage_map[k][0])
# apply_stack contains nodes
if output_subset is not None:
......@@ -679,6 +682,10 @@ class VM_Linker(link.LocalLinker):
A callable object to call after each call to a thunk within
the virtual machine. It will be called with four arguments called
'node', 'thunk', 'storage_map', and 'compute_map'.
callback_input
A callable object to call on each input to the graph
(variables with no owner). It will be called with two
arguments: 'var', 'value'.
lazy
Useful only when use_cloop is False. When lazy is None, use the
theano flag vm.lazy value. Then if we have a None (default) we auto
......@@ -695,8 +702,8 @@ class VM_Linker(link.LocalLinker):
"""
def __init__(self, allow_gc=None, use_cloop=False, callback=None,
lazy=None, schedule=None, c_thunks=None,
allow_partial_eval=None):
callback_input=None, lazy=None, schedule=None,
c_thunks=None, allow_partial_eval=None):
# Note: if more parameters are added to __init__, make sure to forward
# them in the "type(self)(...)" call in the "accept" method below.
if allow_gc is None:
......@@ -705,6 +712,7 @@ class VM_Linker(link.LocalLinker):
self.allow_gc = allow_gc
self.use_cloop = use_cloop
self.callback = callback
self.callback_input = callback_input
self.lazy = lazy
self.c_thunks = c_thunks
self.allow_partial_eval = allow_partial_eval
......@@ -752,9 +760,11 @@ class VM_Linker(link.LocalLinker):
allow_gc=self.allow_gc,
use_cloop=self.use_cloop,
callback=self.callback,
callback_input=self.callback_input,
lazy=self.lazy,
schedule=self.schedule,
c_thunks=self.c_thunks,
allow_partial_eval=self.allow_partial_eval
).accept(fgraph, no_recycling)
self.fgraph = fgraph
self.no_recycling = no_recycling
......@@ -821,16 +831,17 @@ class VM_Linker(link.LocalLinker):
pre_call_clear = [storage_map[v] for v in self.no_recycling]
if (self.callback is not None or
if (self.callback is not None or self.callback_input is not None or
(config.profile and config.profile_memory) or
getattr(self, 'allow_partial_eval', False)):
self.allow_partial_eval):
if self.use_cloop and self.callback is not None:
if self.use_cloop and (self.callback is not None or
self.callback_input is not None):
logger.warn('CVM does not support callback, using Stack VM.')
if self.use_cloop and config.profile_memory:
warnings.warn(
'CVM does not support memory profile, using Stack VM.')
if self.use_cloop and getattr(self, 'allow_partial_eval', False):
if self.use_cloop and self.allow_partial_eval:
warnings.warn(
'CVM does not support partial evaluation yet, '
'using Stack VM.')
......@@ -841,7 +852,8 @@ class VM_Linker(link.LocalLinker):
storage_map, compute_map,
self.fgraph, self.allow_gc,
dependencies=deps,
callback=self.callback)
callback=self.callback,
callback_input=self.callback_input)
elif self.use_cloop:
# create a map from nodes to ints and vars to ints
nodes_idx = {}
......@@ -1038,7 +1050,7 @@ class VM_Linker(link.LocalLinker):
if lazy is None:
lazy = not all([(not th.lazy) for th in thunks])
if not (lazy or (config.profile and config.profile_memory) or
self.use_cloop or self.callback):
self.use_cloop or self.callback or self.callback_input):
for pair in itervalues(reallocated_info):
storage_map[pair[1]] = storage_map[pair[0]]
......@@ -1080,3 +1092,7 @@ class VM_Linker(link.LocalLinker):
self.__dict__.update(d)
if not hasattr(self, 'c_thunks'):
self.c_thunks = True
if not hasattr(self, 'allow_partial_eval'):
self.allow_partial_eval = None
if not hasattr(self, 'callback_input'):
self.callback_input = None
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论