提交 c7a9ac80 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

typo

上级 8ca0abd9
......@@ -183,17 +183,17 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={}, \
elif not(type(outputs_taps[i]) in (list,tuple)):
outputs_taps[i] = [outputs_taps[i]]
'''
# update keep_outputs list
# update stored_steps_output list
for i in xrange(n_outs):
if not keep_outputs.has_key(i):
keep_outputs[i] = True
elif not keep_outputs[i]:
if not stored_steps_output.has_key(i):
stored_steps_output[i] = True
elif not stored_steps_output[i]:
if outputs_taps[i] != [-1]:
keep_outputs[i] = True
stored_steps_output[i] = True
warning('You need to keep past value of outputs if you use'\
'past taps of output different from -1')
'''
keep_outputs = [ 0 for i in xrange(n_outs)]
stored_steps_output = [ 0 for i in xrange(n_outs)]
......@@ -221,7 +221,7 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={}, \
# Create the Scan op object
local_op = Scan( (args,next_outs), n_seqs,n_outs,inplace_map,
sequences_taps, outputs_taps, truncate_gradient,
go_backwards, keep_outputs, mode)
go_backwards, stored_steps_output, mode)
# Call the object on the input sequences, initial values for outs,
# and non sequences
......@@ -237,7 +237,7 @@ class Scan(theano.Op):
def __init__(self,(inputs, outputs),n_seqs, n_outs,
inplace_map={}, seqs_taps={}, outs_taps={},
truncate_gradient = -1,
go_backwards = False, keep_outputs = {},
go_backwards = False, stored_steps_output = {},
mode = 'FAST_RUN', inplace=False):
'''
:param (inputs,outputs): inputs and outputs Theano variables that
......@@ -258,7 +258,7 @@ class Scan(theano.Op):
:param go_bacwards: see scan funcion above
:param keep_outputs: a list of booleans of same size as the number of
:param stored_steps_output: a list of booleans of same size as the number of
outputs; the value at position ``i`` in the list corresponds to the
``i-th`` output, and it tells how many steps (from the end towards
the begining) of the outputs you really need and should return;
......@@ -308,7 +308,7 @@ class Scan(theano.Op):
self.n_outs = n_outs
self.n_args = n_seqs+n_outs+1
self.inplace_map = inplace_map
self.keep_outputs = keep_outputs
self.stored_steps_output = stored_steps_output
self.inplace = inplace
self.inputs = inputs
self.outputs = outputs
......@@ -352,16 +352,18 @@ class Scan(theano.Op):
out_types = []
for i in xrange(self.n_seqs+1, self.n_seqs+self.n_outs+1):
if not (inputs[i] == []):
if self.outs_taps.has_key(i-1-self.n_seqs) and \
(self.outs_taps[i-self.n_seqs-1]==[-1]) and \
(self.keep_outputs[i-1-self.n_seqs] != 1):
out_types += [theano.tensor.Tensor(dtype=inputs[i].dtype, \
broadcastable=(False,)+inputs[i].broadcastable)()]
elif not self.keep_outputs[i-1-self.n_seqs] == 1:
out_types += [ inputs[i].type()]
if self.outs_taps.has_key(i-1-self.n_seqs):
if (self.outs_taps[i-self.n_seqs-1] == [-1]) and \
(self.stored_steps_output[i-1-self.n_seqs] != 1):
out_types += [ theano.tensor.Tensor(dtype=inputs[i].dtype,
broadcastable = (False,)+inputs[i].broadcastable)()]
elif not self.stored_steps_output[i-1-self.n_seqs] ==1 :
out_types += [inputs[i].type()]
else:
out_types += [theano.tensor.Tensor(dtype=inputs[i].dtype,\
broadcastable=(False,)+inputs[i].broadcastable[1:])()]
out_types += [theano.tensor.Tensor(dtype = inputs[i].dtype, \
broadcastable = (False,)+inputs[i].broadcastable[1:])()]
else:
out_types += [inputs[i].type()]
else:
raise ValueError(('You need to provide initial state for outputs'
' such that scan can infer what dataype they are'))
......@@ -373,7 +375,7 @@ class Scan(theano.Op):
if rval:
rval = (self.inputs == other.inputs) and \
(self.outputs == other.outputs) and \
(self.keep_outputs == other.keep_outputs) and \
(self.stored_steps_output == other.stored_steps_output) and \
(self.seqs_taps == other.seqs_taps) and \
(self.outs_taps == other.outs_taps) and \
(self.inplace_map == other.inplace_map) and \
......@@ -400,7 +402,7 @@ class Scan(theano.Op):
hash_listsDictsTuples(self.g_outs) ^ \
hash_listsDictsTuples(self.seqs_taps) ^\
hash_listsDictsTuples(self.outs_taps) ^\
hash_listsDictsTuples(self.keep_outputs)
hash_listsDictsTuples(self.stored_steps_output)
......@@ -473,12 +475,12 @@ class Scan(theano.Op):
if inplace_map.has_key(i) and (inplace_map[i] >= 0):
y += [args[inplace_map[i]]]
else:
if self.keep_outputs[i] < 1 :
if self.stored_steps_output[i] < 1 :
y_shape = (n_steps,)+args[i+n_seqs].shape[1:]
elif self.keep_outputs[i] == 1:
elif self.stored_steps_output[i] == 1:
y_shape = args[i+n_seqs].shape[1:]
else:
y_shape = (self.keep_outputs[i],)+args[i+n_seqs].shape[1:]
y_shape = (self.stored_steps_output[i],)+args[i+n_seqs].shape[1:]
y += [numpy.empty(y_shape,
......@@ -530,9 +532,9 @@ class Scan(theano.Op):
else:
fn_args += [args[j+n_seqs][k]]
else:
if self.keep_outputs[j] < 1:
if self.stored_steps_output[j] < 1:
fn_args += [y[j][i + tap_value]]
elif self.keep_outputs[j] == 1:
elif self.stored_steps_output[j] == 1:
fn_args += [y[j] ]
else:
raise NotImplementedError('This will be implemented in the near future')
......@@ -542,9 +544,9 @@ class Scan(theano.Op):
something = fn(*fn_args)
#update outputs
for j in xrange(n_outs):
if self.keep_outputs[j] <1:
if self.stored_steps_output[j] <1:
y[j][i] = something[j]
elif self.keep_outputs[j] == 1:
elif self.stored_steps_output[j] == 1:
y[j] = something[j]
else:
raise NotImplementedError('This will be implemented in the near future')
......@@ -557,7 +559,7 @@ class Scan(theano.Op):
'''
if True:
#((self.updates.keys() != []) or (self.inplace_map.keys() != [])\
# or numpy.any(self.keep_outputs)):
# or numpy.any(self.stored_steps_output)):
# warning('Can not compute gradients if inplace or updates ' \
# 'are used or if you do not keep past value of outputs.'\
# 'Use force_gradient if you know for sure '\
......@@ -602,7 +604,7 @@ def scan_make_inplace(node):
and (op.inplace_map.keys() != []):
return Scan((op.inputs, op.outputs) , op.n_seqs,
op.n_outs, op.inplace_map, op.seqs_taps, op.outs_taps,
op.truncate_gradient, op.go_backwards, op.keep_outputs,
op.truncate_gradient, op.go_backwards, op.stored_steps_output,
inplace=True
).make_node(*node.inputs).outputs
return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论