提交 94f34a8e authored 作者: Razvan Pascanu's avatar Razvan Pascanu

function to compute hessian

上级 2a658f43
...@@ -770,3 +770,59 @@ def jacobian(expression, wrt, consider_constant=None, warn_type=False, ...@@ -770,3 +770,59 @@ def jacobian(expression, wrt, consider_constant=None, warn_type=False,
if not use_list: if not use_list:
jacobs = jacobs[0] jacobs = jacobs[0]
return jacobs return jacobs
def hessian(cost, wrt, consider_constant=None, warn_type=False,
disconnected_inputs='raise'):
"""
:type cost: Scalar (0-dimensional) `Variable`
:type wrt: 'Variable' or list of `Variables`s
:param consider_constant: a list of expressions not to backpropagate
through
:param warn_type: a value of True will cause warnings to be logged for any
Op that emits a gradient that does not match its input type.
:type disconnected_inputs: string
:param disconnected_inputs: Defines the behaviour if some of the variables
in ``wrt`` are not part of the computational graph computing ``cost``
(or if all links are non-differentiable). The possible values are:
- 'ignore': considers that the gradient on these parameters is zero.
- 'warn': consider the gradient zero, and print a warning.
- 'raise': raise an exception.
:return: either a instance of `Variable` or list/tuple of `Variable`s
(depending upon `wrt`). If an element of `wrt` is not
differentiable with respect to the output, then a zero
variable is returned. The return value is of same type
as `wrt`: a list/tuple or TensorVariable in all cases.
"""
# Check inputs have the right format
assert isisntance(cost, TensorVariable), \
"tensor.hessian expects a Tensor Variable as `cost`"
assert cost.ndim == 0, \
"tensor.hessian expects a 0 dimensional variable as `cost`"
if isintance(wrt, (list, tuple)):
use_list = True
wrt = list(wrt)
else:
use_list = False
wrt = [wrt]
hessians = []
for input in wrt:
assert isisntance(cost, TensorVariable), \
"tensor.hessian expects a (list of) Tensor Variable as `wrt`"
assert cost.ndim == 0, \
"tensor.hessian expects a (list of) 1 dimensional variable"\
"as `wrt`"
expr = grad(cost, input)
hess, _ = scan(lambda i, y, x: grad(y[i], x),
sequences=arange(expr.shape[0]),
non_sequences=[expr, input])
hessians.append(hess)
if not use_list:
hessians = hessians[0]
return hessians
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论