提交 b0f5fb53 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

handling disconnected and undiff cases

上级 e8f5ae18
......@@ -1278,7 +1278,7 @@ class Scan(PureOp):
def compute_gradient(y, g_y):
gmp = gradient.grad_sources_inputs(
[(y, g_y)], diff_inputs)
[(y, g_y)], theano.gof.graph.inputs([y]))
return [gmp.get(p, None) for p in diff_inputs]
dC_dinps_t = [None for inp in diff_inputs]
dC_dXts = []
......@@ -1311,10 +1311,13 @@ class Scan(PureOp):
elif _dC_dinps_t[jdx]:
dC_dinps_t[jdx] += _dC_dinps_t[jdx]
# mask inputs that get no gradients
disconnected_dC_dinps_t = []
for dx in xrange(len(dC_dinps_t)):
if not dC_dinps_t[dx]:
dC_dinps_t[dx] = tensor.zeros_like(diff_inputs[dx])
disconnected_dC_dinps_t.append(True)
else:
disconnected_dC_dinps_t.append(False)
for Xt, Xt_placeholder in zip(
diff_outputs[self.n_mit_mot_outs:],
Xts):
......@@ -1371,6 +1374,7 @@ class Scan(PureOp):
inner_out_mitmot = []
mitmot_inp_taps = []
mitmot_out_taps = []
type_outs = []
out_pos = 0
ins_pos = self.n_seqs
n_mitmot_outs = 0
......@@ -1380,6 +1384,8 @@ class Scan(PureOp):
outer_inp_mitmot.append(dC_douts[idx][::-1])
mitmot_inp_taps.append([])
mitmot_out_taps.append([])
undefined = False
disconnected = True
for jdx in xrange(len(self.mit_mot_out_slices[idx])):
inner_inp_mitmot.append(dC_dXts[out_pos])
mitmot_inp_taps[idx].append(-self.mit_mot_out_slices[idx][jdx])
......@@ -1388,12 +1394,23 @@ class Scan(PureOp):
for jdx in xrange(len(self.tap_array[idx])):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
inner_out_mitmot.append(dC_inps[ins_pos])
inner_out_mitmot.append(dC_dinps_t[ins_pos])
if not disconnected_dC_dinps_t[ins_pos]:
disconnected=False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True
n_mitmot_inps_ += 1
ins_pos += 1
n_mitmot_outs += 1
mitmot_inp_taps[idx].append(-self.tap_array[idx][jdx])
mitmot_out_taps[idx].append(-self.tap_array[idx][jdx])
if undefined:
type_outs.append('undefined')
elif disconnected:
type_outs.append('disconnected')
else:
type_outs.append('connected')
offset = self.n_mit_mot
for idx in xrange(self.n_mit_sot):
......@@ -1404,6 +1421,8 @@ class Scan(PureOp):
inner_inp_mitmot.append(dC_dXts[out_pos])
out_pos += 1
n_mitmot_inps += 1
undefined = False
disconnected = True
mitmot_inp_taps[idx + offset].append(0)
for jdx in xrange(len(self.tap_array[idx_tap])):
inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
......@@ -1412,15 +1431,26 @@ class Scan(PureOp):
mitmot_out_taps[idx].append(
-self.tap_array[idx_tap][jdx])
inner_out_mitmot.append(dC_dinps_t[ins_pos])
if not disconnected_dC_dinps_t[ins_pos]:
disconnected = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True
n_mitmot_inps += 1
ins_pos += 1
n_mitmot_outs += 1
if undefined:
type_outs.append('undefined')
elif disconnected:
type_outs.append('disconnected')
else:
type_outs.append('connected')
offset += self.n_mit_sot
for idx in xrange(self.n_sit_sot):
mitmot_inp_taps.append([0, 1])
mitmot_out_taps.append([1])
undefined = False
if not isinstance(dC_douts[idx + offset].type, DisconnectedType):
outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
else:
......@@ -1428,6 +1458,16 @@ class Scan(PureOp):
tensor.zeros(outs[idx + offset].shape,
dtype = dC_dinps_t[ins_pos].dtype))
inner_out_mitmot.append(dC_dinps_t[ins_pos])
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([dC_dinps_t[ins_pos]]):
undefined = True
if undefined:
type_outs.append('undefined')
elif disconnected_dC_dinps_t[ins_pos]:
type_outs.append('disconnected')
else:
type_outs.append('connected')
inner_inp_mitmot += [dC_dXts[out_pos],
dC_dXtm1s[ins_pos - self.n_seqs]]
n_mitmot_outs += 1
......@@ -1441,6 +1481,29 @@ class Scan(PureOp):
n_nit_sot = self.n_seqs
inner_out_nitsot = dC_dinps_t[:self.n_seqs]
inner_out_sitsot = dC_dinps_t[ins_pos:]
for _p, vl in enumerate(inner_out_sitsot):
undefined = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([vl]):
undefined = True
if undefined:
type_outs.append('undefined')
elif disconnected_dC_dinps_t[_p + ins_pos]:
type_outs.append('disconnected')
else:
type_outs.append('connected')
for _p, vl in enumerate(inner_out_nitsot):
undefined = False
for _sh in self.inner_shared(self_inputs):
if _sh in gof.graph.inputs([vl]):
undefined = True
if undefined:
type_outs.append('undefined')
elif disconnected_dC_dinps_t[_p]:
type_outs.append('disconnected')
else:
type_outs.append('connected')
inner_inp_sitsot = dC_dXtm1s[ins_pos - self.n_seqs:]
outer_inp_sitsot = [
tensor.zeros([grad_steps + 1] +
......@@ -1502,10 +1565,33 @@ class Scan(PureOp):
self.n_mit_sot +
self.n_sit_sot +
n_sitsot_outs)
gradients += [x[::-1] for x in outputs[offset:offset + self.n_seqs]]
for p, (x, t) in enumerate(
zip(outputs[offset:offset + self.n_seqs],
type_outs[offset:offset + self.n_seqs])):
if t == 'undefined':
gradients.append(
grad_undefined(self,
p + 1,
inputs[p + 1],
'Depends on a shared variable'))
elif t == 'disconnected':
gradients.append(DisconnectedType()())
else:
gradients.append(x[::-1])
end = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
gradients += [x[::-1] for x in outputs[:end]]
for p, (x, t) in enumerate(
zip(outputs[:end], type_outs[:end])):
if t == 'undefined':
gradients.append(
grad_undefined(self,
p + 1 + self.n_seqs,
inputs[p + 1 + self.n_seqs],
'Depends on a shared variable'))
elif t == 'disconnected':
gradients.append(DisconnectedType()())
else:
gradients.append(x[::-1])
start = len(gradients)
gradients += [
grad_undefined(self, x + start, inputs[x + start],
......@@ -1519,7 +1605,18 @@ class Scan(PureOp):
begin = end
end = begin + n_sitsot_outs
gradients += [x[-1] for x in outputs[begin:end]]
for p, (x, t) in enumerate(
zip(outputs[begin:end], type_outs[begin:end])):
if t == 'undefined':
gradients.append(
grad_undefined(self,
p + begin + 1,
inputs[p + begin + 1],
'Depends on a shared variable'))
elif t == 'disconnected':
gradients.append(DisconnectedType()())
else:
gradients.append(x[-1])
return gradients
def R_op(self, inputs, eval_points):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论