提交 abfb6448 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

the body of the two cases of perform

上级 cfdab794
......@@ -270,21 +270,72 @@ class ScanOp(PureOp):
# reset all switches if any
for sw in self.switches:
sw.set_value(numpy.int8(0), borrow=True)
# set aux shared variables
for var, val in aux_buffers:
var.set_value(val[0], borrow=True)
# set state shared variables
for var, length, val in state_buffers:
var.set_value(val[0], borrow=True)
length.set_value(val[0].shape[0], borrow=True)
# grab fixed arguments
fix_args = [x[0] for x in non_tensor_buffers]
while cont and pos < node_input_storage[0][0]:
cont = self.fn()
extra_args = [x[0] for x in non_numeric_states_bufs]
rvals = self.fn(*(fix_args + extra_args))
for buf, rval in izip(non_numeric_states_bufs, rvals):
buf[0] = rval
cont = rvals[-1]
pos = pos + 1
# We need to trim the outputs if they are longer
for pos, membuf in enumerate(
node_output_storage[:n_numeric_values]):
if membuf[0].shape[0] > pos + self.mintaps[pos]:
membuf[0] = membuf[0][:pos + self.mintaps[pos]]
for pos in xrange(n_numeric_values):
buf = state_buffers[pos][2][0]
mintap = self.mintaps[pos]
if buf.shape[0] > pos + self.mintaps[pos]:
node_output_storage[pos][0] = buf[:pos + mintap]
else:
node_output_storage[pos][0] = buf
for out_buf, in_buf in izip(
node_output_storage[n_numeric_values:],
non_numeric_states_bufs):
out_buf[0] = in_buf[0]
else:
# 3.2.2 as a for
def p(node, args, outs):
# copy inputs if not inplace
if not self.inplace:
for _, _, val in state_buffers:
val[0] = val[0].copy()
for buf in non_numeric_states_bufs:
buf[0] = buf[0].copy()
# reset all switches if any
for sw in self.switches:
sw.set_value(numpy.int8(0), borrow=True)
self.fn.fn(n_calls=node_input_storage[0][0])
# set aux shared variables
for var, val in aux_buffers:
var.set_value(val[0], borrow=True)
# set state shared variables
for var, length, val in state_buffers:
var.set_value(val[0], borrow=True)
length.set_value(val[0].shape[0], borrow=True)
# grab fixed arguments
fix_args = [x[0] for x in non_tensor_buffers]
for dx in xrange(node_input_storage[0][0]):
extra_args = [x[0] for x in non_numeric_states_bufs]
rvals = self.fn(*(fix_args + extra_args))
for buf, rval in izip(non_numeric_states_bufs, rvals):
buf[0] = rval
for pos in xrange(n_numeric_values):
buf = state_buffers[pos][2][0]
mintap = self.mintaps[pos]
node_output_storage[pos][0] = buf
for out_buf, in_buf in izip(
node_output_storage[n_numeric_values:],
non_numeric_states_bufs):
out_buf[0] = in_buf[0]
# 3.3 construct the rval function
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
r = perform(n, [x[0] for x in i], o)
for o in node.outputs:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论