提交 8d1809de authored 作者: carriepl's avatar carriepl

Speedup scan grad() method

上级 8a615a8d
...@@ -361,7 +361,8 @@ def Lop(f, wrt, eval_points, consider_constant=None, ...@@ -361,7 +361,8 @@ def Lop(f, wrt, eval_points, consider_constant=None,
def grad(cost, wrt, consider_constant=None, def grad(cost, wrt, consider_constant=None,
disconnected_inputs='raise', add_names=True, disconnected_inputs='raise', add_names=True,
known_grads=None, return_disconnected='zero'): known_grads=None, return_disconnected='zero',
null_gradients='raise'):
""" """
Return symbolic gradients for one or more variables with respect to some Return symbolic gradients for one or more variables with respect to some
cost. cost.
...@@ -408,6 +409,12 @@ def grad(cost, wrt, consider_constant=None, ...@@ -408,6 +409,12 @@ def grad(cost, wrt, consider_constant=None,
None None
- 'Disconnected' : returns variables of type DisconnectedType - 'Disconnected' : returns variables of type DisconnectedType
:type null_gradients: string
:param null_gradients: Defines the behaviour if some of the variables in
``wrt`` have a null gradient. The possibles values are :
- 'raise' : raise a NullTypeGradError exception
- 'return' : return the null gradients
:rtype: variable or list/tuple of Variables (matching `wrt`) :rtype: variable or list/tuple of Variables (matching `wrt`)
:return: symbolic expression of gradient of `cost` with respect to each :return: symbolic expression of gradient of `cost` with respect to each
...@@ -560,6 +567,12 @@ def grad(cost, wrt, consider_constant=None, ...@@ -560,6 +567,12 @@ def grad(cost, wrt, consider_constant=None,
grad_dict, wrt, cost_name) grad_dict, wrt, cost_name)
for i in xrange(len(rval)): for i in xrange(len(rval)):
if isinstance(rval[i].type, NullType):
if null_gradients == 'raise':
raise NullTypeGradError("tensor.grad encountered a NaN. " +
rval[i].type.why_null)
else:
assert null_gradients == 'return'
if isinstance(rval[i].type, DisconnectedType): if isinstance(rval[i].type, DisconnectedType):
handle_disconnected(rval[i]) handle_disconnected(rval[i])
if return_disconnected == 'zero': if return_disconnected == 'zero':
...@@ -1128,6 +1141,19 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1128,6 +1141,19 @@ def _populate_grad_dict(var_to_app_to_idx,
# we won't be able to post-process out the Nones if it does that # we won't be able to post-process out the Nones if it does that
input_grads = list(input_grads) input_grads = list(input_grads)
# Need to propagate the NullType gradients; if an input grad is
# not disconnected and the corresponding input is connected
# to at least one output whose gradient is NullType then the input
# grad should be NullType.
op_conn_pattern = _node_to_pattern(node)
for inp_idx in range(len(input_grads)):
for out_idx in range(len(ograd_is_nan)):
if (ograd_is_nan[out_idx] and
op_conn_pattern[inp_idx][out_idx] and
not isinstance(input_grads[inp_idx].type,
DisconnectedType)):
input_grads[inp_idx] = output_grads[out_idx]
# Do type checking on the result # Do type checking on the result
# List of bools indicating if each input only has integer outputs # List of bools indicating if each input only has integer outputs
...@@ -1251,6 +1277,7 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1251,6 +1277,7 @@ def _populate_grad_dict(var_to_app_to_idx,
if var not in grad_dict: if var not in grad_dict:
# If var is not in grad_dict already, we must compute it # If var is not in grad_dict already, we must compute it
if var in var_to_app_to_idx: if var in var_to_app_to_idx:
null_terms = []
terms = [] terms = []
node_to_idx = var_to_app_to_idx[var] node_to_idx = var_to_app_to_idx[var]
for node in node_to_idx: for node in node_to_idx:
...@@ -1265,9 +1292,8 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1265,9 +1292,8 @@ def _populate_grad_dict(var_to_app_to_idx,
type(term))) type(term)))
if isinstance(term.type, NullType): if isinstance(term.type, NullType):
raise NullTypeGradError("tensor.grad " null_terms.append(term)
"encountered a NaN. " + continue
term.type.why_null)
# Don't try to sum up DisconnectedType placeholders # Don't try to sum up DisconnectedType placeholders
if isinstance(term.type, DisconnectedType): if isinstance(term.type, DisconnectedType):
...@@ -1282,7 +1308,11 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1282,7 +1308,11 @@ def _populate_grad_dict(var_to_app_to_idx,
terms.append(term) terms.append(term)
# Add up the terms to get the total gradient on this variable # Add up the terms to get the total gradient on this variable
if len(terms) > 0: if len(null_terms) > 0:
# At least one term is a NullType : the total gradient
# will also be a NullType
grad_dict[var] = null_terms[0]
elif len(terms) > 0:
# the next line is like sum(terms) but doesn't add an # the next line is like sum(terms) but doesn't add an
# extraneous TensorConstant(0) # extraneous TensorConstant(0)
grad_dict[var] = reduce(lambda x, y: x + y, terms) grad_dict[var] = reduce(lambda x, y: x + y, terms)
......
...@@ -1938,39 +1938,45 @@ class Scan(PureOp): ...@@ -1938,39 +1938,45 @@ class Scan(PureOp):
iidx -= len(taps) iidx -= len(taps)
return oidx + iidx return oidx + iidx
def compute_gradient(y, g_y): def compute_all_gradients(known_grads):
if 'int' in str(g_y.dtype): y_s = known_grads.keys()
raise TypeError("Gradients may never be integers but g_y " g_y_s = known_grads.values()
"has type " + str(g_y.type))
for g_y in g_y_s:
odx = get_out_idx(self_outputs.index(y)) if 'int' in str(g_y.dtype):
wrt = [x for x in theano.gof.graph.inputs([y]) raise TypeError("Gradients may never be integers but g_y "
if (x in diff_inputs) and "has type " + str(g_y.type))
(connection_pattern[
get_inp_idx(self_inputs.index(x))][odx])] out_indices = [get_out_idx(self_outputs.index(y)) for y in y_s]
connected_inputs = [i for i in range(len(scan_node.inputs)) if
any([connection_pattern[i][odx] for odx in out_indices])]
wrt = [x for x in theano.gof.graph.inputs(y_s) if
(x in diff_inputs) and
get_inp_idx(self_inputs.index(x)) in connected_inputs]
gmp = OrderedDict() gmp = OrderedDict()
for x in wrt: # Required in case there is a pair of variables X and Y, with X
try: # used to compute Y, for both of which there is an external
gmp[x] = gradient.grad( # gradient signal. Without this, the total gradient signal on X
# will be the external gradient signalknown_grads[X]. With this,
# it will be the sum of the external gradient signal and the
# gradient obtained by propagating Y's external gradient signal
# to X.
known_grads = dict([(k.copy(), v) for (k, v) in known_grads.items()])
grads = gradient.grad(
cost=None, cost=None,
known_grads={y: g_y}, known_grads=known_grads,
wrt=x, wrt=wrt,
consider_constant=wrt, consider_constant=wrt,
disconnected_inputs='ignore', disconnected_inputs='ignore',
return_disconnected='None') return_disconnected='None',
except gradient.NullTypeGradError as e: null_gradients='return')
# The gradient wrt that particular input is undefined.
# This is not necessarily an issue, because maybe that for i in range(len(wrt)):
# particular input is not in the path between the gmp[wrt[i]] = grads[i]
# "cost" and "wrt" of the external, initial call to grad().
# We simply return a Null gradient, forwarding the message.
gmp[x] = NullType((
"This variable is Null because the grad method on the "
"inner graph of the Scan node %s returned Null for "
"the corresponding inner input variable. The original "
"message was: %s"
% (str(self), exc_message(e))))()
rval = [gmp.get(p, None) for p in diff_inputs] rval = [gmp.get(p, None) for p in diff_inputs]
return rval return rval
...@@ -2026,20 +2032,29 @@ class Scan(PureOp): ...@@ -2026,20 +2032,29 @@ class Scan(PureOp):
continue continue
dC_dXt = safe_new(dC_douts[idx][0]) dC_dXt = safe_new(dC_douts[idx][0])
dC_dXts.append(dC_dXt) dC_dXts.append(dC_dXt)
_dC_dinps_t = compute_gradient(Xt, dC_dXt)
for jdx in xrange(len(_dC_dinps_t)):
if dC_dinps_t[jdx] is None: known_grads = {}
dC_dinps_t[jdx] = _dC_dinps_t[jdx] dc_dxts_idx = 0
elif isinstance(dC_dinps_t[jdx].type, NullType): for i in range(len(diff_outputs)):
# The accumulated gradient is undefined if i < idx_nitsot_start or i >= idx_nitsot_end:
pass if diff_outputs[i] in known_grads:
elif _dC_dinps_t[jdx]: known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
if isinstance(_dC_dinps_t[jdx].type, NullType): else:
# The accumulated gradient is defined, but the new known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx]
# term is undefined. The whole thing has to be undefined. dc_dxts_idx += 1
dC_dinps_t[jdx] = _dC_dinps_t[jdx] else:
if isinstance(dC_douts[i].type, DisconnectedType):
dc_dxts_idx += 1
continue
else:
if diff_outputs[i] in known_grads:
known_grads[diff_outputs[i]] += dC_dXts[dc_dxts_idx]
else: else:
dC_dinps_t[jdx] += _dC_dinps_t[jdx] known_grads[diff_outputs[i]] = dC_dXts[dc_dxts_idx]
dc_dxts_idx += 1
dC_dinps_t = compute_all_gradients(known_grads)
# mask inputs that get no gradients # mask inputs that get no gradients
for dx in xrange(len(dC_dinps_t)): for dx in xrange(len(dC_dinps_t)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论