提交 c232106b authored 作者: Razvan Pascanu's avatar Razvan Pascanu

new feature for scan op

上级 4f6c4303
......@@ -32,7 +32,7 @@ def hash_listsDictsTuples(x):
return hash_value
def scan(fn, sequences, initial_states, non_sequences, inplace_map={},
sequences_taps={}, outputs_taps = {},
sequences_taps={}, outputs_taps = {}, keep_outputs = {},
n_steps = theano.tensor.zero(), force_gradient = False,
truncate_gradient = -1, go_backwards = False, mode = 'FAST_RUN'):
......@@ -77,6 +77,16 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={},
outputs_taps.__delitem__(i)
elif not(type(outputs_taps[i]) in (list,tuple)):
outputs_taps[i] = [outputs_taps[i]]
# update keep_outputs list
for i in xrange(n_outs):
if not keep_outputs.has_key(i):
keep_outputs[i] = True
elif not keep_outputs[i]:
if outputs_taps[i] != [-1]:
keep_outputs[i] = True
warning('You need to keep past value of outputs if you use'\
'past taps of output different from -1')
......@@ -95,13 +105,27 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={},
args += [init_out[0].type() ]
args += non_seqs
t1,t2 = fn(*args)
t = fn(*args)
if type(t) in (list,tuple):
if len(t) == 2 :
if (type(t[0]) in (list,tuple,dict)) or (type(t[1]) in (list,tuple,dict)):
t1 = t[0]
t2 = t[1]
else:
t1 = t
t2 = {}
else:
t1 = t
t2 = {}
else:
t1 = t
t2 = {}
# check to see which is the updates list and which is the list of outs
if not ( (type(t1) in (list,tuple)) or (type(t1) == dict)) :
if not ( type(t1) in (list,tuple,dict) ) :
next_outs = [t1]
updates = t2
elif not ( (type(t2) in (list,tuple)) or (type(t2) == dict)) :
elif not ( type(t2) in (list,tuple, dict)) :
next_outs = [t2]
updates = t1
elif type(t1) == dict :
......@@ -117,11 +141,10 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={},
next_outs = t1
updates = t2
# Create the Scan op object
local_op = Scan( (args,next_outs, updates), n_seqs,n_outs,inplace_map,
sequences_taps, outputs_taps, force_gradient, truncate_gradient,
go_backwards, mode)
go_backwards, keep_outputs, mode)
# Call the object on the input sequences, initial values for outs,
# and non sequences
......@@ -137,7 +160,8 @@ class Scan(theano.Op):
def __init__(self,(inputs, outputs, updates),n_seqs, n_outs,
inplace_map={}, seqs_taps={}, outs_taps={},
force_gradient = False, truncate_gradient = -1,
go_backwards = False, mode = 'FAST_RUN', inplace=False):
go_backwards = False, keep_outputs = {},
mode = 'FAST_RUN', inplace=False):
# check inplace map
......@@ -178,9 +202,10 @@ class Scan(theano.Op):
self.seqs_taps = seqs_taps
self.outs_taps = outs_taps
self.n_seqs = n_seqs
self.n_outs = n_outs
self.n_outs = n_outs
self.n_args = n_seqs+n_outs+1
self.inplace_map = inplace_map
self.keep_outputs = keep_outputs
self.inplace = inplace
self.inputs = inputs
self.outputs = outputs
......@@ -188,7 +213,7 @@ class Scan(theano.Op):
self.force_gradient = force_gradient
self.truncate_gradient = truncate_gradient
self.go_backwards = go_backwards
self.fn = theano.function(inputs,outputs, \
updates = updates, mode = mode)
......@@ -228,8 +253,16 @@ class Scan(theano.Op):
out_types = []
for i in xrange(self.n_seqs+1, self.n_seqs+self.n_outs+1):
if not (inputs[i] == []):
out_types += [theano.tensor.Tensor(dtype=inputs[i].dtype,\
broadcastable=(False,)+inputs[i].broadcastable[1:])()]
if self.outs_taps.has_key(i-1-self.n_seqs) and \
(self.outs_taps[i-self.n_seqs-1]==[-1]) and \
(self.keep_outputs[i-1-self.n_seqs]):
out_types += [theano.tensor.Tensor(dtype=inputs[i].dtype, \
broadcastable=(False,)+inputs[i].broadcastable)()]
elif not self.keep_outputs[i-1-self.n_seqs]:
out_types += [ inputs[i].type()]
else:
out_types += [theano.tensor.Tensor(dtype=inputs[i].dtype,\
broadcastable=(False,)+inputs[i].broadcastable[1:])()]
else:
raise ValueError(('You need to provide initial state for outputs'
' such that scan can infer what dataype they are'))
......@@ -242,6 +275,7 @@ class Scan(theano.Op):
rval = (self.inputs == other.inputs) and \
(self.outputs == other.outputs) and \
(self.updates == other.updates) and \
(self.keep_outputs == other.keep_outputs) and \
(self.g_ins == other.g_ins) and \
(self.g_outs == other.g_outs) and \
(self.seqs_taps == other.seqs_taps) and \
......@@ -272,7 +306,8 @@ class Scan(theano.Op):
hash_listsDictsTuples(self.g_outs) ^ \
hash_listsDictsTuples(self.seqs_taps) ^\
hash_listsDictsTuples(self.outs_taps) ^\
hash_listsDictsTuples(self.updates)
hash_listsDictsTuples(self.updates) ^\
hash_listsDictsTuples(self.keep_outputs)
......@@ -345,7 +380,10 @@ class Scan(theano.Op):
if inplace_map.has_key(i) and (inplace_map[i] >= 0):
y += [args[inplace_map[i]]]
else:
y_shape = (n_steps,)+args[i+n_seqs].shape[1:]
if self.keep_outputs[i]:
y_shape = (n_steps,)+args[i+n_seqs].shape[1:]
else:
y_shape = args[i+n_seqs].shape[1:]
y += [numpy.empty(y_shape,
dtype=args[i+n_seqs].dtype)]
seqs_mins = {}
......@@ -395,22 +433,30 @@ class Scan(theano.Op):
else:
fn_args += [args[j+n_seqs][k]]
else:
fn_args += [y[j][i + tap_value]]
if self.keep_outputs[j]:
fn_args += [y[j][i + tap_value]]
else:
fn_args += [y[j] ]
# get the non-iterable sequences
fn_args += list(args[(n_seqs+n_outs):])
# compute output
something = fn(*fn_args)
#update outputs
for j in xrange(n_outs):
y[j][i] = something[j]
if self.keep_outputs[j]:
y[j][i] = something[j]
else:
y[j] = something[j]
return y
def grad(self, args, g_outs):
if (not self.force_gradient) and \
((self.updates.keys() != []) or (self.inplace_map.keys() != [])):
((self.updates.keys() != []) or (self.inplace_map.keys() != [])\
or numpy.any(self.keep_outputs)):
warning('Can not compute gradients if inplace or updates ' \
'are used. Use force_gradient if you know for sure '\
'are used or if you do not keep past value of outputs.'\
'Use force_gradient if you know for sure '\
'that the gradient can be computed automatically.')
return [None for i in args]
else:
......
......@@ -159,8 +159,8 @@ class T_Scan(unittest.TestCase):
y0 = theano.tensor.dscalar('y0')
def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, W_in1):
return ({}, [theano.dot(u1_t,W_in1) + u2_t* W_in2 + \
theano.dot(x_tm1, W), theano.dot(x_tm1, W_out)])
return [theano.dot(u1_t,W_in1) + u2_t* W_in2 + \
theano.dot(x_tm1, W), theano.dot(x_tm1, W_out)]
Y = theano.sandbox.scan.scan(f_rnn_cmpl,[u1,u2],[x0,y0],W_in1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论