提交 ceb19750 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fixed bug in debug mode

上级 8504fce2
......@@ -1469,7 +1469,7 @@ class ScanGrad(Op):
[i.type() for i in outputs_grad ])
def perform(self, node, args, storage):
print 'perform args'
# get scan inputs
n_steps = args[0]
......@@ -1488,7 +1488,12 @@ class ScanGrad(Op):
seqs = inputs[:self.n_seqs]
outInfo = inputs[self.n_seqs:self.n_seqs+self.n_outs]
non_seqs = inputs[self.n_outs+self.n_seqs:]
print 'seqs'
print seqs
print 'outInfo'
print outInfo
print 'non_Seqs'
print non_seqs
if (self.n_seqs == 0 ) and (not numpy.isfinite(n_steps) ):
raise ValueError('Scan does not know how many steps it '
'should iterate! Either provide some input sequences from '
......@@ -1546,10 +1551,16 @@ class ScanGrad(Op):
g_non_seqs = [numpy.zeros_like(k) for k in non_seqs]
# get gradient on the outputs
g_outs = args[1:self.n_outs_not_shared+1]
g_outs = [arg.copy() for arg in args[1:self.n_outs_not_shared+1]]
# get the output of the scan operation
outs = args[1+self.n_outs_not_shared:self.n_outs_not_shared+self.n_outs+1]
print 'g_outs'
print g_outs
print 'outs'
print outs
print 'steps'
print args[0]
......@@ -1623,8 +1634,6 @@ class ScanGrad(Op):
min_tap = seqs_mins[j]
for tap_value in ls_taps :
k = _i - min_tap + tap_value
print k, lower_limit, k-lower_limit
print g_seqs[j].shape
g_seqs[j][k-lower_limit] += grads[pos]
pos += 1
......@@ -1651,7 +1660,12 @@ class ScanGrad(Op):
# return the gradient
print 'g_seqs'
print g_seqs
print 'g_outInfo'
print g_outInfo
print 'g_non_seqs'
print g_non_seqs
for i,v in enumerate(g_seqs + g_outInfo+ g_non_seqs):
storage[i][0] = v
......
......@@ -894,7 +894,7 @@ class T_Scan(unittest.TestCase):
return x_t
cost, updates = scan_project_sum(f_rnn_cmpl,u,x0,W_in, n_steps = None,
truncate_gradient = 40, go_backwards = False)
truncate_gradient = 3, go_backwards = False)
vparams = [v_u, v_x0,vW_in]
params = [u,x0,W_in ]
gparams = theano.tensor.grad(cost, params)
......@@ -919,7 +919,7 @@ class T_Scan(unittest.TestCase):
analytic_grad = reset_rng_grad_fn(v_u, v_x0, vW_in)
assert len(analytic_grad[0]) == 40
assert len(analytic_grad[0]) == 3
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论