Unverified 提交 ee107cba authored 作者: Michal Novomestsky's avatar Michal Novomestsky 提交者: GitHub

Use vectorized jacobian in Minimize Op (#1582)

* added identity as alias for tensor_copy and defined No-Op for TensorFromScalar * refactor: jacobian should use tensorize * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * removed redundant pprint * refactor: added vectorize=True to all jacobians * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add option to vectorize jacobian in minimize/root * pre-commit --------- Co-authored-by: 's avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: 's avatarjessegrabowski <jessegrabowski@gmail.com>
上级 c932ffbd
...@@ -664,6 +664,11 @@ class TensorFromScalar(COp): ...@@ -664,6 +664,11 @@ class TensorFromScalar(COp):
tensor_from_scalar = TensorFromScalar() tensor_from_scalar = TensorFromScalar()
@_vectorize_node.register(TensorFromScalar)
def vectorize_tensor_from_scalar(op, node, batch_x):
return identity(batch_x).owner
class ScalarFromTensor(COp): class ScalarFromTensor(COp):
__props__ = () __props__ = ()
...@@ -2046,6 +2051,7 @@ def register_transfer(fn): ...@@ -2046,6 +2051,7 @@ def register_transfer(fn):
"""Create a duplicate of `a` (with duplicated storage)""" """Create a duplicate of `a` (with duplicated storage)"""
tensor_copy = Elemwise(ps.identity) tensor_copy = Elemwise(ps.identity)
pprint.assign(tensor_copy, printing.IgnorePrinter()) pprint.assign(tensor_copy, printing.IgnorePrinter())
identity = tensor_copy
class Default(Op): class Default(Op):
...@@ -4603,6 +4609,7 @@ __all__ = [ ...@@ -4603,6 +4609,7 @@ __all__ = [
"matrix_transpose", "matrix_transpose",
"default", "default",
"tensor_copy", "tensor_copy",
"identity",
"transfer", "transfer",
"alloc", "alloc",
"identity_like", "identity_like",
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import pytensor.scalar as ps import pytensor.scalar as ps
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.gradient import grad, hessian, jacobian from pytensor.gradient import grad, jacobian
from pytensor.graph.basic import Apply, Constant from pytensor.graph.basic import Apply, Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
...@@ -484,6 +484,7 @@ class MinimizeOp(ScipyWrapperOp): ...@@ -484,6 +484,7 @@ class MinimizeOp(ScipyWrapperOp):
jac: bool = True, jac: bool = True,
hess: bool = False, hess: bool = False,
hessp: bool = False, hessp: bool = False,
use_vectorized_jac: bool = False,
optimizer_kwargs: dict | None = None, optimizer_kwargs: dict | None = None,
): ):
if not cast(TensorVariable, objective).ndim == 0: if not cast(TensorVariable, objective).ndim == 0:
...@@ -496,6 +497,7 @@ class MinimizeOp(ScipyWrapperOp): ...@@ -496,6 +497,7 @@ class MinimizeOp(ScipyWrapperOp):
) )
self.fgraph = FunctionGraph([x, *args], [objective]) self.fgraph = FunctionGraph([x, *args], [objective])
self.use_vectorized_jac = use_vectorized_jac
if jac: if jac:
grad_wrt_x = cast( grad_wrt_x = cast(
...@@ -505,7 +507,12 @@ class MinimizeOp(ScipyWrapperOp): ...@@ -505,7 +507,12 @@ class MinimizeOp(ScipyWrapperOp):
if hess: if hess:
hess_wrt_x = cast( hess_wrt_x = cast(
Variable, hessian(self.fgraph.outputs[0], self.fgraph.inputs[0]) Variable,
jacobian(
self.fgraph.outputs[-1],
self.fgraph.inputs[0],
vectorize=use_vectorized_jac,
),
) )
self.fgraph.add_output(hess_wrt_x) self.fgraph.add_output(hess_wrt_x)
...@@ -561,7 +568,10 @@ class MinimizeOp(ScipyWrapperOp): ...@@ -561,7 +568,10 @@ class MinimizeOp(ScipyWrapperOp):
implicit_f = grad(inner_fx, inner_x) implicit_f = grad(inner_fx, inner_x)
df_dx, *df_dtheta_columns = jacobian( df_dx, *df_dtheta_columns = jacobian(
implicit_f, [inner_x, *inner_args], disconnected_inputs="ignore" implicit_f,
[inner_x, *inner_args],
disconnected_inputs="ignore",
vectorize=self.use_vectorized_jac,
) )
grad_wrt_args = implict_optimization_grads( grad_wrt_args = implict_optimization_grads(
df_dx=df_dx, df_dx=df_dx,
...@@ -581,6 +591,7 @@ def minimize( ...@@ -581,6 +591,7 @@ def minimize(
method: str = "BFGS", method: str = "BFGS",
jac: bool = True, jac: bool = True,
hess: bool = False, hess: bool = False,
use_vectorized_jac: bool = False,
optimizer_kwargs: dict | None = None, optimizer_kwargs: dict | None = None,
) -> tuple[TensorVariable, TensorVariable]: ) -> tuple[TensorVariable, TensorVariable]:
""" """
...@@ -590,18 +601,21 @@ def minimize( ...@@ -590,18 +601,21 @@ def minimize(
---------- ----------
objective : TensorVariable objective : TensorVariable
The objective function to minimize. This should be a pytensor variable representing a scalar value. The objective function to minimize. This should be a pytensor variable representing a scalar value.
x: TensorVariable
x : TensorVariable
The variable with respect to which the objective function is minimized. It must be an input to the The variable with respect to which the objective function is minimized. It must be an input to the
computational graph of `objective`. computational graph of `objective`.
method: str, optional
method : str, optional
The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options. The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options.
jac: bool, optional
jac : bool, optional Whether to compute and use the gradient of the objective function with respect to x for optimization.
Whether to compute and use the gradient of teh objective function with respect to x for optimization.
Default is True. Default is True.
hess: bool, optional
Whether to compute and use the Hessian of the objective function with respect to x for optimization.
Default is False. Note that some methods require this, while others do not support it.
use_vectorized_jac: bool, optional
Whether to use a vectorized graph (vmap) to compute the jacobian (and/or hessian) matrix. If False, a
scan will be used instead. This comes down to a memory/compute trade-off. Vectorized graphs can be faster,
but use more memory. Default is False.
optimizer_kwargs optimizer_kwargs
Additional keyword arguments to pass to scipy.optimize.minimize Additional keyword arguments to pass to scipy.optimize.minimize
...@@ -624,6 +638,7 @@ def minimize( ...@@ -624,6 +638,7 @@ def minimize(
method=method, method=method,
jac=jac, jac=jac,
hess=hess, hess=hess,
use_vectorized_jac=use_vectorized_jac,
optimizer_kwargs=optimizer_kwargs, optimizer_kwargs=optimizer_kwargs,
) )
...@@ -804,6 +819,7 @@ class RootOp(ScipyWrapperOp): ...@@ -804,6 +819,7 @@ class RootOp(ScipyWrapperOp):
method: str = "hybr", method: str = "hybr",
jac: bool = True, jac: bool = True,
optimizer_kwargs: dict | None = None, optimizer_kwargs: dict | None = None,
use_vectorized_jac: bool = False,
): ):
if cast(TensorVariable, variables).ndim != cast(TensorVariable, equations).ndim: if cast(TensorVariable, variables).ndim != cast(TensorVariable, equations).ndim:
raise ValueError( raise ValueError(
...@@ -817,7 +833,11 @@ class RootOp(ScipyWrapperOp): ...@@ -817,7 +833,11 @@ class RootOp(ScipyWrapperOp):
self.fgraph = FunctionGraph([variables, *args], [equations]) self.fgraph = FunctionGraph([variables, *args], [equations])
if jac: if jac:
jac_wrt_x = jacobian(self.fgraph.outputs[0], self.fgraph.inputs[0]) jac_wrt_x = jacobian(
self.fgraph.outputs[0],
self.fgraph.inputs[0],
vectorize=use_vectorized_jac,
)
self.fgraph.add_output(atleast_2d(jac_wrt_x)) self.fgraph.add_output(atleast_2d(jac_wrt_x))
self.jac = jac self.jac = jac
...@@ -897,8 +917,14 @@ class RootOp(ScipyWrapperOp): ...@@ -897,8 +917,14 @@ class RootOp(ScipyWrapperOp):
inner_x, *inner_args = self.fgraph.inputs inner_x, *inner_args = self.fgraph.inputs
inner_fx = self.fgraph.outputs[0] inner_fx = self.fgraph.outputs[0]
df_dx = jacobian(inner_fx, inner_x) if not self.jac else self.fgraph.outputs[1] df_dx = (
df_dtheta_columns = jacobian(inner_fx, inner_args, disconnected_inputs="ignore") jacobian(inner_fx, inner_x, vectorize=True)
if not self.jac
else self.fgraph.outputs[1]
)
df_dtheta_columns = jacobian(
inner_fx, inner_args, disconnected_inputs="ignore", vectorize=True
)
grad_wrt_args = implict_optimization_grads( grad_wrt_args = implict_optimization_grads(
df_dx=df_dx, df_dx=df_dx,
...@@ -917,6 +943,7 @@ def root( ...@@ -917,6 +943,7 @@ def root(
variables: TensorVariable, variables: TensorVariable,
method: str = "hybr", method: str = "hybr",
jac: bool = True, jac: bool = True,
use_vectorized_jac: bool = False,
optimizer_kwargs: dict | None = None, optimizer_kwargs: dict | None = None,
) -> tuple[TensorVariable, TensorVariable]: ) -> tuple[TensorVariable, TensorVariable]:
""" """
...@@ -935,6 +962,10 @@ def root( ...@@ -935,6 +962,10 @@ def root(
jac : bool, optional jac : bool, optional
Whether to compute and use the Jacobian of the `equations` with respect to `variables`. Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
Default is True. Most methods require this. Default is True. Most methods require this.
use_vectorized_jac: bool, optional
Whether to use a vectorized graph (vmap) to compute the jacobian matrix. If False, a scan will be used instead.
This comes down to a memory/compute trade-off. Vectorized graphs can be faster, but use more memory.
Default is False.
optimizer_kwargs : dict, optional optimizer_kwargs : dict, optional
Additional keyword arguments to pass to `scipy.optimize.root`. Additional keyword arguments to pass to `scipy.optimize.root`.
...@@ -958,6 +989,7 @@ def root( ...@@ -958,6 +989,7 @@ def root(
method=method, method=method,
jac=jac, jac=jac,
optimizer_kwargs=optimizer_kwargs, optimizer_kwargs=optimizer_kwargs,
use_vectorized_jac=use_vectorized_jac,
) )
solution, success = cast( solution, success = cast(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论