提交 ee763488 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

change names to more proper things

上级 98c1a8b8
...@@ -1270,8 +1270,8 @@ class Scan(PureOp): ...@@ -1270,8 +1270,8 @@ class Scan(PureOp):
gmp = gradient.grad_sources_inputs( gmp = gradient.grad_sources_inputs(
[(y, g_y)], diff_inputs) [(y, g_y)], diff_inputs)
return [gmp.get(p, None) for p in diff_inputs] return [gmp.get(p, None) for p in diff_inputs]
dXt_inps = [None for inp in diff_inputs] dC_dinps_t = [None for inp in diff_inputs]
dXtp1_dXts = [] dC_dXts = []
Xts = [] Xts = []
for idx, Xt in enumerate(diff_outputs): for idx, Xt in enumerate(diff_outputs):
# We are looking for x[t-1] for a given x[t] # We are looking for x[t-1] for a given x[t]
...@@ -1288,36 +1288,36 @@ class Scan(PureOp): ...@@ -1288,36 +1288,36 @@ class Scan(PureOp):
tmp = tensor.grad(Xt.sum(), Xtm1) tmp = tensor.grad(Xt.sum(), Xtm1)
except ValueError: except ValueError:
tmp = Xt tmp = Xt
dXtp1_dXt = safe_new(tmp) dC_dXt = safe_new(tmp)
else: else:
if isinstance(dC_douts[idx].type, DisconnectedType): if isinstance(dC_douts[idx].type, DisconnectedType):
continue continue
dXtp1_dXt = safe_new(dC_douts[idx][0]) dC_dXt = safe_new(dC_douts[idx][0])
dXtp1_dXts.append(dXtp1_dXt) dC_dXts.append(dC_dXt)
_dXt_inps = compute_gradient(Xt, dXtp1_dXt) _dC_dinps_t = compute_gradient(Xt, dC_dXt)
for jdx in xrange(len(_dXt_inps)): for jdx in xrange(len(_dC_dinps_t)):
if dXt_inps[jdx] is None: if dC_dinps_t[jdx] is None:
dXt_inps[jdx] = _dXt_inps[jdx] dC_dinps_t[jdx] = _dC_dinps_t[jdx]
elif _dXt_inps[jdx]: elif _dC_dinps_t[jdx]:
dXt_inps[jdx] += _dXt_inps[jdx] dC_dinps_t[jdx] += _dC_dinps_t[jdx]
# mask inputs that get no gradients # mask inputs that get no gradients
for dx in xrange(len(dXt_inps)): for dx in xrange(len(dC_dinps_t)):
if not dXt_inps[dx]: if not dC_dinps_t[dx]:
dXt_inps[dx] = tensor.zeros_like(diff_inputs[dx]) dC_dinps_t[dx] = tensor.zeros_like(diff_inputs[dx])
else: else:
for Xt, Xt_placeholder in zip( for Xt, Xt_placeholder in zip(
diff_outputs[self.n_mit_mot_outs:], diff_outputs[self.n_mit_mot_outs:],
Xts): Xts):
tmp = forced_replace( tmp = forced_replace(
dXt_inps[dx], dC_dinps_t[dx],
Xt, Xt,
Xt_placeholder) Xt_placeholder)
dXt_inps[dx] = tmp dC_dinps_t[dx] = tmp
# construct dX_dtm1 # construct dX_dtm1
dXt_dXtm1s = [x.type() for x in dXt_inps[self.n_seqs:]] dC_dXtm1s = [x.type() for x in dC_dinps_t[self.n_seqs:]]
for dx, dXt_dXtm1 in enumerate(dXt_dXtm1s): for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
dXt_inps[dx+self.n_seqs] += dXt_dXtm1 dC_dinps_t[dx+self.n_seqs] += dC_dXtm1
# Construct scan op # Construct scan op
# Seqs # Seqs
outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]] outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]]
...@@ -1352,7 +1352,7 @@ class Scan(PureOp): ...@@ -1352,7 +1352,7 @@ class Scan(PureOp):
inner_inp_seqs += self.inner_mitmot(self_inputs) inner_inp_seqs += self.inner_mitmot(self_inputs)
inner_inp_seqs += self.inner_mitsot(self_inputs) inner_inp_seqs += self.inner_mitsot(self_inputs)
inner_inp_seqs += self.inner_sitsot(self_inputs) inner_inp_seqs += self.inner_sitsot(self_inputs)
inner_inp_seqs += self.inner_nitsot_outs(dXtp1_dXts) inner_inp_seqs += self.inner_nitsot_outs(dC_dXts)
inner_inp_seqs += Xts inner_inp_seqs += Xts
# mitmot # mitmot
outer_inp_mitmot = [] outer_inp_mitmot = []
...@@ -1371,14 +1371,14 @@ class Scan(PureOp): ...@@ -1371,14 +1371,14 @@ class Scan(PureOp):
mitmot_inp_taps.append([]) mitmot_inp_taps.append([])
mitmot_out_taps.append([]) mitmot_out_taps.append([])
for jdx in xrange(len(self.mit_mot_out_slices[idx])): for jdx in xrange(len(self.mit_mot_out_slices[idx])):
inner_inp_mitmot.append(dXtp1_dXts[out_pos]) inner_inp_mitmot.append(dC_dXts[out_pos])
mitmot_inp_taps[idx].append(-self.mit_mot_out_slices[idx][jdx]) mitmot_inp_taps[idx].append(-self.mit_mot_out_slices[idx][jdx])
n_mitmot_inps += 1 n_mitmot_inps += 1
out_pos += 1 out_pos += 1
for jdx in xrange(len(self.tap_array[idx])): for jdx in xrange(len(self.tap_array[idx])):
inner_inp_mitmot.append(dXt_dXtm1s[ins_pos - self.n_seqs]) inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
inner_out_mitmot.append(dXt_inps[ins_pos]) inner_out_mitmot.append(dC_inps[ins_pos])
n_mitmot_inps_ += 1 n_mitmot_inps_ += 1
ins_pos += 1 ins_pos += 1
n_mitmot_outs += 1 n_mitmot_outs += 1
...@@ -1391,17 +1391,17 @@ class Scan(PureOp): ...@@ -1391,17 +1391,17 @@ class Scan(PureOp):
mitmot_out_taps.append([]) mitmot_out_taps.append([])
outer_inp_mitmot.append(dC_douts[idx + offset][::-1]) outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
idx_tap = idx + self.n_mit_mot idx_tap = idx + self.n_mit_mot
inner_inp_mitmot.append(dXtp1_dXts[out_pos]) inner_inp_mitmot.append(dC_dXts[out_pos])
out_pos += 1 out_pos += 1
n_mitmot_inps += 1 n_mitmot_inps += 1
mitmot_inp_taps[idx + offset].append(0) mitmot_inp_taps[idx + offset].append(0)
for jdx in xrange(len(self.tap_array[idx_tap])): for jdx in xrange(len(self.tap_array[idx_tap])):
inner_inp_mitmot.append(dXt_dXtm1s[ins_pos - self.n_seqs]) inner_inp_mitmot.append(dC_dXtm1s[ins_pos - self.n_seqs])
mitmot_inp_taps[idx + offset].append( mitmot_inp_taps[idx + offset].append(
-self.tap_array[idx_tap][jdx]) -self.tap_array[idx_tap][jdx])
mitmot_out_taps[idx].append( mitmot_out_taps[idx].append(
-self.tap_array[idx_tap][jdx]) -self.tap_array[idx_tap][jdx])
inner_out_mitmot.append(dXt_inps[ins_pos]) inner_out_mitmot.append(dC_dinps_t[ins_pos])
n_mitmot_inps += 1 n_mitmot_inps += 1
ins_pos += 1 ins_pos += 1
n_mitmot_outs += 1 n_mitmot_outs += 1
...@@ -1416,10 +1416,10 @@ class Scan(PureOp): ...@@ -1416,10 +1416,10 @@ class Scan(PureOp):
else: else:
outer_inp_mitmot.append( outer_inp_mitmot.append(
tensor.zeros(outs[idx + offset].shape, tensor.zeros(outs[idx + offset].shape,
dtype = dXt_inps[ins_pos].dtype)) dtype = dC_dinps_t[ins_pos].dtype))
inner_out_mitmot.append(dXt_inps[ins_pos]) inner_out_mitmot.append(dC_dinps_t[ins_pos])
inner_inp_mitmot += [dXtp1_dXts[out_pos], inner_inp_mitmot += [dC_dXts[out_pos],
dXt_dXtm1s[ins_pos - self.n_seqs]] dC_dXtm1s[ins_pos - self.n_seqs]]
n_mitmot_outs += 1 n_mitmot_outs += 1
out_pos += 1 out_pos += 1
ins_pos += 1 ins_pos += 1
...@@ -1431,9 +1431,9 @@ class Scan(PureOp): ...@@ -1431,9 +1431,9 @@ class Scan(PureOp):
do_steps = inputs[0] do_steps = inputs[0]
n_nit_sot = self.n_seqs n_nit_sot = self.n_seqs
inner_out_nitsot = dXt_inps[:self.n_seqs] inner_out_nitsot = dC_dinps_t[:self.n_seqs]
inner_out_sitsot = dXt_inps[ins_pos:] inner_out_sitsot = dC_dinps_t[ins_pos:]
inner_inp_sitsot = dXt_dXtm1s[ins_pos - self.n_seqs:] inner_inp_sitsot = dC_dXtm1s[ins_pos - self.n_seqs:]
outer_inp_sitsot = [ outer_inp_sitsot = [
tensor.zeros([do_steps + 1] + tensor.zeros([do_steps + 1] +
[x.shape[i] for i in xrange(x.ndim)], [x.shape[i] for i in xrange(x.ndim)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论