提交 d3715b59 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

created state_buffers and dealt with non-numeric states

I've cerated similarly to aux inpute a list of state_buffers, the difference being that know we keep track of the length of each argument as well..
上级 eec6ce70
...@@ -222,21 +222,29 @@ class ScanOp(PureOp): ...@@ -222,21 +222,29 @@ class ScanOp(PureOp):
# 2.2. Next the states (numeric) and the outputs # 2.2. Next the states (numeric) and the outputs
updates = {} updates = {}
state_buffers = []
n_numeric_values = len(self.lengths) n_numeric_values = len(self.lengths)
for pos, (mem_buf, var, expr) in enumerate( for pos in xrange(n_numeric_values):
izip(node_output_storage, base_inputs, self.outputs)): var = base_inputs[pos]
givens[var] = theano.shared(mem_buf[0], name=var.name, mem_buf = base_buffers[pos]
borrow=True) expr = self.outputs[pos]
givens[var] = fake_shared(var)
state_buffers.append((givens[var], self.lengths[pos], mem_buf))
updates[givens[var]] = expr updates[givens[var]] = expr
if pos < n_numeric_values:
self.lengths[pos].set_value(mem_buf[0].shape[0])
givens[self.lengths[pos]] = \ #2.3 Non-numeric states
tensor.constant(mem_buf[0].shape[0]) n_non_numeric = len(self.outputs) - n_numeric_values
fn_outs = self.outputs[n_numeric_values:]
# 3.3 Add the update for the index of scan for var in base_inputs[n_numeric_values:]:
updates[self.t] = self.t + numpy.int64(1) givens[var] = var.type()
# 4.1 Construct the inner function of scan non_tensor_args.append(givens[var])
fn_outs = [] non_numeric_states_bufs = base_buffers[n_numeric_values:]
# 2.4 Add the update for the index of scan
updates[self.index] = self.index + numpy.int64(1)
# 3.1 Construct the inner function of scan
if self.as_repeatUntil is not None: if self.as_repeatUntil is not None:
fn_outs = self.as_repeatUntil fn_outs = self.as_repeatUntil
self.fn = theano.function([], fn_outs, self.fn = theano.function([], fn_outs,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论