提交 a996911a authored 作者: Razvan Pascanu's avatar Razvan Pascanu

bug fixed scan

上级 c710fb2d
......@@ -211,15 +211,21 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={}, \
args += [init_out.type() ]
else:
args += [init_out[0].type() ]
for non_seq in non_seqs :
if not isinstance(non_seq, theano.compile.sharedvalue.SharedVariable):
args += [non_seq]
else:
tmp_var = theano.tensor.Tensor(dtype = non_seq.dtype,
broadcastable = non_seq.broadcastable)()
args += [ tmp_var ]
args += non_seqs
next_outs = fn(*args)
if not (type(next_outs) in (list,tuple)):
next_outs = [next_outs]
# Create the Scan op object
local_op = Scan( (args,next_outs), n_seqs,n_outs,inplace_map,
local_op = Scan( (args,next_outs ), n_seqs,n_outs,inplace_map,
sequences_taps, outputs_taps, truncate_gradient,
go_backwards, stored_steps_output, mode)
......@@ -362,8 +368,13 @@ class Scan(theano.Op):
else:
out_types += [theano.tensor.Tensor(dtype = inputs[i].dtype, \
broadcastable = (False,)+inputs[i].broadcastable[1:])()]
else:
out_types += [inputs[i].type()]
else:
if self.stored_steps_output[i-1-self.n_seqs] != 1 :
out_types += [ theano.tensor.Tensor(dtype = inputs[i].dtype,
broadcastable = (False,)+inputs[i].broadcastable)()]
else:
out_types += [ theano.tensor.Tensor(dtype = inputs[i].dtype,
broadcastable = inputs[i].broadcastable)()]
else:
raise ValueError(('You need to provide initial state for outputs'
' such that scan can infer what dataype they are'))
......@@ -374,7 +385,7 @@ class Scan(theano.Op):
rval = type(self) == type(other)
if rval:
rval = (self.inputs == other.inputs) and \
(self.outputs == other.outputs) and \
(self.outputs == other.outputs) and \
(self.stored_steps_output == other.stored_steps_output) and \
(self.seqs_taps == other.seqs_taps) and \
(self.outs_taps == other.outs_taps) and \
......@@ -408,7 +419,6 @@ class Scan(theano.Op):
def perform(self,node,args, outs):
n_steps = 0
if (self.n_seqs ==0 ) and (args[0] == 0):
raise ValueError('Scan does not know over how many steps it '
......@@ -605,7 +615,7 @@ def scan_make_inplace(node):
op = node.op
if isinstance(op, Scan) and (not op.inplace) \
and (op.inplace_map.keys() != []):
return Scan((op.inputs, op.outputs) , op.n_seqs,
return Scan((op.inputs, op.outputs ) , op.n_seqs,
op.n_outs, op.inplace_map, op.seqs_taps, op.outs_taps,
op.truncate_gradient, op.go_backwards, op.stored_steps_output,
inplace=True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论