提交 8a615a8d authored 作者: carriepl's avatar carriepl

Merge pull request #3460 from Theano/revert-3435-scan_grad_speedup

Revert "Scan grad speedup"
...@@ -361,8 +361,7 @@ def Lop(f, wrt, eval_points, consider_constant=None, ...@@ -361,8 +361,7 @@ def Lop(f, wrt, eval_points, consider_constant=None,
def grad(cost, wrt, consider_constant=None, def grad(cost, wrt, consider_constant=None,
disconnected_inputs='raise', add_names=True, disconnected_inputs='raise', add_names=True,
known_grads=None, return_disconnected='zero', known_grads=None, return_disconnected='zero'):
null_gradients='raise'):
""" """
Return symbolic gradients for one or more variables with respect to some Return symbolic gradients for one or more variables with respect to some
cost. cost.
...@@ -409,12 +408,6 @@ def grad(cost, wrt, consider_constant=None, ...@@ -409,12 +408,6 @@ def grad(cost, wrt, consider_constant=None,
None None
- 'Disconnected' : returns variables of type DisconnectedType - 'Disconnected' : returns variables of type DisconnectedType
:type null_gradients: string
:param null_gradients: Defines the behaviour if some of the variables in
``wrt`` have a null gradient. The possibles values are :
- 'raise' : raise a NullTypeGradError exception
- 'return' : return the null gradients
:rtype: variable or list/tuple of Variables (matching `wrt`) :rtype: variable or list/tuple of Variables (matching `wrt`)
:return: symbolic expression of gradient of `cost` with respect to each :return: symbolic expression of gradient of `cost` with respect to each
...@@ -567,12 +560,6 @@ def grad(cost, wrt, consider_constant=None, ...@@ -567,12 +560,6 @@ def grad(cost, wrt, consider_constant=None,
grad_dict, wrt, cost_name) grad_dict, wrt, cost_name)
for i in xrange(len(rval)): for i in xrange(len(rval)):
if isinstance(rval[i].type, NullType):
if null_gradients == 'raise':
raise NullTypeGradError("tensor.grad encountered a NaN. " +
rval[i].type.why_null)
else:
assert null_gradients == 'return'
if isinstance(rval[i].type, DisconnectedType): if isinstance(rval[i].type, DisconnectedType):
handle_disconnected(rval[i]) handle_disconnected(rval[i])
if return_disconnected == 'zero': if return_disconnected == 'zero':
...@@ -1141,19 +1128,6 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1141,19 +1128,6 @@ def _populate_grad_dict(var_to_app_to_idx,
# we won't be able to post-process out the Nones if it does that # we won't be able to post-process out the Nones if it does that
input_grads = list(input_grads) input_grads = list(input_grads)
# Need to propagate the NullType gradients; if an input grad is
# not disconnected and the corresponding input is connected
# to at least one output whose gradient is NullType then the input
# grad should be NullType.
op_conn_pattern = _node_to_pattern(node)
for inp_idx in range(len(input_grads)):
for out_idx in range(len(ograd_is_nan)):
if (ograd_is_nan[out_idx] and
op_conn_pattern[inp_idx][out_idx] and
not isinstance(input_grads[inp_idx].type,
DisconnectedType)):
input_grads[inp_idx] = output_grads[out_idx]
# Do type checking on the result # Do type checking on the result
# List of bools indicating if each input only has integer outputs # List of bools indicating if each input only has integer outputs
...@@ -1277,7 +1251,6 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1277,7 +1251,6 @@ def _populate_grad_dict(var_to_app_to_idx,
if var not in grad_dict: if var not in grad_dict:
# If var is not in grad_dict already, we must compute it # If var is not in grad_dict already, we must compute it
if var in var_to_app_to_idx: if var in var_to_app_to_idx:
null_terms = []
terms = [] terms = []
node_to_idx = var_to_app_to_idx[var] node_to_idx = var_to_app_to_idx[var]
for node in node_to_idx: for node in node_to_idx:
...@@ -1292,8 +1265,9 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1292,8 +1265,9 @@ def _populate_grad_dict(var_to_app_to_idx,
type(term))) type(term)))
if isinstance(term.type, NullType): if isinstance(term.type, NullType):
null_terms.append(term) raise NullTypeGradError("tensor.grad "
continue "encountered a NaN. " +
term.type.why_null)
# Don't try to sum up DisconnectedType placeholders # Don't try to sum up DisconnectedType placeholders
if isinstance(term.type, DisconnectedType): if isinstance(term.type, DisconnectedType):
...@@ -1308,11 +1282,7 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1308,11 +1282,7 @@ def _populate_grad_dict(var_to_app_to_idx,
terms.append(term) terms.append(term)
# Add up the terms to get the total gradient on this variable # Add up the terms to get the total gradient on this variable
if len(null_terms) > 0: if len(terms) > 0:
# At least one term is a NullType : the total gradient
# will also be a NullType
grad_dict[var] = null_terms[0]
elif len(terms) > 0:
# the next line is like sum(terms) but doesn't add an # the next line is like sum(terms) but doesn't add an
# extraneous TensorConstant(0) # extraneous TensorConstant(0)
grad_dict[var] = reduce(lambda x, y: x + y, terms) grad_dict[var] = reduce(lambda x, y: x + y, terms)
......
...@@ -1938,45 +1938,39 @@ class Scan(PureOp): ...@@ -1938,45 +1938,39 @@ class Scan(PureOp):
iidx -= len(taps) iidx -= len(taps)
return oidx + iidx return oidx + iidx
def compute_all_gradients(known_grads): def compute_gradient(y, g_y):
y_s = known_grads.keys()
g_y_s = known_grads.values()
for g_y in g_y_s:
if 'int' in str(g_y.dtype): if 'int' in str(g_y.dtype):
raise TypeError("Gradients may never be integers but g_y " raise TypeError("Gradients may never be integers but g_y "
"has type " + str(g_y.type)) "has type " + str(g_y.type))
out_indices = [get_out_idx(self_outputs.index(y)) for y in y_s] odx = get_out_idx(self_outputs.index(y))
wrt = [x for x in theano.gof.graph.inputs([y])
connected_inputs = [i for i in range(len(scan_node.inputs)) if if (x in diff_inputs) and
any([connection_pattern[i][odx] for odx in out_indices])] (connection_pattern[
get_inp_idx(self_inputs.index(x))][odx])]
wrt = [x for x in theano.gof.graph.inputs(y_s) if
(x in diff_inputs) and
get_inp_idx(self_inputs.index(x)) in connected_inputs]
gmp = OrderedDict() gmp = OrderedDict()
# Required in case there is a pair of variables X and Y, with X for x in wrt:
# used to compute Y, for both of which there is an external try:
# gradient signal. Without this, the total gradient signal on X gmp[x] = gradient.grad(
# will be the external gradient signalknown_grads[X]. With this,
# it will be the sum of the external gradient signal and the
# gradient obtained by propagating Y's external gradient signal
# to X.
known_grads = dict([(k.copy(), v) for (k, v) in known_grads.items()])
grads = gradient.grad(
cost=None, cost=None,
known_grads=known_grads, known_grads={y: g_y},
wrt=wrt, wrt=x,
consider_constant=wrt, consider_constant=wrt,
disconnected_inputs='ignore', disconnected_inputs='ignore',
return_disconnected='None', return_disconnected='None')
null_gradients='return') except gradient.NullTypeGradError as e:
# The gradient wrt that particular input is undefined.
for i in range(len(wrt)): # This is not necessarily an issue, because maybe that
gmp[wrt[i]] = grads[i] # particular input is not in the path between the
# "cost" and "wrt" of the external, initial call to grad().
# We simply return a Null gradient, forwarding the message.
gmp[x] = NullType((
"This variable is Null because the grad method on the "
"inner graph of the Scan node %s returned Null for "
"the corresponding inner input variable. The original "
"message was: %s"
% (str(self), exc_message(e))))()
rval = [gmp.get(p, None) for p in diff_inputs] rval = [gmp.get(p, None) for p in diff_inputs]
return rval return rval
...@@ -2032,29 +2026,20 @@ class Scan(PureOp): ...@@ -2032,29 +2026,20 @@ class Scan(PureOp):
continue continue
dC_dXt = safe_new(dC_douts[idx][0]) dC_dXt = safe_new(dC_douts[idx][0])
dC_dXts.append(dC_dXt) dC_dXts.append(dC_dXt)
_dC_dinps_t = compute_gradient(Xt, dC_dXt)
for jdx in xrange(len(_dC_dinps_t)):
known_grads = {} if dC_dinps_t[jdx] is None:
dc_dxts_idx = 0 dC_dinps_t[jdx] = _dC_dinps_t[jdx]
for i in range(len(diff_outputs)): elif isinstance(dC_dinps_t[jdx].type, NullType):
if i < idx_nitsot_start or i >= idx_nitsot_end: # The accumulated gradient is undefined
if diff_outputs[i] in known_grads: pass
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx] elif _dC_dinps_t[jdx]:
else: if isinstance(_dC_dinps_t[jdx].type, NullType):
known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] # The accumulated gradient is defined, but the new
dc_dxts_idx += 1 # term is undefined. The whole thing has to be undefined.
else: dC_dinps_t[jdx] = _dC_dinps_t[jdx]
if isinstance(dC_douts[i].type, DisconnectedType):
dc_dxts_idx += 1
continue
else:
if diff_outputs[i] in known_grads:
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
else: else:
known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx] dC_dinps_t[jdx] += _dC_dinps_t[jdx]
dc_dxts_idx += 1
dC_dinps_t = compute_all_gradients(known_grads)
# mask inputs that get no gradients # mask inputs that get no gradients
for dx in xrange(len(dC_dinps_t)): for dx in xrange(len(dC_dinps_t)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论