提交 746b8dd9 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

commented version of scan; I took out the updates ( which could have resulted in many issues)

上级 6ff2fca5
"""Provide Scan and related functions
Scanning a function over sequential input(s) producing sequential output(s).
Scanning is a general form of recurrence, which can be used for looping.
The idea is that you 'scan' a function along some input sequence, producing
an output at each time-step that can be seen (but not modified) by the
function at the next time-step. (Technically, the function can see the
previous K time-steps of your outputs and L time steps of your inputs,
future and past.)
So for example, ``sum()`` could be computed by scanning the ``z+x_i``
function over a list, given an initial state of ``z=0``.
Special cases:
- A ``reduce()`` operation can be performed by returning only the last
output of a scan.
- A ``map()`` operation can be performed by applying a function that
ignores each previous output.
Often a for loop can be expressed as a scan() operation, and scan is the
closest that theano comes to looping.
This module provides scanning functionality with the `Scan` Op.
"""
__docformat__ = 'restructedtext en'
import numpy
import theano
from theano.tensor import opt
......@@ -32,10 +64,88 @@ def hash_listsDictsTuples(x):
return hash_value
def scan(fn, sequences, initial_states, non_sequences, inplace_map={},
sequences_taps={}, outputs_taps = {}, keep_outputs = {},
n_steps = theano.tensor.zero(), force_gradient = False,
sequences_taps={}, outputs_taps = {},
n_steps = theano.tensor.zero(),
truncate_gradient = -1, go_backwards = False, mode = 'FAST_RUN'):
'''Function that constructs and applies a scan op
:param fn: given variables representing all the slices of input and
past values of outputs and other non sequences parameters, fn should
produce variables describing the output of one time step of scan.
The order in which the argument to this function are given is very
important. You should have the following order:
* all time slices of the first sequence (as given in the ``sequences``
list) ordered cronologically
* all time slices of the second sequence (as given in the
``sequences`` list) ordered cronologically
..
* all time slices of the first output (as given in the
``initial_state`` list) ordered cronologically
* all time slices of the second otuput (as given in the
``initial_state`` list) ordered cronologically
...
* all other parameters over which scan doesn't iterate given in
the same order as in ``non_sequences``
The outputs of these function should have the same order as in the list
``initial_states`
:param sequences: list of Theano variables over which scan needs to
iterate
:param initial_states: list of Theano variables containing the initial
state used for the output. Note that if the function applied recursively
uses only the previous value of the output or none, this initial state
should have same shape as one time step of the output; otherwise, the
initial state should have the same number of dimension as output. This
can easily be understand through an example. For computing y(t) let
assume that we need y(t-1), y(t-2) y(t-4). Through an abuse of notation,
when t = 0, we would need values for y(-1), y(-2) and y(-4). These values
are provided by the initial state of y, which should have same number
of dimension as y, where the first dimension should be 4 in this case.
If init_y is the initial values of y, then init_y[0] corresponds to
y[-4], init_y[1] corresponds to y[-3], init_y[2] corresponds to y[-2],
init_y[3] corresponds to y[-1]. By default, scan is set to use the
last time step for each output.
:param non_sequences: These are parameters used by the recursive function
over which scan shouldn't iterate
:param inplace_map: It is a dictionary where keys are output indexes,
and values are sequence indexes. Assigning to a key a value, means that
the output represented by key will be computed inplace (in the same
memory buffer) as the input represented by the value
:param sequences_taps: At each step you can use different time slices
of sequences, and this dictionary lets you define exactly that. The
keys of the dictionary are sequence indexes, the values are list of
numbers. Having the following entry : i : [t_1,t_2,t_3], means that
at time step k, for sequence x, that has the index i in the list of
sequences, you would use the values x[k+t_1], x[k+t_2], x[k+t_3].
t_1, t_2, t_3 values can be positive or negative. If you do not want
to use any time slice of the sequence you need to give to that entry
the empy list. By default, for each entry the dictionary will contain
the list [0].
:param outputs_taps: This has the same meaning as the parameter
sequences_taps, with the only differences that these taps or for
the outputs, and that they have to be negative (smaller than 0). To
enforce an output to not use any past values, you have the specify
in the dictionary for that entry the emty list, otherwise, by default,
scan will expect to use the last time step
:param n_steps: in case you do not have any sequences over which you want
to iterate, but rather apply some set of computation for a number of
steps, or when you want to restrict to a certain length, you provide
that length as n_steps. It can be a theano scalar or a value.
:param truncate_gradient: if you compute gradients through a scan op,
this can be computed using backpropagation through time. As such you
have the option to truncate the BPTT to a given number of steps (to
increase speed). If set to -1 no truncation is done.
:param go_backwards: This give you the option to move backwards through
your sequences instead of forward
'''
# check if inputs are just single variables instead of lists
if not (type(sequences) in (list, tuple)):
......@@ -77,6 +187,7 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={},
outputs_taps.__delitem__(i)
elif not(type(outputs_taps[i]) in (list,tuple)):
outputs_taps[i] = [outputs_taps[i]]
'''
# update keep_outputs list
for i in xrange(n_outs):
if not keep_outputs.has_key(i):
......@@ -86,6 +197,8 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={},
keep_outputs[i] = True
warning('You need to keep past value of outputs if you use'\
'past taps of output different from -1')
'''
keep_outputs = [ 0 for i in xrange(n_outs)]
......@@ -104,46 +217,15 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={},
else:
args += [init_out[0].type() ]
args += non_seqs
t = fn(*args)
if type(t) in (list,tuple):
if len(t) == 2 :
if (type(t[0]) in (list,tuple,dict)) or (type(t[1]) in (list,tuple,dict)):
t1 = t[0]
t2 = t[1]
else:
t1 = t
t2 = {}
else:
t1 = t
t2 = {}
else:
t1 = t
t2 = {}
# check to see which is the updates list and which is the list of outs
if not ( type(t1) in (list,tuple,dict) ) :
next_outs = [t1]
updates = t2
elif not ( type(t2) in (list,tuple, dict)) :
next_outs = [t2]
updates = t1
elif type(t1) == dict :
next_outs = t2
updates = t1
elif type(t2) == dict :
next_outs = t1
updates = t2
elif type(t1[0]) in (list,tuple):
next_outs = t2
updates = t1
else:
next_outs = t1
updates = t2
args += non_seqs
next_outs = fn(*args)
if not (type(next_outs) in (list,tuple)):
next_outs = [next_outs]
# Create the Scan op object
local_op = Scan( (args,next_outs, updates), n_seqs,n_outs,inplace_map,
sequences_taps, outputs_taps, force_gradient, truncate_gradient,
local_op = Scan( (args,next_outs), n_seqs,n_outs,inplace_map,
sequences_taps, outputs_taps, truncate_gradient,
go_backwards, keep_outputs, mode)
# Call the object on the input sequences, initial values for outs,
......@@ -157,11 +239,37 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={},
class Scan(theano.Op):
def __init__(self,(inputs, outputs, updates),n_seqs, n_outs,
def __init__(self,(inputs, outputs),n_seqs, n_outs,
inplace_map={}, seqs_taps={}, outs_taps={},
force_gradient = False, truncate_gradient = -1,
truncate_gradient = -1,
go_backwards = False, keep_outputs = {},
mode = 'FAST_RUN', inplace=False):
'''
:param (inputs,outputs): inputs and outputs Theano variables that
describe the function that is applied recursively
:param n_seqs: number of sequences over which scan will have to iterate
:param n_outs: number of outputs of the scan op
:param inplace_map: see scan function above
:param seqs_taps: see scan function above
:param outs_taps: see scan function above
:param truncate_gradient: number of steps after which scan should truncate
-1 implies no truncation
:param go_bacwards: see scan funcion above
:param keep_outputs: a list of booleans of same size as the number of
outputs; the value at position ``i`` in the list corresponds to the
``i-th`` output, and it tells how many steps (from the end towards
the begining) of the outputs you really need and should return;
given this information, scan can know (if possible) to allocate only
the amount of memory needed to compute that many entries
'''
# check inplace map
......@@ -209,14 +317,10 @@ class Scan(theano.Op):
self.inplace = inplace
self.inputs = inputs
self.outputs = outputs
self.updates = updates
self.force_gradient = force_gradient
self.truncate_gradient = truncate_gradient
self.go_backwards = go_backwards
self.fn = theano.function(inputs,outputs, \
updates = updates, mode = mode)
self.fn = theano.function(inputs,outputs, mode = mode)
g_y = [outputs[0].type()]
def compute_gradient(y, g_y):
......@@ -255,10 +359,10 @@ class Scan(theano.Op):
if not (inputs[i] == []):
if self.outs_taps.has_key(i-1-self.n_seqs) and \
(self.outs_taps[i-self.n_seqs-1]==[-1]) and \
(self.keep_outputs[i-1-self.n_seqs]):
(self.keep_outputs[i-1-self.n_seqs] != 1):
out_types += [theano.tensor.Tensor(dtype=inputs[i].dtype, \
broadcastable=(False,)+inputs[i].broadcastable)()]
elif not self.keep_outputs[i-1-self.n_seqs]:
elif not self.keep_outputs[i-1-self.n_seqs] == 1:
out_types += [ inputs[i].type()]
else:
out_types += [theano.tensor.Tensor(dtype=inputs[i].dtype,\
......@@ -274,7 +378,6 @@ class Scan(theano.Op):
if rval:
rval = (self.inputs == other.inputs) and \
(self.outputs == other.outputs) and \
(self.updates == other.updates) and \
(self.keep_outputs == other.keep_outputs) and \
(self.g_ins == other.g_ins) and \
(self.g_outs == other.g_outs) and \
......@@ -285,7 +388,6 @@ class Scan(theano.Op):
(self.inplace == other.inplace) and\
(self.go_backwards == other.go_backwards) and\
(self.truncate_gradient == other.truncate_gradient) and\
(self.force_gradient == other.force_gradient) and\
(self.n_outs == other.n_outs) and\
(self.n_args == other.n_args)
return rval
......@@ -295,7 +397,6 @@ class Scan(theano.Op):
return hash(type(self)) ^ \
hash(self.n_seqs) ^ \
hash(self.n_outs) ^ \
hash(self.force_gradient) ^\
hash(self.inplace) ^\
hash(self.go_backwards) ^\
hash(self.truncate_gradient) ^\
......@@ -306,7 +407,6 @@ class Scan(theano.Op):
hash_listsDictsTuples(self.g_outs) ^ \
hash_listsDictsTuples(self.seqs_taps) ^\
hash_listsDictsTuples(self.outs_taps) ^\
hash_listsDictsTuples(self.updates) ^\
hash_listsDictsTuples(self.keep_outputs)
......@@ -380,10 +480,14 @@ class Scan(theano.Op):
if inplace_map.has_key(i) and (inplace_map[i] >= 0):
y += [args[inplace_map[i]]]
else:
if self.keep_outputs[i]:
if self.keep_outputs[i] < 1 :
y_shape = (n_steps,)+args[i+n_seqs].shape[1:]
else:
elif self.keep_outputs[i] == 1:
y_shape = args[i+n_seqs].shape[1:]
else:
y_shape = (self.keep_outputs[i],)+args[i+n_seqs].shape[1:]
y += [numpy.empty(y_shape,
dtype=args[i+n_seqs].dtype)]
seqs_mins = {}
......@@ -433,32 +537,37 @@ class Scan(theano.Op):
else:
fn_args += [args[j+n_seqs][k]]
else:
if self.keep_outputs[j]:
if self.keep_outputs[j] < 1:
fn_args += [y[j][i + tap_value]]
else:
elif self.keep_outputs[j] == 1:
fn_args += [y[j] ]
else:
raise NotImplementedError('in the near future')
# get the non-iterable sequences
fn_args += list(args[(n_seqs+n_outs):])
# compute output
something = fn(*fn_args)
#update outputs
for j in xrange(n_outs):
if self.keep_outputs[j]:
if self.keep_outputs[j] <1:
y[j][i] = something[j]
else:
elif self.keep_outputs[j] == 1:
y[j] = something[j]
else:
raise NotImplementedError('in the near future')
return y
def grad(self, args, g_outs):
if (not self.force_gradient) and \
((self.updates.keys() != []) or (self.inplace_map.keys() != [])\
or numpy.any(self.keep_outputs)):
warning('Can not compute gradients if inplace or updates ' \
'are used or if you do not keep past value of outputs.'\
'Use force_gradient if you know for sure '\
'that the gradient can be computed automatically.')
return [None for i in args]
if True:
#((self.updates.keys() != []) or (self.inplace_map.keys() != [])\
# or numpy.any(self.keep_outputs)):
# warning('Can not compute gradients if inplace or updates ' \
# 'are used or if you do not keep past value of outputs.'\
# 'Use force_gradient if you know for sure '\
# 'that the gradient can be computed automatically.')
warning('Gradient not fully tested yet !')
return [None for i in args]
else:
# forward pass
y = self(*args)
......@@ -495,7 +604,7 @@ def scan_make_inplace(node):
op = node.op
if isinstance(op, Scan) and (not op.inplace) \
and (op.inplace_map.keys() != []):
return Scan((op.inputs, op.outputs, op.updates), op.n_seqs, \
return Scan((op.inputs, op.outputs) , op.n_seqs, \
op.n_outs, op.inplace_map, op.seqs_taps, op.outs_taps, \
op.force_gradient, op.truncate_gradient, \
op.go_backwards, inplace=True \
......
......@@ -95,7 +95,7 @@ class T_Scan(unittest.TestCase):
# non sequence arguments
def test_1(self):
def f_pow2(x_tm1):
return (2*x_tm1, {})
return 2*x_tm1
s = theano.tensor.dscalar()
n_steps = theano.tensor.dscalar()
......@@ -109,7 +109,7 @@ class T_Scan(unittest.TestCase):
# vectors, weights are scalars
def test_2(self):
def f_rnn(u_t,x_tm1,W_in, W):
return (u_t*W_in+x_tm1*W, {})
return u_t*W_in+x_tm1*W
u = theano.tensor.dvector()
x0 = theano.tensor.dscalar()
......@@ -134,7 +134,7 @@ class T_Scan(unittest.TestCase):
W = theano.shared(1., name ='w')
def f_rnn_shared(u_t,x_tm1):
return (u_t*W_in+x_tm1*W, {})
return u_t*W_in+x_tm1*W
Y = theano.sandbox.scan.scan(f_rnn_shared, u,x0,[])
......@@ -177,7 +177,7 @@ class T_Scan(unittest.TestCase):
assert( compareArrays(x,v_x))
assert( compareArrays(y,v_y))
'''
# basic ESN using updates
def test_5(self):
W_in = theano.shared(numpy.array([1.,1.]), name='win')
......@@ -225,7 +225,7 @@ class T_Scan(unittest.TestCase):
out = f6(v_u, v_y0)
assert( compareArrays(out, v_out))
'''
# simple rnn, one input, one state, weights for each; input/state are
# vectors, weights are scalars; using shared variables and past
# taps (sequences and outputs)
......@@ -237,7 +237,7 @@ class T_Scan(unittest.TestCase):
W = theano.shared(1., name ='w')
def f_rnn_shared(u_tm2, x_tm1, x_tm2):
return (u_tm2*W_in+x_tm1*W+x_tm2, {})
return u_tm2*W_in+x_tm1*W+x_tm2
Y = theano.sandbox.scan.scan(f_rnn_shared, u,x0, [], \
sequences_taps = {0:[-2]}, outputs_taps = {0:[-1,-2]})
......@@ -259,7 +259,7 @@ class T_Scan(unittest.TestCase):
W = theano.shared(1., name ='w')
def f_rnn_shared(u_tm2,u_tp2, x_tm1, x_tm2):
return ((u_tm2+u_tp2)*W_in+x_tm1*W+x_tm2, {})
return (u_tm2+u_tp2)*W_in+x_tm1*W+x_tm2
Y = theano.sandbox.scan.scan(f_rnn_shared, u,x0, [], \
sequences_taps = {0:[-2,2]}, outputs_taps = {0:[-1,-2]})
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论