提交 763586fa authored 作者: carriepl's avatar carriepl

Add method compute_all_gradients

上级 9bdef90d
......@@ -1552,7 +1552,7 @@ class Scan(PureOp):
return connection_pattern
def get_oinp_iinp_iout_oout_mappings(self):
"""
"""
Compute and return dictionary mappings between the inputs and
outputs of the inner function and the inputs and outputs of the Scan
node in the outer graph.
......@@ -1783,6 +1783,42 @@ class Scan(PureOp):
iidx -= len(taps)
return oidx + iidx
def compute_all_gradients(known_grads):
y_s = known_grads.keys()
g_y_s = known_grads.values()
for g_y in g_y_s:
if 'int' in str(g_y.dtype):
raise TypeError("Gradients may never be integers but g_y "
"has type " + str(g_y.type))
out_indices = [get_out_idx(self_outputs.index(y)) for y in y_s]
connected_inputs = [i for i in range(len(scan_node.inputs)) if
any([connection_pattern[i][odx] for odx in out_indices])]
wrt = [x for x in theano.gof.graph.inputs(y_s) if
(x in diff_inputs) and
get_inp_idx(self_inputs.index(x)) in connected_inputs]
gmp = OrderedDict()
known_grads = dict([(k*1,v) for (k,v) in known_grads.items()])
grads = gradient.grad(
cost=None,
known_grads=known_grads,
wrt=wrt,
consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None',
null_gradients='return')
for i in range(len(wrt)):
gmp[wrt[i]] = grads[i]
rval = [gmp.get(p, None) for p in diff_inputs]
return rval
def compute_gradient(y, g_y):
if 'int' in str(g_y.dtype):
raise TypeError("Gradients may never be integers but g_y "
......@@ -1871,6 +1907,7 @@ class Scan(PureOp):
continue
dC_dXt = safe_new(dC_douts[idx][0])
dC_dXts.append(dC_dXt)
"""
_dC_dinps_t = compute_gradient(Xt, dC_dXt)
for jdx in xrange(len(_dC_dinps_t)):
if dC_dinps_t[jdx] is None:
......@@ -1885,6 +1922,31 @@ class Scan(PureOp):
dC_dinps_t[jdx] = _dC_dinps_t[jdx]
else:
dC_dinps_t[jdx] += _dC_dinps_t[jdx]
"""
known_grads = {}
dc_dxts_idx = 0
for i in range(len(diff_outputs)):
if i < idx_nitsot_start or i >= idx_nitsot_end:
if diff_outputs[i] in known_grads:
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
else:
known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx]
dc_dxts_idx += 1
else:
if isinstance(dC_douts[i].type, DisconnectedType):
c_dxts_idx += 1
continue
else:
if diff_outputs[i] in known_grads:
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
else:
known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx]
dc_dxts_idx += 1
dC_dinps_t = compute_all_gradients(known_grads)
# mask inputs that get no gradients
for dx in xrange(len(dC_dinps_t)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论