提交 f244d4f2 authored 作者: carriepl's avatar carriepl

Add null_gradients arg to gradient.grad()

上级 763586fa
...@@ -359,7 +359,8 @@ def Lop(f, wrt, eval_points, consider_constant=None, ...@@ -359,7 +359,8 @@ def Lop(f, wrt, eval_points, consider_constant=None,
def grad(cost, wrt, consider_constant=None, def grad(cost, wrt, consider_constant=None,
disconnected_inputs='raise', add_names=True, disconnected_inputs='raise', add_names=True,
known_grads=None, return_disconnected='zero'): known_grads=None, return_disconnected='zero',
null_gradients='raise'):
""" """
Return symbolic gradients for one or more variables with respect to some Return symbolic gradients for one or more variables with respect to some
cost. cost.
...@@ -406,6 +407,12 @@ def grad(cost, wrt, consider_constant=None, ...@@ -406,6 +407,12 @@ def grad(cost, wrt, consider_constant=None,
None None
- 'Disconnected' : returns variables of type DisconnectedType - 'Disconnected' : returns variables of type DisconnectedType
:type null_gradients: string
:param null_gradients: Defines the behaviour if some of the variables in
``wrt`` have a null gradient. The possibles values are :
- 'raise' : raise a NullTypeGradError exception
- 'return' : return the null gradients
:rtype: variable or list/tuple of Variables (matching `wrt`) :rtype: variable or list/tuple of Variables (matching `wrt`)
:return: symbolic expression of gradient of `cost` with respect to each :return: symbolic expression of gradient of `cost` with respect to each
...@@ -547,6 +554,12 @@ def grad(cost, wrt, consider_constant=None, ...@@ -547,6 +554,12 @@ def grad(cost, wrt, consider_constant=None,
grad_dict, wrt, cost_name) grad_dict, wrt, cost_name)
for i in xrange(len(rval)): for i in xrange(len(rval)):
if isinstance(rval[i].type, NullType):
if null_gradients == 'raise':
raise NullTypeGradError("tensor.grad encountered a NaN. " +
rval[i].type.why_null)
else:
assert null_gradients == 'return'
if isinstance(rval[i].type, DisconnectedType): if isinstance(rval[i].type, DisconnectedType):
handle_disconnected(rval[i]) handle_disconnected(rval[i])
if return_disconnected == 'zero': if return_disconnected == 'zero':
...@@ -1115,6 +1128,19 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1115,6 +1128,19 @@ def _populate_grad_dict(var_to_app_to_idx,
# we won't be able to post-process out the Nones if it does that # we won't be able to post-process out the Nones if it does that
input_grads = list(input_grads) input_grads = list(input_grads)
# Need to propagate the NullType gradients; if an input grad is
# not disconnected and the corresponding input is connected
# to at least one output whose gradient is NullType then the input
# grad should be NullType.
op_conn_pattern = _node_to_pattern(node)
for inp_idx in range(len(input_grads)):
for out_idx in range(len(ograd_is_nan)):
if (ograd_is_nan[out_idx] and
op_conn_pattern[inp_idx][out_idx] and
not isinstance(input_grads[inp_idx].type,
DisconnectedType)):
input_grads[inp_idx] = output_grads[out_idx]
# Do type checking on the result # Do type checking on the result
# List of bools indicating if each input only has integer outputs # List of bools indicating if each input only has integer outputs
...@@ -1238,6 +1264,7 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1238,6 +1264,7 @@ def _populate_grad_dict(var_to_app_to_idx,
if var not in grad_dict: if var not in grad_dict:
# If var is not in grad_dict already, we must compute it # If var is not in grad_dict already, we must compute it
if var in var_to_app_to_idx: if var in var_to_app_to_idx:
null_terms = []
terms = [] terms = []
node_to_idx = var_to_app_to_idx[var] node_to_idx = var_to_app_to_idx[var]
for node in node_to_idx: for node in node_to_idx:
...@@ -1252,9 +1279,8 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1252,9 +1279,8 @@ def _populate_grad_dict(var_to_app_to_idx,
type(term))) type(term)))
if isinstance(term.type, NullType): if isinstance(term.type, NullType):
raise NullTypeGradError("tensor.grad " null_terms.append(term)
"encountered a NaN. " + continue
term.type.why_null)
# Don't try to sum up DisconnectedType placeholders # Don't try to sum up DisconnectedType placeholders
if isinstance(term.type, DisconnectedType): if isinstance(term.type, DisconnectedType):
...@@ -1269,7 +1295,11 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1269,7 +1295,11 @@ def _populate_grad_dict(var_to_app_to_idx,
terms.append(term) terms.append(term)
# Add up the terms to get the total gradient on this variable # Add up the terms to get the total gradient on this variable
if len(terms) > 0: if len(null_terms) > 0:
# At least one term is a NullType : the total gradient
# will also be a NullType
grad_dict[var] = null_terms[0]
elif len(terms) > 0:
# the next line is like sum(terms) but doesn't add an # the next line is like sum(terms) but doesn't add an
# extraneous TensorConstant(0) # extraneous TensorConstant(0)
grad_dict[var] = reduce(lambda x, y: x + y, terms) grad_dict[var] = reduce(lambda x, y: x + y, terms)
......
...@@ -1936,7 +1936,7 @@ class Scan(PureOp): ...@@ -1936,7 +1936,7 @@ class Scan(PureOp):
dc_dxts_idx += 1 dc_dxts_idx += 1
else: else:
if isinstance(dC_douts[i].type, DisconnectedType): if isinstance(dC_douts[i].type, DisconnectedType):
c_dxts_idx += 1 dc_dxts_idx += 1
continue continue
else: else:
if diff_outputs[i] in known_grads: if diff_outputs[i] in known_grads:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论