提交 5fd729d0 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add helper to build hessian vector product

上级 db1c161e
......@@ -267,6 +267,16 @@ or, making use of the R-operator:
>>> f([4, 4], [2, 2])
array([ 4., 4.])
There is a builtin helper that uses the first method
>>> x = pt.dvector('x')
>>> v = pt.dvector('v')
>>> y = pt.sum(x ** 2)
>>> Hv = pytensor.gradient.hessian_vector_product(y, x, v)
>>> f = pytensor.function([x, v], Hv)
>>> f([4, 4], [2, 2])
array([ 4., 4.])
Final Pointers
==============
......
......@@ -2050,6 +2050,85 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
return as_list_or_tuple(using_list, using_tuple, hessians)
def hessian_vector_product(cost, wrt, p, **grad_kwargs):
"""Return the expression of the Hessian times a vector p.
Notes
-----
This function uses backward autodiff twice to obtain the desired expression.
You may want to manually build the equivalent expression by combining backward
followed by forward (if all Ops support it) autodiff.
See {ref}`docs/_tutcomputinggrads#Hessian-times-a-Vector` for how to do this.
Parameters
----------
cost: Scalar (0-dimensional) variable.
wrt: Vector (1-dimensional tensor) 'Variable' or list of Vectors
p: Vector (1-dimensional tensor) 'Variable' or list of Vectors
Each vector will be used for the hessp wirt to exach input variable
**grad_kwargs:
Keyword arguments passed to `grad` function.
Returns
-------
:class:` Vector or list of Vectors
The Hessian times p of the `cost` with respect to (elements of) `wrt`.
Examples
--------
.. testcode::
import numpy as np
from scipy.optimize import minimize
from pytensor import function
from pytensor.tensor import vector
from pytensor.gradient import grad, hessian_vector_product
x = vector('x')
p = vector('p')
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
rosen_jac = grad(rosen, x)
rosen_hessp = hessian_vector_product(rosen, x, p)
rosen_fn = function([x], rosen)
rosen_jac_fn = function([x], rosen_jac)
rosen_hessp_fn = function([x, p], rosen_hessp)
x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2])
res = minimize(
rosen_fn,
x0,
method="Newton-CG",
jac=rosen_jac_fn,
hessp=rosen_hessp_fn,
options={"xtol": 1e-8},
)
print(res.x)
.. testoutput::
[1. 1. 1. 0.99999999 0.99999999]
"""
wrt_list = wrt if isinstance(wrt, Sequence) else [wrt]
p_list = p if isinstance(p, Sequence) else [p]
grad_wrt_list = grad(cost, wrt=wrt_list, **grad_kwargs)
hessian_cost = pytensor.tensor.add(
*[
(grad_wrt * p).sum()
for grad_wrt, p in zip(grad_wrt_list, p_list, strict=True)
]
)
Hp_list = grad(hessian_cost, wrt=wrt_list, **grad_kwargs)
if isinstance(wrt, Variable):
return Hp_list[0]
return Hp_list
def _is_zero(x):
"""
Returns 'yes', 'no', or 'maybe' indicating whether x
......
import numpy as np
import pytest
from scipy.optimize import rosen_hess_prod
import pytensor
import pytensor.tensor.basic as ptb
......@@ -20,6 +21,7 @@ from pytensor.gradient import (
grad_scale,
grad_undefined,
hessian,
hessian_vector_product,
jacobian,
subgraph_grad,
zero_grad,
......@@ -1079,3 +1081,40 @@ def test_jacobian_disconnected_inputs():
func_s = pytensor.function([s2], jacobian_s)
val = np.array(1.0).astype(pytensor.config.floatX)
assert np.allclose(func_s(val), np.zeros(1))
class TestHessianVectorProdudoct:
def test_rosen(self):
x = vector("x", dtype="float64")
rosen = (100 * (x[1:] - x[:-1] ** 2) ** 2 + (1 - x[:-1]) ** 2).sum()
p = vector("p", dtype="float64")
rosen_hess_prod_pt = hessian_vector_product(rosen, wrt=x, p=p)
x_test = 0.1 * np.arange(9)
p_test = 0.5 * np.arange(9)
np.testing.assert_allclose(
rosen_hess_prod_pt.eval({x: x_test, p: p_test}),
rosen_hess_prod(x_test, p_test),
)
def test_multiple_wrt(self):
x = vector("x", dtype="float64")
y = vector("y", dtype="float64")
p_x = vector("p_x", dtype="float64")
p_y = vector("p_y", dtype="float64")
cost = (x**2 - y**2).sum()
hessp_x, hessp_y = hessian_vector_product(cost, wrt=[x, y], p=[p_x, p_y])
hessp_fn = pytensor.function([x, y, p_x, p_y], [hessp_x, hessp_y])
test = {
# x, y don't matter
"x": np.full((3,), np.nan),
"y": np.full((3,), np.nan),
"p_x": [1, 2, 3],
"p_y": [3, 2, 1],
}
hessp_x_eval, hessp_y_eval = hessp_fn(**test)
np.testing.assert_allclose(hessp_x_eval, [2, 4, 6])
np.testing.assert_allclose(hessp_y_eval, [-6, -4, -2])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论