提交 67beeb9d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

added a make_thunk to scan

It helps by compiling the inner graph only once, making compilation much faster
上级 bb7e191a
...@@ -155,22 +155,6 @@ class Scan(Op): ...@@ -155,22 +155,6 @@ class Scan(Op):
# function that we set in case none was given # function that we set in case none was given
self.info['name'] = self.name self.info['name'] = self.name
# If a shared variable is the result of a ViewOp it is a clear
# indication that we need to copy that value after the perform of
# scan is done
slices = ( self.n_mit_mot_outs +
self.n_mit_sot +
self.n_sit_sot +
self.n_nit_sot )
wrapped_inputs = [Param(x, borrow=True) for x in inputs ]
wrapped_outputs = [Out(x, borrow=True) for x in
outputs[:slices] ]
wrapped_outputs += outputs[slices:]
self.fn = function(wrapped_inputs,
wrapped_outputs,
mode = self.mode_instance,
name = self.name )
# Pre-computing some values to speed up perform # Pre-computing some values to speed up perform
self.mintaps = [ numpy.min(x) for x in self.tap_array] self.mintaps = [ numpy.min(x) for x in self.tap_array]
self.mintaps += [ 0 for x in xrange(self.n_nit_sot) ] self.mintaps += [ 0 for x in xrange(self.n_nit_sot) ]
...@@ -183,7 +167,10 @@ class Scan(Op): ...@@ -183,7 +167,10 @@ class Scan(Op):
self.n_shared_outs ) self.n_shared_outs )
self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
self.n_tap_outs = self.n_mit_mot + self.n_mit_sot self.n_tap_outs = self.n_mit_mot + self.n_mit_sot
self._cmodule_key = gof.CLinker.cmodule_key_(self.fn.maker.env,[]) tmp_in, tmp_out = scan_utils.reconstruct_graph(self.inputs,
self.outputs)
local_env = gof.Env(tmp_in, tmp_out)
self._cmodule_key = gof.CLinker.cmodule_key_(local_env,[])
self._hash_inner_graph = hash(self._cmodule_key) self._hash_inner_graph = hash(self._cmodule_key)
...@@ -347,6 +334,68 @@ class Scan(Op): ...@@ -347,6 +334,68 @@ class Scan(Op):
scan_utils.hash_listsDictsTuples(self.info) ) scan_utils.hash_listsDictsTuples(self.info) )
def make_thunk(self, node, storage_map, compute_map, no_recycling):
"""
:param node: something previously returned by self.make_node
:param storage_map: dict variable -> one-element-list where a computed
value for this variable may be found.
:param compute_map: dict variable -> one-element-list where a boolean
value will be found. The boolean indicates whether the
variable's storage_map container contains a valid value (True)
or if it has not been computed yet (False).
:param no_recycling: list of variables for which it is forbidden to
reuse memory allocated by a previous call.
:note: If the thunk consults the storage_map on every call, it is safe
for it to ignore the no_recycling argument, because elements of the
no_recycling list will have a value of None in the storage map. If
the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list.
"""
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
node_input_compute = [compute_map[r] for r in node.inputs]
node_output_compute = [compute_map[r] for r in node.outputs]
#logger.debug('Compiling node %i of graph' % node_idx)
# If a shared variable is the result of a ViewOp it is a clear
# indication that we need to copy that value after the perform of
# scan is done
slices = ( self.n_mit_mot_outs +
self.n_mit_sot +
self.n_sit_sot +
self.n_nit_sot )
wrapped_inputs = [Param(x, borrow=True) for x in self.inputs ]
wrapped_outputs = [Out(x, borrow=True) for x in
self.outputs[:slices] ]
wrapped_outputs += self.outputs[slices:]
profile = None
if theano.config.profile or type(self.profile) is str:
profile = ScanProfileStats(name = self.name)
elif self.profile:
profile = self.profile
self.fn = function(wrapped_inputs,
wrapped_outputs,
mode = self.mode_instance,
name = self.name,
profile = profile)
p = self.perform
# default arguments are stored in the closure of `rval`
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
return r
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.perform = p
rval.lazy = False
return rval
def perform( self, node, args, outs): def perform( self, node, args, outs):
""" """
The args are packed like this: The args are packed like this:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论