提交 d850cd95 authored 作者: rman@rpad's avatar rman@rpad

some tests and bugs fixed to the new scan interface

上级 16ff6ebb
import numpy
import theano
import theano.sandbox.scan
# generator network, only one output , type scalar ; no sequence or
# non sequence arguments
def test_1():
def f_pow2(x_tm1):
return (2*x_tm1, {})
s = theano.tensor.dvector()
n_steps = theano.tensor.dscalar()
Y = theano.sandbox.scan.scan(f_pow2, [],s, [],n_steps = n_steps)
f1 = theano.function([s,n_steps], Y)
assert( numpy.any(f1([1],3)== [2,4,8]) )
# simple rnn, one input, one state, weights for each; input/state are
# vectors, weights are scalars
def test_2():
def f_rnn(u_t,x_tm1,W_in, W):
return (u_t*W_in+x_tm1*W, {})
u = theano.tensor.dvector()
x0 = theano.tensor.dvector()
W_in = theano.tensor.dscalar()
W = theano.tensor.dscalar()
Y = theano.sandbox.scan.scan(f_rnn, u,x0,[W_in,W])
f2 = theano.function([u,x0,W_in,W], Y)
assert(numpy.any(f2([1,2,3,4],[1],.1,1)== numpy.array([1.1,1.3,1.6,2.])))
# simple rnn, one input, one state, weights for each; input/state are
# vectors, weights are scalars; using shared variables
def test_3():
u = theano.tensor.dvector()
x0 = theano.tensor.dvector()
W_in = theano.shared(.1, name = 'w_in')
W = theano.shared(1., name ='w')
def f_rnn_shared(u_t,x_tm1):
return (u_t*W_in+x_tm1*W, {})
Y = theano.sandbox.scan.scan(f_rnn_shared, u,x0,[])
f3 = theano.function([u,x0], Y)
assert(numpy.any(f3([1,2,3,4],[1])== numpy.array([1.1,1.3,1.6,2.])))
# some rnn with multiple outputs and multiple inputs; other dimension
# instead of scalars/vectors
def test_4():
W_in2 = theano.shared(numpy.array([1.,2.]), name='win2')
W = theano.shared(numpy.array([[2.,1.],[1.,1.]]), name='w')
W_out = theano.shared(numpy.array([.5,1.]), name = 'wout')
W_in1 = theano.tensor.dmatrix('win')
u1 = theano.tensor.dmatrix('u1')
u2 = theano.tensor.dvector('u2')
x0 = theano.tensor.dmatrix('x0')
y0 = theano.tensor.dvector('y0')
## Why dot doesn;t work with scalars !??
## Why * doesn't support SharedVariable and TensorVariable
def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, W_in1):
return ({}, [theano.dot(u1_t,W_in1) + u2_t* W_in2 + \
theano.dot(x_tm1, W), theano.dot(x_tm1, W_out)])
Y = theano.sandbox.scan.scan(f_rnn_cmpl,[u1,u2],[x0,y0],W_in1)
f4 = theano.function([u1,u2,x0,y0,W_in1], Y)
(x,y) = f4( numpy.array([[1,2],[1,2],[1,2]]), \
numpy.array([1,2,3]), \
numpy.array([[0,0]]), \
numpy.array([1]), \
numpy.array([[1,1],[1,1]]))
assert( numpy.all(x == numpy.array([[4.,5.],[18.,16.],[58.,43.]])))
assert( numpy.all(y == numpy.array([0.,7.,25.])))
# basic ESN using updates
def test_5():
W_in = theano.shared(numpy.array([1.,1.]), name='win')
W = theano.shared(numpy.array([[.1,0.],[.0,.1]]),name='w')
W_out= theano.shared(numpy.array([.5,1.]), name='wout')
u = theano.tensor.dvector('u')
x = theano.shared(numpy.array([0.,0.]),'x')
y0 = theano.tensor.dvector('y0')
def f_ESN(u_t):
return ( theano.dot(x,W_out), \
{ x: W_in*u_t + theano.dot(x,W) } )
Y = theano.sandbox.scan.scan(f_ESN,u,y0,[],outputs_taps={0:[]})
f5 = theano.function([u,y0],Y)
assert( f5( numpy.array([1,2,3]), numpy.array([0])) == \
numpy.array([0.,1.4,3.15]))
# basic ESN using updates ; moving backwards
def test_6():
W_in = theano.shared(numpy.array([1.,1.]), name='win')
W = theano.shared(numpy.array([[.1,0.],[.0,.1]]),name='w')
W_out= theano.shared(numpy.array([.5,1.]), name='wout')
u = theano.tensor.dvector('u')
x = theano.shared(numpy.array([0.,0.]),'x')
y0 = theano.tensor.dvector('y0')
def f_ESN(u_t):
return ( theano.dot(x,W_out), \
{ x: W_in*u_t + theano.dot(x,W) } )
Y = theano.sandbox.scan.scan(f_ESN,u,y0,[],outputs_taps={0:[]}, \
go_backwards = True)
f6 = theano.function([u,y0],Y)
assert( f6( numpy.array([1,2,3]), numpy.array([0])) == \
numpy.array([0., 4.5, 3.45]))
'''
TO TEST:
- test taps (for sequences and outputs )
- test gradient (one output)
- test gradient (multiple outputs)
- test gradient (go_bacwards)
- test gradient (multiple outputs / some uncomputable )
- test gradient (truncate_gradient)
- test gradient (force_gradient)
- test inplace map
'''
if __name__=='__main__':
test_1()
test_2()
test_3()
test_4()
test_5()
test_6()
"""Provide Scan and related functions
Scanning a function over sequential input(s) producing sequential output(s).
Scanning is a general form of recurrence, which can be used for looping.
The idea is that you 'scan' a function along some input sequence, producing
an output at each time-step that can be seen (but not modified) by the
function at the next time-step. (Technically, the function can see the
previous K time-steps.)
So for example, ``sum()`` could be computed by scanning the ``z+x_i``
function over a list, given an initial state of ``z=0``.
Special cases:
- A ``reduce()`` operation can be performed by returning only the last
output of a scan.
- A ``map()`` operation can be performed by applying a function that
ignores each previous output.
Often a for loop can be expressed as a scan() operation, and scan is the
closest that theano comes to looping.
This module provides scanning functionality with the `Scan` Op.
"""
__docformat__ = 'restructedtext en'
import numpy
import theano
from theano.tensor import opt
......@@ -49,7 +18,7 @@ def info(*msg):
def hash_list(list):
hash_value = 0
for v in list:
hash_value ^= v
hash_value ^= hash(v)
return hash_value
......@@ -57,137 +26,54 @@ def hash_list(list):
# as values either numbers or list of numbers
def hash_dict(dictionary):
hash_value = 0
for k,v in dictionary,iteritems():
for k,v in dictionary.iteritems():
# hash key
hash_value ^= k
hash_value ^= hash(k)
if type(v) in (list,tuple):
hash_value ^= hash_list(v)
else:
hash_value ^= v
hash_value ^= hash(v)
return hash_value
def scan(fn, sequnces, non_sequences, seed_values, inplace_map={},
def scan(fn, sequences, initial_states, non_sequences, inplace_map={},
sequences_taps={}, outputs_taps = {},
len = theano.tensor.zero(), force_gradient = False,
n_steps = theano.tensor.zero(), force_gradient = False,
truncate_gradient = -1, go_backwards = False, mode = 'FAST_RUN'):
'''The function creates a more intuitive interface to the scan op.
This function first creates a scan op object, and afterwards applies it
to the input data. The scan operation iterates over X sequences producing
Y outputs. The function that is applied recursively may consult several
previous outputs from the past as well as past values and future values
of the input. You can see it as havin the inputs :
X sequences inptus x_1, x_2, .. x_X
Y seeds/initial values ( u_1, u_2, .. u_Y) for the outputs
W non sequences inputs w_1, w_2, .. w_W
Outputs :
Y sequence outputs y_1, y_2, .. y_Y
Each otuput y_j computed one time step at a time according to the
formula:
.. code-block:: python
(y_1[t], y_2[t], .. y_Y[t]) = f(
x_1[t-K_1],.. x_1[t],x_1[t+1],.. x_1[t+L_1], # x_1 past and future
#values
x_2[t-K-2],.. x_2[t],x_2[t+1],.. x_2[t+L_2], # x_2 past and future
# values
... # ...
y_1[t-1], y_1[t-2], .. y[t - T_1], # past values of y_1
y_2[t-1], y_2[t-2], .. y[t - T_2],, # past values of y_2
...
w_1, w_2, .., w_W) # 'timeless' inputs
:param fn: fn is a lambda expression or a function that given a list of
symbolic inputs returns the update list and symbolic outputs list of the
function that shall be applied recursively.
:param sequences:list of sequences over which the scan op should iterate;
sequnces length should also cover past and future taps; for example if
you also use for a sequence the past tap -3 and future tap +4, to total
length should be n+7, where first 3 values of sequence are those
corresponding to -3 -2 -1 and the last 4 values correspond to n+1 n+2
n+3 and n+4
:param non_sequences: list of inputs over which it shouldn't iterate
:param seed_values: seeds (initial values) of the outputs; if past taps
are this seeds should contain enough values to cover this past values;
note that index 0 of a seed belongs to the largest past tap
:param inplace_map: a dictionary telling which output should be
computed in place of which input sequence ; input sequence has to be
of the same shape as the output
:param sequence_taps: a dictionary telling for each sequence what past
and future taps it should use; past values should be negative, future
taps positives; by default 0 is added in this dictionary (current value)
if nothing is provided
:param outputs_taps: a dictionary telling for each output what past
taps it should use (negative values); by default -1 is added to this
dictionary if nothing is provided
:param len: a value (or theano scalar) describing for how many steps
the scan should iterate; 0 means that it should iterate over the entire
length of the input sequence(s)
:param force_gradient: a flag telling scan op that the gradient can be
computed even though inplace or updates are used - use this on your own
risk
:param truncate_gradient: tells for how many steps should scan go
back in time on the backward pass of backpropagation through time
:param go_backwards: a flag indicating if scan should iterate back from
the end of the sequence to the begining (if it is true) or from 0 to
the end
:param mode: indicates the mode that should be used to compile the
function that will be applied recursively
'''
# check if inputs are just single variables instead of lists
if not (type(sequences) in (list, tuple)):
seqs = [sequences]
elif seqs = sequences
else:
seqs = sequences
if not type(seed_values) in (list,tuple)):
seeds = [seed_values]
elif
seeds = seed_values
if not (type(initial_states) in (list,tuple)):
init_outs = [initial_states]
else:
init_outs = initial_states
if not (type(non_sequences) in (list,tuple)):
non_seqs = [non_sequences]
elif
else:
non_seqs = non_sequences
# compute number of sequences and number of seeds
# compute number of sequences and number of seqs
n_seqs = len(seqs)
# see if there are outputs that do not feed anything back to the function
# applied recursively
outs_tapkeys = outputs_taps.keys()
for k in outs_tapkeys.sort():
if outputs_taps[k] == []
# add empty lists where you have outputs that do not have past
# values
seeds = seeds[:k] + [[]] + seeds[k:]
#outs_tapkeys = outputs_taps.keys()
#outs_tapkeys.sort()
#for k in outs_tapkeys:
# if outputs_taps[k] == []:
# # add empty lists where you have outputs that do not have past
# # values
# init_outs = init_outs[:k] + [[]] + init_outs[k:]
n_seeds = len(seeds)
n_outs = len(init_outs)
# update sequences_taps[idx] to contain 0 if it is not defined
......@@ -197,93 +83,79 @@ def scan(fn, sequnces, non_sequences, seed_values, inplace_map={},
# if input sequence is not actually used by the recursive function
elif sequences_taps[i] == []:
sequences_taps.__delitem__(i)
elif not (sequences_taps[i] in (list,tuple)):
elif not (type(sequences_taps[i]) in (list,tuple)):
sequences_taps[i] = [sequences_taps[i]]
# update outputs_taps[idx] to contain -1 if it is not defined
for i in xrange(n_seeds):
for i in xrange(n_outs):
if not outputs_taps.has_key(i):
outputs_taps.update({i:-1})
outputs_taps.update({i:[-1]})
# if output sequence is not actually used as input to the recursive
# function
elif outputs_taps[i] == []:
outputs_taps.__delitem__(i)
elif not(outputs_taps[i] in (list,tuple)):
elif not(type(outputs_taps[i]) in (list,tuple)):
outputs_taps[i] = [outputs_taps[i]]
# create theano inputs for the recursive function
args = []
for (i,seq) in enumerate(seqs):
if sequences_taps.has_key(i):
for k in len(sequences_taps[i]):
for k in xrange(len(sequences_taps[i])):
args += [seq[0].type() ]
for (i,seed) in enumerate(seeds):
for (i,init_out) in enumerate(init_outs):
if outputs_taps.has_key(i):
for k in len(outputs_taps[i]):
args += [seed[0].type() ]
for k in xrange(len(outputs_taps[i])):
args += [init_out[0].type() ]
args += non_seqs
next_outs, updates = fn(*args)
t1,t2 = fn(*args)
# check to see which is the updates list and which is the list of outs
if not ( (type(t1) in (list,tuple)) or (type(t1) == dict)) :
next_outs = [t1]
updates = t2
elif not ( (type(t2) in (list,tuple)) or (type(t2) == dict)) :
next_outs = [t2]
updates = t1
elif type(t1) == dict :
next_outs = t2
updates = t1
elif type(t2) == dict :
next_outs = t1
updates = t2
elif type(t1[0]) in (list,tuple):
next_outs = t2
updates = t1
else:
next_outs = t1
updates = t2
# Create the Scan op object
local_op = Scan( (args,next_outs, updates), n_seqs,n_seeds,inplace_map,
local_op = Scan( (args,next_outs, updates), n_seqs,n_outs,inplace_map,
sequences_taps, outputs_taps, force_gradient, truncate_gradient,
go_backwards, mode)
# Call the object on the input sequences, seeds, and non sequences
return local_op( *( [thenao.tensor.as_tensor(len)] \
# Call the object on the input sequences, initial values for outs,
# and non sequences
return local_op( *( [theano.tensor.as_tensor(n_steps)] \
+ seqs \
+ seeds \
+ init_outs \
+ non_seqs))
''' The class implementing the scan op
The actual class. I would not recommend using it directly unless you really
know what you are doing'
'''
class Scan(theano.Op):
def __init__(self,(inputs, outputs, updates),n_seqs, n_seeds,
def __init__(self,(inputs, outputs, updates),n_seqs, n_outs,
inplace_map={}, seqs_taps={}, outs_taps={},
force_gradient = False, truncate_gradient = -1,
go_backwards = False, inplace=False):
'''
:param inputs: list of symbolic inputs of the function that will
be applied recursively
:param outputs: list of symbolic outputs for the function applied
recursively
:param updates: list of updates for the function applied recursively
:param n_seqs: number of sequences in the input over which it needs
to iterate
:param n_seeds: number of outputs (same as the number of seeds)
:param inplace_map: dictionary discribing which output should be
computed inplace of which input
:param seqs_taps: dictionary discribing which past and future taps
of the input sequences are used by the recursive function
:param outs_taps: dictionary discribing which past taps of the
outputs the recursive function is using
:param force_gradient: a flag indicating if the gradient is still
computable even though inplace operation or updates are used
:param truncate_gradient: if different from -1 it tells after how
many steps in the backward pass of BPTT
'''
go_backwards = False, mode = 'FAST_RUN', inplace=False):
# check inplace map
for _out,_in in inplace_map.iteritems():
if _out > n_seeds:
if _out > n_outs:
raise ValueError(('Inplace map reffers to an unexisting'\
'output %d')% _out)
if _in > n_seqs:
......@@ -295,19 +167,19 @@ class Scan(theano.Op):
#check sequences past taps
for k,v in seqs_taps.map_iteritems():
for k,v in seqs_taps.iteritems():
if k > n_seqs:
raise ValueError(('Sequences past taps dictionary reffers to '
'an unexisting sequence %d')%k)
#check outputs past taps
for k,v in outs_taps.map_iteritems():
if k > n_seeds:
for k,v in outs_taps.iteritems():
if k > n_outs:
raise ValueError(('Sequences past taps dictionary reffers to '
'an unexisting sequence %d')%k)
if max(v) > -1:
raise ValueError(('Can not require future value %d of output'
'%d')%(k,max(v)))
raise ValueError(('Can not require future value %d of output' \
' %d')%(k,max(v)))
......@@ -318,8 +190,8 @@ class Scan(theano.Op):
self.seqs_taps = seqs_taps
self.outs_taps = outs_taps
self.n_seqs = n_seqs
self.n_seeds = n_seeds
self.n_args = n_seqs+n_seeds+1
self.n_outs = n_outs
self.n_args = n_seqs+n_outs+1
self.inplace_map = inplace_map
self.inplace = inplace
self.inputs = inputs
......@@ -328,8 +200,7 @@ class Scan(theano.Op):
self.force_gradient = force_gradient
self.truncate_gradient = truncate_gradient
self.go_backwards = go_backwards
self.fn = theano.function(inputs,outputs, \
updates = updates, mode = mode)
......@@ -355,9 +226,13 @@ class Scan(theano.Op):
# Create list of output datatypes
out_types = []
for i in xrange(self.n_seqs+1, self.n_seqs+self.n_seeds+1):
out_types += [theano.tensor.Tensor(dtype=inputs[i].dtype,\
for i in xrange(self.n_seqs+1, self.n_seqs+self.n_outs+1):
if not (inputs[i] == []):
out_types += [theano.tensor.Tensor(dtype=inputs[i].dtype,\
broadcastable=(False,)+inputs[i].broadcastable[1:])()]
else:
raise ValueError(('You need to provide initial state for outputs'
' such that scan can infer what dataype they are'))
return theano.Apply(self,inputs, out_types)
......@@ -376,15 +251,15 @@ class Scan(theano.Op):
(self.inplace == other.inplace) and\
(self.go_backwards == other.go_backwards) and\
(self.truncate_gradient == other.truncate_gradient) and\
(self.force_gradient = other.force_gradient) and\
(self.n_seeds == other.n_seeds) and\
(self.force_gradient == other.force_gradient) and\
(self.n_outs == other.n_outs) and\
(self.n_args == other.n_args)
return rval
def __hash__(self):
return hash(type(self)) ^ \
hash(self.n_seqs) ^ \
hash(self.n_seeds) ^ \
hash(self.n_outs) ^ \
hash(self.force_gradient) ^\
hash(self.inplace) ^\
hash(self.go_backwards) ^\
......@@ -392,11 +267,10 @@ class Scan(theano.Op):
hash(self.n_args) ^ \
hash_list(self.outputs) ^ \
hash_list(self.inputs) ^ \
hash_list(g_ins) ^ \
hash_list(h_outs) ^ \
hash_list(self.g_ins) ^ \
hash_list(self.g_outs) ^ \
hash_dict(self.seqs_taps) ^\
hash_dict(self.outs_taps) ^\
hash_dict(self.inplace_map) ^\
hash_dict(self.updates)
......@@ -405,7 +279,7 @@ class Scan(theano.Op):
def perform(self,node,args, outs):
n_steps = 0
if (self.n_seqs ==0 ) and (args[0] == 0)
if (self.n_seqs ==0 ) and (args[0] == 0):
raise ValueError('Scan does not know over how many steps it '
'should iterate! No input sequence or number of steps to '
'iterate given !')
......@@ -417,10 +291,10 @@ class Scan(theano.Op):
if self.seqs_taps.has_key(i):
# compute actual length of the sequence ( we need to see what
# past taps this sequence has, and leave room for them
seq_len = args[i+1].shape[0] + min(self.seqs_taps[i+1])
if self.seqs_taps[i+1][2] > 0:
seq_len = args[i+1].shape[0] + min(self.seqs_taps[i])
if max( self.seqs_taps[i]) > 0:
# using future values, so need to end the sequence earlier
seq_len -= self.seqs_taps[i+1][2]
seq_len -= max(self.seqs_taps[i])
if n_steps == 0 :
# length of the sequences, leaving room for the largest
n_steps = seq_len
......@@ -437,9 +311,9 @@ class Scan(theano.Op):
inplace_map = {}
# check lengths of seeds
# check lengths of init_outs
for i in xrange(self.n_seqs+1, \
self.n_seqs+self.n_seeds+1):
self.n_seqs+self.n_outs+1):
if self.outs_taps.has_key(i-self.n_seqs-1):
req_size = abs(min(self.outs_taps[i-self.n_seqs-1]))-1
if args[i].shape[0] < req_size:
......@@ -448,83 +322,82 @@ class Scan(theano.Op):
' for missing values')%(i-self.n_iterable-1,req_size))
self.n_steps = n_steps
y = self.scan(self.fn, args[1:],self.n_seqs, self.n_seeds,
y = self.scan(self.fn, args[1:],self.n_seqs, self.n_outs,
self.seqs_taps, self.outs_taps, n_steps, self.go_backwards,
inplace_map)
# write to storage
for i in xrange(self.n_seeds):
for i in xrange(self.n_outs):
outs[i][0]=y[i]
def scan(fn, args, n_seqs, n_seeds, seqs_taps, outs_taps, n_steps,
def scan(self,fn, args, n_seqs, n_outs, seqs_taps, outs_taps, n_steps,
go_backwards, inplace_map):
y = []
for i in xrange(self.n_seeds):
for i in xrange(n_outs):
if inplace_map.has_key(i) and (inplace_map[i] >= 0):
y += [args[inplace_map[i]]]
else:
y_shape = (n_steps,)+args[i+self.n_seqs].shape[1:]
y_shape = (n_steps,)+args[i+n_seqs].shape[1:]
y += [numpy.empty(y_shape,
dtype=args[i+self.n_seqs].dtype)]
#iterate
if go_backwards:
the_range = xrange(n_steps-1,-1,-1)
else:
the_range = xrange(n_steps)
dtype=args[i+n_seqs].dtype)]
seqs_mins = {}
for j in xrange(self.n_seqs):
for j in xrange(n_seqs):
if seqs_taps.has_key(j):
seqs_mins.update({j: min(seqs_taps[j])})
outs_mins = {}
seed_size = {}
for j in xrange(self.n_seeds):
initOuts_size = {}
for j in xrange(n_outs):
if outs_taps.has_key(j):
outs_mins.update({j: min(outs_taps[j])})
seed_size.update({j: args[n_seqs+j].shape[0]})
initOuts_size.update({j: args[n_seqs+j].shape[0]})
for i in the_range:
for i in xrange(n_steps):
fn_args = []
# sequences over which scan iterates
for j in xrange(self.n_seqs):
# check to see if we are scaning them backwards or no
_i = i
if go_backwards:
_i = n_steps-1-i
for j in xrange(n_seqs):
if seqs_taps.has_key(j):
ls_taps = seqs_taps[j]
min_tap = seqs_mins[j]
for tap_value in ls_taps:
k = i - min_tap + tap_value
k = _i - min_tap + tap_value
fn_args += [args[j][k]]
# seeds or past values of outputs
for j in xrange(self.n_seeds):
# past values of outputs
for j in xrange(n_outs):
if outs_taps.has_key(j):
ls_taps = outs_taps[j]
min_tap = outs_mins[j]
seed_sz = seed_size[j]
sz = initOuts_size[j]
for tap_value in ls_taps:
if i + tap_value < 0:
k = i + seed_sz + tap_value
if k < 0
k = i + sz + tap_value
if k < 0:
# past value not provided.. issue a warning and use 0s
fn_args += [numpy.zeros(args[j][0].shape)]
warning('Past value %d for output %d not given in seeds' %
(j,tap_value))
fn_args += [numpy.zeros(args[j+n_seqs][0].shape)]
warning(('Past value %d for output %d not given in inital '
'out') % (j,tap_value))
else:
fn_args += [args[j][k]]
fn_args += [args[j+n_seqs][k]]
else:
fn_args += [y[j][i + tap_value]]
# get the non-iterable sequences
fn_args += list(args[(self.n_seqs+self.n_seedss):]
fn_args += list(args[(n_seqs+n_outs):])
# compute output
something = fn(*fn_args)
#update outputs
for j in xrange(self.n_seeds):
for j in xrange(n_outs):
y[j][i] = something[j]
return y
......@@ -560,7 +433,7 @@ class Scan(theano.Op):
g_scan = ScanGrad((self.g_ins,self.g_outs), self.n_seqs, \
self.n_seeds,self.seqs_taps, self.outs_taps,
self.n_outs,self.seqs_taps, self.outs_taps,
self.truncate_gradient)
return g_scan(g_args)
......@@ -573,7 +446,7 @@ def scan_make_inplace(node):
if isinstance(op, Scan) and (not op.inplace) \
and (op.inplace_map.keys() != []):
return Scan((op.inputs, op.outputs, op.updates), op.n_seqs, \
op.n_seeds, op.inplace_map, op.seqs_taps, op.outs_taps, \
op.n_outs, op.inplace_map, op.seqs_taps, op.outs_taps, \
op.force_gradient, op.truncate_gradient, \
op.go_backwards, inplace=True \
).make_node(*node.inputs).outputs
......@@ -673,12 +546,12 @@ class ScanGrad(theano.Op):
for j in xrange(self.n_outs):
if self.outs_taps.has_key(j):
outs_mins.update({j: min(self.outs_taps[j])})
seed_size.update({j: g_seeds[j]..shape[0]})
seed_size.update({j: g_seeds[j].shape[0]})
for i in the_range:
# time slice of inputs
_ins = []
for j in xrange(self.n_seqs)
for j in xrange(self.n_seqs):
if self.seqs_taps.has_key(j):
ls_taps = self.seqs_taps[j]
min_tap = seqs_mins[j]
......@@ -701,7 +574,7 @@ class ScanGrad(theano.Op):
warning('Past value %d for output $d not given' \
%(j,tap_value))
else:
_outs += [seeds[j][[k]]
_outs += [seeds[j][k]]
else:
_outs += [outs[j][i + tap_value]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论