提交 a404f83f authored 作者: Razvan Pascanu's avatar Razvan Pascanu

subtle changes to scan op

上级 c22f4e84
......@@ -73,13 +73,13 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={},
for i in xrange(n_outs):
if not outputs_taps.has_key(i):
outputs_taps.update({i:[-1]})
# if output sequence is not actually used as input to the recursive
# function
elif outputs_taps[i] == []:
outputs_taps.__delitem__(i)
elif not(type(outputs_taps[i]) in (list,tuple)):
outputs_taps[i] = [outputs_taps[i]]
# create theano inputs for the recursive function
args = []
for (i,seq) in enumerate(seqs):
......@@ -89,6 +89,9 @@ def scan(fn, sequences, initial_states, non_sequences, inplace_map={},
for (i,init_out) in enumerate(init_outs):
if outputs_taps.has_key(i):
for k in xrange(len(outputs_taps[i])):
if outputs_taps[i] == [-1]:
args += [init_out.type() ]
else:
args += [init_out[0].type() ]
args += non_seqs
......@@ -313,6 +316,9 @@ class Scan(theano.Op):
for i in xrange(self.n_seqs+1, \
self.n_seqs+self.n_outs+1):
if self.outs_taps.has_key(i-self.n_seqs-1):
if self.outs_taps[i-self.n_seqs-1] == [-1]:
args[i] = numpy.array([args[i]])
req_size = abs(min(self.outs_taps[i-self.n_seqs-1]))-1
if args[i].shape[0] < req_size:
warning(('Initial state for output %d has fewer values then '
......
......@@ -97,13 +97,13 @@ class T_Scan(unittest.TestCase):
def f_pow2(x_tm1):
return (2*x_tm1, {})
s = theano.tensor.dvector()
s = theano.tensor.dscalar()
n_steps = theano.tensor.dscalar()
Y = theano.sandbox.scan.scan(f_pow2, [],s, [],n_steps = n_steps)
f1 = theano.function([s,n_steps], Y)
assert(compareArrays(f1([1],3), [2,4,8]))
assert(compareArrays(f1(1,3), [2,4,8]))
# simple rnn, one input, one state, weights for each; input/state are
# vectors, weights are scalars
......@@ -112,7 +112,7 @@ class T_Scan(unittest.TestCase):
return (u_t*W_in+x_tm1*W, {})
u = theano.tensor.dvector()
x0 = theano.tensor.dvector()
x0 = theano.tensor.dscalar()
W_in = theano.tensor.dscalar()
W = theano.tensor.dscalar()
......@@ -120,7 +120,7 @@ class T_Scan(unittest.TestCase):
f2 = theano.function([u,x0,W_in,W], Y)
v_u = numpy.array([1.,2.,3.,4.])
v_x0 = numpy.array([1])
v_x0 = numpy.array(1)
v_out = numpy.array([1.1,1.3,1.6,2.])
assert(compareArrays( f2(v_u,v_x0,.1,1), v_out ) )
......@@ -129,7 +129,7 @@ class T_Scan(unittest.TestCase):
def test_3(self):
u = theano.tensor.dvector()
x0 = theano.tensor.dvector()
x0 = theano.tensor.dscalar()
W_in = theano.shared(.1, name = 'w_in')
W = theano.shared(1., name ='w')
......@@ -140,7 +140,7 @@ class T_Scan(unittest.TestCase):
f3 = theano.function([u,x0], Y)
v_u = numpy.array([1.,2.,3.,4.])
v_x0 = numpy.array([1.])
v_x0 = numpy.array(1.)
v_out = numpy.array([1.1,1.3,1.6,2.])
assert(compareArrays(f3(v_u,v_x0),v_out))
......@@ -155,8 +155,8 @@ class T_Scan(unittest.TestCase):
W_in1 = theano.tensor.dmatrix('win')
u1 = theano.tensor.dmatrix('u1')
u2 = theano.tensor.dvector('u2')
x0 = theano.tensor.dmatrix('x0')
y0 = theano.tensor.dvector('y0')
x0 = theano.tensor.dvector('x0')
y0 = theano.tensor.dscalar('y0')
def f_rnn_cmpl(u1_t, u2_t, x_tm1, y_tm1, W_in1):
return ({}, [theano.dot(u1_t,W_in1) + u2_t* W_in2 + \
......@@ -167,8 +167,8 @@ class T_Scan(unittest.TestCase):
f4 = theano.function([u1,u2,x0,y0,W_in1], Y)
v_u1 = numpy.array([[1.,2.],[1.,2.],[1.,2.]])
v_u2 = numpy.array([1.,2.,3.])
v_x0 = numpy.array([[0.,0.]])
v_y0 = numpy.array([1])
v_x0 = numpy.array([0.,0.])
v_y0 = numpy.array(1)
v_Win1 = numpy.array([[1.,1.],[1.,1.]])
v_x = numpy.array([[4.,5.],[18.,16.],[58.,43.]])
v_y = numpy.array([0.,7.,25.])
......@@ -186,7 +186,7 @@ class T_Scan(unittest.TestCase):
u = theano.tensor.dvector('u')
x = theano.shared(numpy.array([0.,0.]),'x')
y0 = theano.tensor.dvector('y0')
y0 = theano.tensor.dscalar('y0')
def f_ESN(u_t):
return ( theano.dot(x,W_out), \
......@@ -196,7 +196,7 @@ class T_Scan(unittest.TestCase):
f5 = theano.function([u,y0],Y)
v_u = numpy.array([1.,2.,3.])
v_y0 = numpy.array([0.])
v_y0 = numpy.array(0.)
v_out = numpy.array([0.,1.5,3.15])
out = f5( v_u, v_y0 )
assert( compareArrays(v_out, out))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论