提交 cfaf9c94 authored 作者: James Bergstra's avatar James Bergstra

scan - several changes to implementation (during discussion with Razvan). interface remains.

上级 db59d113
...@@ -26,15 +26,13 @@ The Scan Op should typically be used by calling the ``scan()`` function. ...@@ -26,15 +26,13 @@ The Scan Op should typically be used by calling the ``scan()`` function.
__docformat__ = 'restructedtext en' __docformat__ = 'restructedtext en'
import theano import theano
from theano.tensor import opt from theano.tensor import opt, TensorType
from theano import gof from theano import gof, Apply
from theano.compile import optdb from theano.compile import optdb
import theano.tensor.shared_randomstreams as shared_random import theano.tensor.shared_randomstreams as shared_random
import numpy import numpy
# Logging function for sending warning or info # Logging function for sending warning or info
import logging import logging
_logger = logging.getLogger('theano.scan') _logger = logging.getLogger('theano.scan')
...@@ -62,7 +60,40 @@ def hash_listsDictsTuples(x): ...@@ -62,7 +60,40 @@ def hash_listsDictsTuples(x):
pass pass
return hash_value return hash_value
def scan(fn, sequences, initial_states, non_sequences, inplace_map={}, \ def _map(fn, sequences, non_sequences=[]):
#TODO
#UGLY HACK: instead of figuring out how many outputs there are, we
# will assume there are less than 100 of them
return scan(fn, sequences=sequences,
outputs_taps=dict([(i,[]) for i in xrange(100)]))
# CONSIDER ALTERNATE CALLING CONVENTIONS:
# simple:
# scan(fn, [a,b], [c])
# complex:
# scan(fn, [dict(input=a, taps=[0,-1,-2]), b], [dict(initial=c, taps=[-1,-3], inplace=a)])
#
#
# So for example, if we wanted a scan that took a window of 3 inputs, and produced
# x - a sequence that we need one previous value of, and only need to return the last value;
# y - a sequence that we need no previous values of;
# z - a sequence that we need two previous values of
# and we want z to be computed inplace using the storage of 'a'.
#
# scan(fn, [dict(input=a, taps=[-1,0,1])],
# [dict(initial=x_init, taps=[-1], ????????),
# None
# dict(initial=z_init, taps=[-2,-1], inplace=a,)])
#
# QUESTION:
# If the larger (in absolute values) the sequence_taps, the shorter the output
# right? If the sequence_taps = {0: [-10, 10]}, and I pass an input with 22
# rows, then the scan will output something of length <=2 right?
#
def scan(fn, sequences=[], initial_states=[], non_sequences=[], inplace_map={}, \
sequences_taps={}, outputs_taps = {}, n_steps = 0, \ sequences_taps={}, outputs_taps = {}, n_steps = 0, \
truncate_gradient = -1, go_backwards = False, truncate_gradient = -1, go_backwards = False,
mode = None): mode = None):
...@@ -204,7 +235,7 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={}, \ ...@@ -204,7 +235,7 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={}, \
# compute number of sequences and number of seqs # compute number of sequences and number of seqs
n_seqs = len(seqs) n_seqs = len(seqs)
n_outs = len(init_outs) n_init_outs = len(init_outs)
# update sequences_taps[idx] to contain 0 if it is not defined # update sequences_taps[idx] to contain 0 if it is not defined
for i in xrange(n_seqs): for i in xrange(n_seqs):
...@@ -216,14 +247,13 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={}, \ ...@@ -216,14 +247,13 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={}, \
elif not (type(sequences_taps[i]) in (list,tuple)): elif not (type(sequences_taps[i]) in (list,tuple)):
sequences_taps[i] = [sequences_taps[i]] sequences_taps[i] = [sequences_taps[i]]
# update outputs_taps[idx] to contain -1 if it is not defined # update outputs_taps[idx] to contain -1 if it is not defined
for i in xrange(n_outs): for i in xrange(n_init_outs):
if not outputs_taps.has_key(i): if not outputs_taps.has_key(i):
outputs_taps.update({i:[-1]}) outputs_taps.update({i:[-1]})
elif outputs_taps[i] == []: elif outputs_taps[i] == []:
outputs_taps.__delitem__(i) outputs_taps.__delitem__(i)
elif not(type(outputs_taps[i]) in (list,tuple)): elif not(type(outputs_taps[i]) in (list,tuple)):
outputs_taps[i] = [outputs_taps[i]] outputs_taps[i] = [outputs_taps[i]]
stored_steps_output = [ 0 for i in xrange(n_outs)]
...@@ -299,10 +329,9 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={}, \ ...@@ -299,10 +329,9 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={}, \
theano.compile.mode.Mode(linker = 'py', optimizer = None) ) theano.compile.mode.Mode(linker = 'py', optimizer = None) )
ls_outputs = [ sout.variable for sout in dummy_f.maker.outputs] ls_outputs = [ sout.variable for sout in dummy_f.maker.outputs]
update_map = {} update_map = {}
n_actual_outs = n_outs n_actual_outs = len(dummy_f.maker.outputs)
shared_outs = [] shared_outs = []
shared_non_seqs = [] shared_non_seqs = []
givens = {} givens = {}
...@@ -310,7 +339,11 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={}, \ ...@@ -310,7 +339,11 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={}, \
ls_inputs=[inp.variable for inp in \ ls_inputs=[inp.variable for inp in \
dummy_f.maker.expanded_inputs[:_ins+_outs]] dummy_f.maker.expanded_inputs[:_ins+_outs]]
fromIdx = _ins + _outs fromIdx = _ins + _outs
stored_steps_output = [ 0 for i in xrange(n_actual_outs)]
# add shared variable that act as outputs # add shared variable that act as outputs
#
n_outs = n_actual_outs
for inp in dummy_f.maker.expanded_inputs[fromIdx:] : for inp in dummy_f.maker.expanded_inputs[fromIdx:] :
if isinstance(inp.variable, theano.compile.SharedVariable) and inp.update: if isinstance(inp.variable, theano.compile.SharedVariable) and inp.update:
ls_inputs.append(inp.variable.type()) ls_inputs.append(inp.variable.type())
...@@ -363,6 +396,10 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={}, \ ...@@ -363,6 +396,10 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={}, \
class Scan(theano.Op): class Scan(theano.Op):
#
# OLD DOCUMENTATION CAN BE FOUND NEAR REVISION 2581
#
def __init__(self,(inputs, outputs, givens),n_seqs, n_outs, def __init__(self,(inputs, outputs, givens),n_seqs, n_outs,
inplace_map={}, seqs_taps={}, outs_taps={}, inplace_map={}, seqs_taps={}, outs_taps={},
truncate_gradient = -1, truncate_gradient = -1,
...@@ -418,12 +455,22 @@ class Scan(theano.Op): ...@@ -418,12 +455,22 @@ class Scan(theano.Op):
#check outputs past taps #check outputs past taps
for k,v in outs_taps.iteritems(): for k,v in outs_taps.iteritems():
if k > n_outs: if k > n_outs:
raise ValueError(('Sequences past taps dictionary reffers to ' raise ValueError(('Output past taps dictionary reffers to '
'an unexisting sequence %d')%k) 'an unexisting sequence %d')%k)
if max(v) > -1: if v and (max(v) > -1):
raise ValueError(('Can not require future value %d of output' \ raise ValueError(('Can not require future value %d of output' \
' %d')%(k,max(v))) ' %d')%(k,max(v)))
# build a list of output types for any Apply node using this op.
self.apply_output_types = []
for i, o in enumerate(outputs):
if 1 == stored_steps_output[i]:
self.apply_output_types.append(o.type)
else:
expanded_otype = TensorType(
broadcastable=(False,)+o.type.broadcastable,
dtype=o.type.dtype)
self.apply_output_types.append(expanded_otype)
self.destroy_map = {} self.destroy_map = {}
...@@ -448,92 +495,14 @@ class Scan(theano.Op): ...@@ -448,92 +495,14 @@ class Scan(theano.Op):
self.fn = theano.function(inputs,outputs, mode = mode, givens = givens) self.fn = theano.function(inputs,outputs, mode = mode, givens = givens)
def make_node(self,*inputs): def make_node(self,*inputs):
n_args = len(inputs) assert all(isinstance(i, theano.Variable) for i in inputs)
if n_args < self.n_args : return Apply(self, inputs, [t() for t in self.apply_output_types])
err = 'There should be at least '+str(self.n_args)+ 'arguments'
raise ValueError(err)
# return a new variable of same type and same shape
def new_same_dim(var):
try:
nw_var = theano.tensor.as_tensor_variable(var)
return nw_var.type()
except TypeError:
if isinstance(var, shared_random.RandomStateSharedVariable):
return var.type()
else:
raise TypeError("Could not convert %s to suitable type"%var,
type(var))
# return a new variable of same type but with an extra dimension
def new_add_one_dim(var):
nw_var = theano.tensor.as_tensor_variable(var)
return theano.tensor.Tensor( dtype = nw_var.dtype, \
broadcastable = (False,)+nw_var.broadcastable)()
def new_replace_one_dim(var):
nw_var = theano.tensor.as_tensor_variable(var)
return theano.tensor.Tensor( dtype = nw_var.dtype, \
broadcastable = (False,)+nw_var.broadcastable[1:])()
def new_remove_one_dim(var):
nw_var = theano.tensor.as_tensor_variable(var)
return theano.tensor.Tensor( dtype = nw_var.dtype, \
broadcastable = nw_var.broadcastable[1:])()
# Create list of output datatypes
out_types = []
for i in xrange(self.n_seqs+1, self.n_seqs+self.n_outs+1):
out_idx = i - 1 - self.n_seqs
if not (inputs[i] == []):
## CASES :
# outs_taps[i] == [-1] or == [] => inputs[i] no extra dim
# outs_taps anything else => inputs[i] remove one dim
#
# stored_steps_outputs = 1 ==> outs no extra dim
# anything else --> needs extra dim
sw_inputs = self.outs_taps.get(out_idx, [-1]) == [-1]
sw_outputs = self.stored_steps_output[out_idx] == 1
if sw_inputs:
if sw_outputs:
# You need to output something identical to the
# input.. which can even be a non tensor
out_types += [ new_same_dim(inputs[i]) ]
else:
# You need to output a list of things identical to
# the input .. (here we force it to be a tensor )
out_types += [ new_add_one_dim(inputs[i]) ]
else:
if sw_outputs:
# your input has one dimension more, so you need
# to strip it by its first dimension
out_types += [new_remove_one_dim(inputs[i])]
else:
# input and output have the same # of dimensions,
# just that you need to "refresh" the first one
# this is important only in the corner case that
# the first dimension of the input is 1, in which
# case the output broadcastable pattern does not
# match the input broadcastable pattern
#
# Note that this should in practice never happen !!
# I add it here just for safety
out_types += [new_replace_one_dim(inputs[i])]
else:
raise ValueError(('You need to provide initial state for outputs'
' such that scan can infer what dataype they are'))
return theano.Apply(self,inputs, out_types)
def __eq__(self,other): def __eq__(self,other):
# the self.apply_output_types are a function of all these things
# no need to compare it as well
rval = type(self) == type(other) rval = type(self) == type(other)
if rval: if rval:
rval = (self.inputs == other.inputs) and \ rval = (self.inputs == other.inputs) and \
...@@ -553,6 +522,8 @@ class Scan(theano.Op): ...@@ -553,6 +522,8 @@ class Scan(theano.Op):
def __hash__(self): def __hash__(self):
# the self.apply_output_types are a function of all these things
# no need to compare it as well
return hash(type(self)) ^ \ return hash(type(self)) ^ \
hash(self.n_seqs) ^ \ hash(self.n_seqs) ^ \
hash(self.n_outs) ^ \ hash(self.n_outs) ^ \
...@@ -571,6 +542,26 @@ class Scan(theano.Op): ...@@ -571,6 +542,26 @@ class Scan(theano.Op):
def perform(self,node,args, outs): def perform(self,node,args, outs):
"""
The args are packed like this:
n_steps
X sequence inputs x_1, x_2, ... x_<self.n_seqs>
Y initial states (u_1, u_2, ... u_<self.n_outs>) for our outputs. Each must have appropriate length (T_1, T_2, ..., T_Y).
W other inputs w_1, w_2, ... w_W
There are at least 1 + self.n_seqs + self.n_outs inputs, and the ones above this number
are passed to the scanned function as non-sequential inputs.
The outputs are more straightforward:
Y sequence outputs y_1, y_2, ... y_<self.n_outs>
"""
n_steps = 0 n_steps = 0
if (self.n_seqs ==0 ) and (args[0] == 0): if (self.n_seqs ==0 ) and (args[0] == 0):
raise ValueError('Scan does not know over how many steps it ' raise ValueError('Scan does not know over how many steps it '
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论