提交 8a615a8d authored 作者: carriepl's avatar carriepl

Merge pull request #3460 from Theano/revert-3435-scan_grad_speedup

Revert "Scan grad speedup"
......@@ -361,8 +361,7 @@ def Lop(f, wrt, eval_points, consider_constant=None,
def grad(cost, wrt, consider_constant=None,
disconnected_inputs='raise', add_names=True,
known_grads=None, return_disconnected='zero',
null_gradients='raise'):
known_grads=None, return_disconnected='zero'):
"""
Return symbolic gradients for one or more variables with respect to some
cost.
......@@ -409,12 +408,6 @@ def grad(cost, wrt, consider_constant=None,
None
- 'Disconnected' : returns variables of type DisconnectedType
:type null_gradients: string
:param null_gradients: Defines the behaviour if some of the variables in
``wrt`` have a null gradient. The possibles values are :
- 'raise' : raise a NullTypeGradError exception
- 'return' : return the null gradients
:rtype: variable or list/tuple of Variables (matching `wrt`)
:return: symbolic expression of gradient of `cost` with respect to each
......@@ -567,12 +560,6 @@ def grad(cost, wrt, consider_constant=None,
grad_dict, wrt, cost_name)
for i in xrange(len(rval)):
if isinstance(rval[i].type, NullType):
if null_gradients == 'raise':
raise NullTypeGradError("tensor.grad encountered a NaN. " +
rval[i].type.why_null)
else:
assert null_gradients == 'return'
if isinstance(rval[i].type, DisconnectedType):
handle_disconnected(rval[i])
if return_disconnected == 'zero':
......@@ -1141,19 +1128,6 @@ def _populate_grad_dict(var_to_app_to_idx,
# we won't be able to post-process out the Nones if it does that
input_grads = list(input_grads)
# Need to propagate the NullType gradients; if an input grad is
# not disconnected and the corresponding input is connected
# to at least one output whose gradient is NullType then the input
# grad should be NullType.
op_conn_pattern = _node_to_pattern(node)
for inp_idx in range(len(input_grads)):
for out_idx in range(len(ograd_is_nan)):
if (ograd_is_nan[out_idx] and
op_conn_pattern[inp_idx][out_idx] and
not isinstance(input_grads[inp_idx].type,
DisconnectedType)):
input_grads[inp_idx] = output_grads[out_idx]
# Do type checking on the result
# List of bools indicating if each input only has integer outputs
......@@ -1277,7 +1251,6 @@ def _populate_grad_dict(var_to_app_to_idx,
if var not in grad_dict:
# If var is not in grad_dict already, we must compute it
if var in var_to_app_to_idx:
null_terms = []
terms = []
node_to_idx = var_to_app_to_idx[var]
for node in node_to_idx:
......@@ -1292,8 +1265,9 @@ def _populate_grad_dict(var_to_app_to_idx,
type(term)))
if isinstance(term.type, NullType):
null_terms.append(term)
continue
raise NullTypeGradError("tensor.grad "
"encountered a NaN. " +
term.type.why_null)
# Don't try to sum up DisconnectedType placeholders
if isinstance(term.type, DisconnectedType):
......@@ -1308,11 +1282,7 @@ def _populate_grad_dict(var_to_app_to_idx,
terms.append(term)
# Add up the terms to get the total gradient on this variable
if len(null_terms) > 0:
# At least one term is a NullType : the total gradient
# will also be a NullType
grad_dict[var] = null_terms[0]
elif len(terms) > 0:
if len(terms) > 0:
# the next line is like sum(terms) but doesn't add an
# extraneous TensorConstant(0)
grad_dict[var] = reduce(lambda x, y: x + y, terms)
......
......@@ -1938,45 +1938,39 @@ class Scan(PureOp):
iidx -= len(taps)
return oidx + iidx
def compute_all_gradients(known_grads):
y_s = known_grads.keys()
g_y_s = known_grads.values()
for g_y in g_y_s:
if 'int' in str(g_y.dtype):
raise TypeError("Gradients may never be integers but g_y "
"has type " + str(g_y.type))
out_indices = [get_out_idx(self_outputs.index(y)) for y in y_s]
connected_inputs = [i for i in range(len(scan_node.inputs)) if
any([connection_pattern[i][odx] for odx in out_indices])]
wrt = [x for x in theano.gof.graph.inputs(y_s) if
(x in diff_inputs) and
get_inp_idx(self_inputs.index(x)) in connected_inputs]
def compute_gradient(y, g_y):
if 'int' in str(g_y.dtype):
raise TypeError("Gradients may never be integers but g_y "
"has type " + str(g_y.type))
odx = get_out_idx(self_outputs.index(y))
wrt = [x for x in theano.gof.graph.inputs([y])
if (x in diff_inputs) and
(connection_pattern[
get_inp_idx(self_inputs.index(x))][odx])]
gmp = OrderedDict()
# Required in case there is a pair of variables X and Y, with X
# used to compute Y, for both of which there is an external
# gradient signal. Without this, the total gradient signal on X
# will be the external gradient signalknown_grads[X]. With this,
# it will be the sum of the external gradient signal and the
# gradient obtained by propagating Y's external gradient signal
# to X.
known_grads = dict([(k.copy(), v) for (k, v) in known_grads.items()])
grads = gradient.grad(
for x in wrt:
try:
gmp[x] = gradient.grad(
cost=None,
known_grads=known_grads,
wrt=wrt,
known_grads={y: g_y},
wrt=x,
consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None',
null_gradients='return')
for i in range(len(wrt)):
gmp[wrt[i]] = grads[i]
return_disconnected='None')
except gradient.NullTypeGradError as e:
# The gradient wrt that particular input is undefined.
# This is not necessarily an issue, because maybe that
# particular input is not in the path between the
# "cost" and "wrt" of the external, initial call to grad().
# We simply return a Null gradient, forwarding the message.
gmp[x] = NullType((
"This variable is Null because the grad method on the "
"inner graph of the Scan node %s returned Null for "
"the corresponding inner input variable. The original "
"message was: %s"
% (str(self), exc_message(e))))()
rval = [gmp.get(p, None) for p in diff_inputs]
return rval
......@@ -2032,29 +2026,20 @@ class Scan(PureOp):
continue
dC_dXt = safe_new(dC_douts[idx][0])
dC_dXts.append(dC_dXt)
known_grads = {}
dc_dxts_idx = 0
for i in range(len(diff_outputs)):
if i < idx_nitsot_start or i >= idx_nitsot_end:
if diff_outputs[i] in known_grads:
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
else:
known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx]
dc_dxts_idx += 1
else:
if isinstance(dC_douts[i].type, DisconnectedType):
dc_dxts_idx += 1
continue
else:
if diff_outputs[i] in known_grads:
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
_dC_dinps_t = compute_gradient(Xt, dC_dXt)
for jdx in xrange(len(_dC_dinps_t)):
if dC_dinps_t[jdx] is None:
dC_dinps_t[jdx] = _dC_dinps_t[jdx]
elif isinstance(dC_dinps_t[jdx].type, NullType):
# The accumulated gradient is undefined
pass
elif _dC_dinps_t[jdx]:
if isinstance(_dC_dinps_t[jdx].type, NullType):
# The accumulated gradient is defined, but the new
# term is undefined. The whole thing has to be undefined.
dC_dinps_t[jdx] = _dC_dinps_t[jdx]
else:
known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx]
dc_dxts_idx += 1
dC_dinps_t = compute_all_gradients(known_grads)
dC_dinps_t[jdx] += _dC_dinps_t[jdx]
# mask inputs that get no gradients
for dx in xrange(len(dC_dinps_t)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论