提交 763586fa authored 作者: carriepl's avatar carriepl

Add method compute_all_gradients

上级 9bdef90d
...@@ -1552,7 +1552,7 @@ class Scan(PureOp): ...@@ -1552,7 +1552,7 @@ class Scan(PureOp):
return connection_pattern return connection_pattern
def get_oinp_iinp_iout_oout_mappings(self): def get_oinp_iinp_iout_oout_mappings(self):
""" """
Compute and return dictionary mappings between the inputs and Compute and return dictionary mappings between the inputs and
outputs of the inner function and the inputs and outputs of the Scan outputs of the inner function and the inputs and outputs of the Scan
node in the outer graph. node in the outer graph.
...@@ -1783,6 +1783,42 @@ class Scan(PureOp): ...@@ -1783,6 +1783,42 @@ class Scan(PureOp):
iidx -= len(taps) iidx -= len(taps)
return oidx + iidx return oidx + iidx
def compute_all_gradients(known_grads):
y_s = known_grads.keys()
g_y_s = known_grads.values()
for g_y in g_y_s:
if 'int' in str(g_y.dtype):
raise TypeError("Gradients may never be integers but g_y "
"has type " + str(g_y.type))
out_indices = [get_out_idx(self_outputs.index(y)) for y in y_s]
connected_inputs = [i for i in range(len(scan_node.inputs)) if
any([connection_pattern[i][odx] for odx in out_indices])]
wrt = [x for x in theano.gof.graph.inputs(y_s) if
(x in diff_inputs) and
get_inp_idx(self_inputs.index(x)) in connected_inputs]
gmp = OrderedDict()
known_grads = dict([(k*1,v) for (k,v) in known_grads.items()])
grads = gradient.grad(
cost=None,
known_grads=known_grads,
wrt=wrt,
consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None',
null_gradients='return')
for i in range(len(wrt)):
gmp[wrt[i]] = grads[i]
rval = [gmp.get(p, None) for p in diff_inputs]
return rval
def compute_gradient(y, g_y): def compute_gradient(y, g_y):
if 'int' in str(g_y.dtype): if 'int' in str(g_y.dtype):
raise TypeError("Gradients may never be integers but g_y " raise TypeError("Gradients may never be integers but g_y "
...@@ -1871,6 +1907,7 @@ class Scan(PureOp): ...@@ -1871,6 +1907,7 @@ class Scan(PureOp):
continue continue
dC_dXt = safe_new(dC_douts[idx][0]) dC_dXt = safe_new(dC_douts[idx][0])
dC_dXts.append(dC_dXt) dC_dXts.append(dC_dXt)
"""
_dC_dinps_t = compute_gradient(Xt, dC_dXt) _dC_dinps_t = compute_gradient(Xt, dC_dXt)
for jdx in xrange(len(_dC_dinps_t)): for jdx in xrange(len(_dC_dinps_t)):
if dC_dinps_t[jdx] is None: if dC_dinps_t[jdx] is None:
...@@ -1885,6 +1922,31 @@ class Scan(PureOp): ...@@ -1885,6 +1922,31 @@ class Scan(PureOp):
dC_dinps_t[jdx] = _dC_dinps_t[jdx] dC_dinps_t[jdx] = _dC_dinps_t[jdx]
else: else:
dC_dinps_t[jdx] += _dC_dinps_t[jdx] dC_dinps_t[jdx] += _dC_dinps_t[jdx]
"""
known_grads = {}
dc_dxts_idx = 0
for i in range(len(diff_outputs)):
if i < idx_nitsot_start or i >= idx_nitsot_end:
if diff_outputs[i] in known_grads:
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
else:
known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx]
dc_dxts_idx += 1
else:
if isinstance(dC_douts[i].type, DisconnectedType):
c_dxts_idx += 1
continue
else:
if diff_outputs[i] in known_grads:
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
else:
known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx]
dc_dxts_idx += 1
dC_dinps_t = compute_all_gradients(known_grads)
# mask inputs that get no gradients # mask inputs that get no gradients
for dx in xrange(len(dC_dinps_t)): for dx in xrange(len(dC_dinps_t)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论