提交 ac1e27e4 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1614 from pascanur/new_fix_grad_scan

New fix grad scan
...@@ -1561,19 +1561,16 @@ class Scan(PureOp): ...@@ -1561,19 +1561,16 @@ class Scan(PureOp):
for idx in xrange(self.n_mit_mot + self.n_mit_sot): for idx in xrange(self.n_mit_mot + self.n_mit_sot):
mintap = numpy.min(self.tap_array[idx]) mintap = numpy.min(self.tap_array[idx])
maxtap = numpy.max(self.tap_array[idx]) maxtap = numpy.max(self.tap_array[idx])
if idx < self.n_mit_mot:
outmaxtap = numpy.max(self.mitmot_out_taps()[idx])
else:
outmaxtap = 0
seq = outs[idx] seq = outs[idx]
for k in self.tap_array[idx]: for k in self.tap_array[idx]:
if maxtap < 0: if outmaxtap -k != 0:
dim_offset = abs(maxtap) nw_seq = seq[k - mintap: -(outmaxtap-k)][::-1]
else:
dim_offset = 0
if maxtap == mintap and maxtap != 0:
nw_seq = seq[:abs(maxtap)]
elif maxtap - k != 0:
nw_seq = seq[dim_offset + k - mintap - 1:\
-(maxtap - k + 1)][::-1]
else: else:
nw_seq = seq[dim_offset + k - mintap - 1: -1][::-1] nw_seq = seq[k - mintap:][::-1]
outer_inp_seqs.append(nw_seq) outer_inp_seqs.append(nw_seq)
outer_inp_seqs += [ outer_inp_seqs += [
x[:-1][::-1] for x in self.outer_sitsot_outs(outs)] x[:-1][::-1] for x in self.outer_sitsot_outs(outs)]
...@@ -1627,7 +1624,11 @@ class Scan(PureOp): ...@@ -1627,7 +1624,11 @@ class Scan(PureOp):
n_mitmot_inps = 0 n_mitmot_inps = 0
for idx in xrange(self.n_mit_mot): for idx in xrange(self.n_mit_mot):
outer_inp_mitmot.append(dC_douts[idx][::-1]) if isinstance(dC_douts[idx].type, DisconnectedType):
out = outs[idx]
outer_inp_mitmot.append(tensor.zeros_like(out))
else:
outer_inp_mitmot.append(dC_douts[idx][::-1])
mitmot_inp_taps.append([]) mitmot_inp_taps.append([])
mitmot_out_taps.append([]) mitmot_out_taps.append([])
undefined = False undefined = False
...@@ -1648,7 +1649,7 @@ class Scan(PureOp): ...@@ -1648,7 +1649,7 @@ class Scan(PureOp):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]): if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True undefined = True
n_mitmot_inps_ += 1 n_mitmot_inps += 1
ins_pos += 1 ins_pos += 1
n_mitmot_outs += 1 n_mitmot_outs += 1
mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx]) mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx])
......
...@@ -3577,10 +3577,8 @@ class T_Scan(unittest.TestCase): ...@@ -3577,10 +3577,8 @@ class T_Scan(unittest.TestCase):
assert not opt_obj.belongs_to_set(scan_node2, [scan_node1]) assert not opt_obj.belongs_to_set(scan_node2, [scan_node1])
def test_remove_constants_and_unused_inputs_scan_non_seqs(self): def test_remove_constants_and_unused_inputs_scan_non_seqs(self):
"""Test the opt remove_constants_and_unused_inputs_scan for #Test the opt remove_constants_and_unused_inputs_scan for
non sequences. #non sequences.
"""
W = theano.tensor.matrix(name='W') W = theano.tensor.matrix(name='W')
v = theano.tensor.ivector(name='v') v = theano.tensor.ivector(name='v')
y1, _ = theano.scan(lambda i, W: W[i], sequences=v, y1, _ = theano.scan(lambda i, W: W[i], sequences=v,
...@@ -3616,10 +3614,7 @@ class T_Scan(unittest.TestCase): ...@@ -3616,10 +3614,7 @@ class T_Scan(unittest.TestCase):
assert (len(inp) == len(set(inp))) assert (len(inp) == len(set(inp)))
def test_remove_constants_and_unused_inputs_scan_seqs(self): def test_remove_constants_and_unused_inputs_scan_seqs(self):
""" #Test the opt remove_constants_and_unused_inputs_scan for sequences.
Test the opt remove_constants_and_unused_inputs_scan for sequences.
"""
W = theano.tensor.matrix(name='W') W = theano.tensor.matrix(name='W')
v = theano.tensor.ivector(name='v') v = theano.tensor.ivector(name='v')
vv = theano.tensor.matrix(name='vv') vv = theano.tensor.matrix(name='vv')
...@@ -3661,6 +3656,39 @@ class T_Scan(unittest.TestCase): ...@@ -3661,6 +3656,39 @@ class T_Scan(unittest.TestCase):
inp = scan_node.op.outer_non_seqs(scan_node) inp = scan_node.op.outer_non_seqs(scan_node)
assert len(inp) == 1 assert len(inp) == 1
def test_hessian_bug_grad_grad_two_scans(self):
#Bug reported by Bitton Tenessi
W_flat = tensor.fvector(name='W')
W_flat.tag.test_value=numpy.ones((8,), dtype=numpy.float32)
W = W_flat.reshape((2,2,2))
def loss_outer(i_outer, sum_outer, W):
def loss_inner(i_inner, sum_inner, W):
return sum_inner + (W**2).sum().sum().sum()
result_inner, _ = theano.scan(fn=loss_inner,
outputs_info = tensor.as_tensor_variable(
numpy.asarray(0, dtype=numpy.float32)),
sequences=tensor.arange(1, dtype='int32'),
non_sequences=[W],
)
return sum_outer + result_inner[-1]
result_outer, _ = theano.scan(fn=loss_outer,
outputs_info = tensor.as_tensor_variable(
numpy.asarray(0, dtype=numpy.float32)),
sequences=tensor.arange(1, dtype='int32'),
non_sequences=[W],
)
cost = result_outer[-1]
H = theano.gradient.hessian(cost, W_flat)
f = theano.function([W_flat], H)
f(numpy.ones((8,), dtype='float32'))
def test_speed(): def test_speed():
# #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论