提交 b4d3b368 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

make thunk function for scan

上级 8de65286
......@@ -161,10 +161,109 @@ class ScanOp(PureOp):
return input_shapes[1: n_outs + 1]
def make_thunk(self, node, storage_map, compute_map, no_recycling):
pass
"""
:param node: the Apply node returned by the ``make_node`` function
of the scan op class
def infer_shape(self, node, input_shapes):
pass
:param storage_map: dict variable -> one-element-list where a computed
value for this variable may be found.
:param compute_map: dict variable -> one-element-list where a boolean
value will be found. The boolean indicates whether the
variable's storage_map container contains a valid value (True)
or if it has not been computed yet (False).
:param no_recycling: list of variables for which it is forbidden to
reuse memory allocated by a previous call.
:note: If the thunk consults the storage_map on every call, it is safe
for it to ignore the no_recycling argument, because elements of the
no_recycling list will have a value of None in the storage map. If
the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list.
"""
# 1. Collect all memory buffers
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
node_input_compute = [compute_map[r] for r in node.inputs]
node_output_compute = [compute_map[r] for r in node.outputs]
# 2. If the op is not inplace we need to copy over the initial values
if not self.inplace:
for membuf1, membuf2 in izip(
node_output_storage,
node_input_storage[1: 1 + len(node_output_storage)]):
membuf1[0][:] = membuf2[0]
# 3. Construct fake shared variables around every argument of scan
givens = {}
base_inputs = self.inputs[:len(self.outputs)]
aux_inputs = self.inputs[len(self.outputs):]
# 3.1 First the auxiliary arguments, those that are parameters or
# input
for mem_buf, var in izip(ndoe_input_storage[1 + len(base_inputs):],
aux_inputs):
givens[var] = theano.shared(mem_buf[0], name=var.name,
borrow=True)
# 3.2. Next the states (numeric or not) and the outputs
updates = {}
n_numeric_values = len(self.lengths)
for pos, (mem_buf, var, expr) in enumerate(
izip(node_output_storage, base_inputs, self.outputs)):
givens[var] = theano.shared(mem_buf[0], name=var.name,
borrow=True)
updates[givens[var]] = expr
if pos < n_numeric_values:
self.lengths[pos].set_value(mem_buf[0].shape[0])
givens[self.lengths[pos]] = \
tensor.constant(mem_buf[0].shape[0])
# 3.3 Add the update for the index of scan
updates[self.t] = self.t + numpy.int64(1)
# 4.1 Construct the inner function of scan
fn_outs = []
if self.as_repeatUntil is not None:
fn_outs = self.as_repeatUntil
self.fn = theano.function([], fn_outs,
givens=givens,
updates=updates,
mode=self.mode_instance,
name=self.name,
profile=self.profile)
# Construct the perform
if self.as_repeatUntil is not None:
def p(node, args, outs):
pos = 0
cont = 1
# reset all switches if any
for sw in self.swithces:
sw.set_value(numpy.int8(0), borrow=True)
while cont and pos < node_input_storage[0][0]:
cont = self.fn()
pos = pos + 1
# We need to trim the outputs if they are longer
for pos, membuf in enumerate(
node_output_storage[:n_numeric_values]):
if membuf[0].shape[0] > pos + self.mintaps[pos]:
membuf[0] = membuf[0][:pos + self.mintaps[pos]]
else:
def p(node, args, outs):
for sw in self.switches:
sw.set_value(numpy.int8(0), borrow=True)
self.fn.fn(n_calls=node_input_storage[0][0])
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
r = perform(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
return r
rval.inputs = node_input_storage
rval.outputs = node_output_storage
rval.perform = p
rval.lazy = False
return rval
def grad(self, args, g_outs):
pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论