提交 98cba59e authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fix disconnected input case

Conflicts: theano/scan_module/scan_op.py
上级 a8358593
...@@ -1334,6 +1334,16 @@ class Scan(PureOp): ...@@ -1334,6 +1334,16 @@ class Scan(PureOp):
tmp = ils tmp = ils
if any([x is not None for x in tmp]): if any([x is not None for x in tmp]):
connection_pattern[iidx + 1][oidx] = True connection_pattern[iidx + 1][oidx] = True
old_conn = [ [v for v in cp] for cp in connection_pattern]
n_outs = len(node.outputs)
for steps in xrange(n_outs):
for iidx in xrange(n_outs):
for jidx in xrange(n_outs):
j_inp_idx = self.get_input_pos(jidx) + 1
if connection_pattern[j_inp_idx][iidx] == True:
for k in xrange(len(connection_pattern)):
if connection_pattern[k][iidx]:
connection_pattern[k][jidx] = True
return connection_pattern return connection_pattern
### GRAD FUNCTION ### GRAD FUNCTION
...@@ -1371,17 +1381,53 @@ class Scan(PureOp): ...@@ -1371,17 +1381,53 @@ class Scan(PureOp):
self.inner_mitsot_outs(self_outputs) + self.inner_mitsot_outs(self_outputs) +
self.inner_sitsot_outs(self_outputs) + self.inner_sitsot_outs(self_outputs) +
self.inner_nitsot_outs(self_outputs)) self.inner_nitsot_outs(self_outputs))
scan_node = outs[0].owner
connection_pattern = self.connection_pattern(scan_node)
def get_inp_idx(iidx):
if iidx < self.n_seqs:
return 1 + iidx
oidx = 1 + self.n_seqs
iidx = iidx - self.n_seqs
for taps in self.mitmot_taps():
if len(taps) > iidx:
return oidx
else:
oidx += 1
iidx -= len(taps)
for taps in self.mitsot_taps():
if len(taps) > iidx:
return oidx
else:
oidx += 1
iidx -= len(taps)
if iidx < self.info['n_sit_sot']:
return oidx + iidx
else:
return oidx + iidx + self.info['n_nit_sot']
def get_out_idx(iidx):
oidx = 0
for taps in self.mitmot_out_taps():
if len(taps) > iidx:
return oidx
else:
oidx += 1
iidx -= len(taps)
return oidx + iidx
def compute_gradient(y, g_y): def compute_gradient(y, g_y):
if 'int' in str(g_y.dtype): if 'int' in str(g_y.dtype):
raise TypeError("Gradients may never be integers but g_y " raise TypeError("Gradients may never be integers but g_y "
"has type " + str(g_y.type)) "has type " + str(g_y.type))
odx = get_out_idx(self_outputs.index(y))
wrt = [x for x in theano.gof.graph.inputs([y]) wrt = [x for x in theano.gof.graph.inputs([y])
if x in diff_inputs] if (x in diff_inputs) and
(connection_pattern[get_inp_idx(self_inputs.index(x))][odx])]
grads = gradient.grad( grads = gradient.grad(
cost=None, cost = None,
known_grads={y: g_y}, known_grads = {y : g_y },
wrt=wrt, consider_constant=wrt, wrt=wrt, consider_constant=wrt,
disconnected_inputs='ignore', disconnected_inputs='ignore',
return_disconnected='None') return_disconnected='None')
...@@ -1757,6 +1803,20 @@ class Scan(PureOp): ...@@ -1757,6 +1803,20 @@ class Scan(PureOp):
'Depends on a shared variable')) 'Depends on a shared variable'))
else: else:
gradients.append(x[-1]) gradients.append(x[-1])
# Mask disconnected gradients
# Ideally we would want to assert that the gradients we are
# replacing do indeed evaluate to 0, though that is not practical
# from a computational point of view
# The gradients of scan are computed replacing Disconnected with 0,
# because through the recurrence they can become nonzero
for idx in xrange(len(gradients)):
disconnected = True
for kdx in xrange(len(node.outputs)):
if connection_pattern[idx][kdx] and \
not isinstance(dC_douts[kdx].type, DisconnectedType):
disconnected = False
if disconnected:
gradients[idx] = DisconnectedType()()
return gradients return gradients
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
......
...@@ -3295,6 +3295,17 @@ class T_Scan(unittest.TestCase): ...@@ -3295,6 +3295,17 @@ class T_Scan(unittest.TestCase):
cost = x.sum() cost = x.sum()
self.assertRaises(ValueError, tensor.grad, cost, y0) self.assertRaises(ValueError, tensor.grad, cost, y0)
def test_disconnected_gradient(self):
v = tensor.vector('v')
m = tensor.matrix('m')
u0 = tensor.zeros((7,))
[u, m2], _ = theano.scan(lambda _, u: [u, v],
sequences=m,
outputs_info=[u0, None])
# This used to raise an exception with older versions becasue for a
# disconnected gradient a non disconnected type was returned
tensor.grad((m * m2).sum(), v)
def test_pregreedy_optimizer(self): def test_pregreedy_optimizer(self):
W = tensor.zeros((5, 4)) W = tensor.zeros((5, 4))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论