提交 62ce0a4a authored 作者: Aleksandar Botev's avatar Aleksandar Botev

Added options for `disconnected_outputs` to Rop.

上级 83debd73
...@@ -160,7 +160,8 @@ disconnected_type = DisconnectedType() ...@@ -160,7 +160,8 @@ disconnected_type = DisconnectedType()
######################## ########################
def Rop(f, wrt, eval_points): def Rop(f, wrt, eval_points, disconnected_outputs="raise",
return_disconnected="zero"):
""" """
Computes the R operation on `f` wrt to `wrt` evaluated at points given Computes the R operation on `f` wrt to `wrt` evaluated at points given
in `eval_points`. Mathematically this stands for the jacobian of `f` wrt in `eval_points`. Mathematically this stands for the jacobian of `f` wrt
...@@ -174,6 +175,22 @@ def Rop(f, wrt, eval_points): ...@@ -174,6 +175,22 @@ def Rop(f, wrt, eval_points):
described by `f` described by `f`
:type eval_points: Variable or list of Variables :type eval_points: Variable or list of Variables
evalutation points for each of the variables in `wrt` evalutation points for each of the variables in `wrt`
:type disconnected_outputs: str
Defines the behaviour if some of the variables in `f` are
have no dependency on any of the variable in `wrt` (or if
all links are non-differentiable). The possible values are:
- 'ignore': considers that the gradient on these parameters is zero.
- 'warn': consider the gradient zero, and print a warning.
- 'raise': raise DisconnectedInputError.
:type return_disconnected : {'zero', 'None', 'Disconnected'}
- 'zero' : If wrt[i] is disconnected, return value i will be
wrt[i].zeros_like()
- 'None' : If wrt[i] is disconnected, return value i will be
None
- 'Disconnected' : returns variables of type DisconnectedType
:rtype: :class:`~theano.gof.Variable` or list/tuple of Variables depending on type of f :rtype: :class:`~theano.gof.Variable` or list/tuple of Variables depending on type of f
:return: symbolic expression such that :return: symbolic expression such that
R_op[i] = sum_j ( d f[i] / d wrt[j]) eval_point[j] R_op[i] = sum_j ( d f[i] / d wrt[j]) eval_point[j]
...@@ -296,9 +313,33 @@ def Rop(f, wrt, eval_points): ...@@ -296,9 +313,33 @@ def Rop(f, wrt, eval_points):
for out in f: for out in f:
if out in wrt: if out in wrt:
rval.append(eval_points[wrt.index(out)]) rval.append(eval_points[wrt.index(out)])
elif seen_nodes[out.owner][out.owner.outputs.index(out)] is None: elif seen_nodes.get(out.owner, None) is None or \
raise ValueError(('The function is not differentiable with ' seen_nodes[out.owner][out.owner.outputs.index(out)] is None:
'respect to the provided inputs !')) message = ("Rop method was asked to compute the gradient "
"with respect to a variable that is not part of "
"the computational graph of variables in wrt, or is "
"used only by a non-differentiable operator: %s" % out)
if disconnected_outputs == 'ignore':
pass
elif disconnected_outputs == 'warn':
warnings.warn(message, stacklevel=2)
elif disconnected_outputs == 'raise':
message = utils.get_variable_trace_string(out)
raise DisconnectedInputError(message)
else:
raise ValueError("Invalid value for keyword "
"'disconnected_inputs', valid values are "
"'ignore', 'warn' and 'raise'.")
if return_disconnected.lower() == "zero":
rval.append(tensor.zeros_like(out))
elif return_disconnected.lower() == "none":
rval.append(None)
elif return_disconnected.lower() == "disconnected":
rval.append(disconnected_type())
else:
raise ValueError("Invalid value for keyword "
"'return_disconnected', valid values are "
"'zero', 'None' and 'Disconnected'.")
else: else:
rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)]) rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论