提交 d3715b59 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

created state_buffers and dealt with non-numeric states

I've cerated similarly to aux inpute a list of state_buffers, the difference being that know we keep track of the length of each argument as well..
上级 eec6ce70
......@@ -222,21 +222,29 @@ class ScanOp(PureOp):
# 2.2. Next the states (numeric) and the outputs
updates = {}
state_buffers = []
n_numeric_values = len(self.lengths)
for pos, (mem_buf, var, expr) in enumerate(
izip(node_output_storage, base_inputs, self.outputs)):
givens[var] = theano.shared(mem_buf[0], name=var.name,
borrow=True)
for pos in xrange(n_numeric_values):
var = base_inputs[pos]
mem_buf = base_buffers[pos]
expr = self.outputs[pos]
givens[var] = fake_shared(var)
state_buffers.append((givens[var], self.lengths[pos], mem_buf))
updates[givens[var]] = expr
if pos < n_numeric_values:
self.lengths[pos].set_value(mem_buf[0].shape[0])
givens[self.lengths[pos]] = \
tensor.constant(mem_buf[0].shape[0])
# 3.3 Add the update for the index of scan
updates[self.t] = self.t + numpy.int64(1)
# 4.1 Construct the inner function of scan
fn_outs = []
#2.3 Non-numeric states
n_non_numeric = len(self.outputs) - n_numeric_values
fn_outs = self.outputs[n_numeric_values:]
for var in base_inputs[n_numeric_values:]:
givens[var] = var.type()
non_tensor_args.append(givens[var])
non_numeric_states_bufs = base_buffers[n_numeric_values:]
# 2.4 Add the update for the index of scan
updates[self.index] = self.index + numpy.int64(1)
# 3.1 Construct the inner function of scan
if self.as_repeatUntil is not None:
fn_outs = self.as_repeatUntil
self.fn = theano.function([], fn_outs,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论