提交 03970991 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #6156 from botev/master

Added options for `disconnected_outputs` to Rop.
......@@ -160,7 +160,8 @@ disconnected_type = DisconnectedType()
########################
def Rop(f, wrt, eval_points):
def Rop(f, wrt, eval_points, disconnected_outputs="raise",
return_disconnected="zero"):
"""
Computes the R operation on `f` wrt to `wrt` evaluated at points given
in `eval_points`. Mathematically this stands for the jacobian of `f` wrt
......@@ -174,6 +175,22 @@ def Rop(f, wrt, eval_points):
described by `f`
:type eval_points: Variable or list of Variables
evalutation points for each of the variables in `wrt`
:type disconnected_outputs: str
Defines the behaviour if some of the variables in `f` are
have no dependency on any of the variable in `wrt` (or if
all links are non-differentiable). The possible values are:
- 'ignore': considers that the gradient on these parameters is zero.
- 'warn': consider the gradient zero, and print a warning.
- 'raise': raise DisconnectedInputError.
:type return_disconnected : {'zero', 'None', 'Disconnected'}
- 'zero' : If wrt[i] is disconnected, return value i will be
wrt[i].zeros_like()
- 'None' : If wrt[i] is disconnected, return value i will be
None
- 'Disconnected' : returns variables of type DisconnectedType
:rtype: :class:`~theano.gof.Variable` or list/tuple of Variables depending on type of f
:return: symbolic expression such that
R_op[i] = sum_j ( d f[i] / d wrt[j]) eval_point[j]
......@@ -296,9 +313,33 @@ def Rop(f, wrt, eval_points):
for out in f:
if out in wrt:
rval.append(eval_points[wrt.index(out)])
elif seen_nodes[out.owner][out.owner.outputs.index(out)] is None:
raise ValueError(('The function is not differentiable with '
'respect to the provided inputs !'))
elif seen_nodes.get(out.owner, None) is None or \
seen_nodes[out.owner][out.owner.outputs.index(out)] is None:
message = ("Rop method was asked to compute the gradient "
"with respect to a variable that is not part of "
"the computational graph of variables in wrt, or is "
"used only by a non-differentiable operator: %s" % out)
if disconnected_outputs == 'ignore':
pass
elif disconnected_outputs == 'warn':
warnings.warn(message, stacklevel=2)
elif disconnected_outputs == 'raise':
message = utils.get_variable_trace_string(out)
raise DisconnectedInputError(message)
else:
raise ValueError("Invalid value for keyword "
"'disconnected_inputs', valid values are "
"'ignore', 'warn' and 'raise'.")
if return_disconnected.lower() == "zero":
rval.append(tensor.zeros_like(out))
elif return_disconnected.lower() == "none":
rval.append(None)
elif return_disconnected.lower() == "disconnected":
rval.append(disconnected_type())
else:
raise ValueError("Invalid value for keyword "
"'return_disconnected', valid values are "
"'zero', 'None' and 'Disconnected'.")
else:
rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论