Unverified 提交 3876e73d authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Correct bad imports in optimize.py and expose it via `pytensor.tensor.__init__` (#1464)

* Expose `pt.optimize` * Clean up imports in optimize.py * Add example notebook for `optimize.root` * Small updates * Add doc file * Link to optimize docs in tensor index * Add docs to user-facing functions. * Move `scipy.optimize` imports into `perform` methods * Use global import strategy * Remove props and overload __str__ * rerun example notebook
上级 b56bff5b
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -31,3 +31,4 @@ symbolic expressions using calls that look just like numpy calls, such as
math_opt
basic_opt
functional
optimize
========================================================
:mod:`tensor.optimize` -- Symbolic Optimization Routines
========================================================
.. module:: tensor.conv
:platform: Unix, Windows
:synopsis: Symbolic Optimization Routines
.. moduleauthor:: LISA, PyMC Developers, PyTensor Developers
.. automodule:: pytensor.tensor.optimize
:members:
......@@ -118,6 +118,7 @@ import pytensor.tensor._linalg
from pytensor.tensor import linalg
from pytensor.tensor import special
from pytensor.tensor import signal
from pytensor.tensor import optimize
# For backward compatibility
from pytensor.tensor import nlinalg
......
......@@ -4,17 +4,14 @@ from copy import copy
from typing import cast
import numpy as np
from scipy.optimize import minimize as scipy_minimize
from scipy.optimize import minimize_scalar as scipy_minimize_scalar
from scipy.optimize import root as scipy_root
from scipy.optimize import root_scalar as scipy_root_scalar
import pytensor.scalar as ps
from pytensor import Variable, function, graph_replace
from pytensor.compile.function import function
from pytensor.gradient import grad, hessian, jacobian
from pytensor.graph import Apply, Constant, FunctionGraph
from pytensor.graph.basic import ancestors, truncated_graph_inputs
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
from pytensor.graph.replace import graph_replace
from pytensor.tensor.basic import (
atleast_2d,
concatenate,
......@@ -24,7 +21,12 @@ from pytensor.tensor.basic import (
)
from pytensor.tensor.math import dot
from pytensor.tensor.slinalg import solve
from pytensor.tensor.variable import TensorVariable
from pytensor.tensor.variable import TensorVariable, Variable
# scipy.optimize can be slow to import, and will not be used by most users
# We import scipy.optimize lazily inside optimization perform methods to avoid this.
optimize = None
_log = logging.getLogger(__name__)
......@@ -352,8 +354,6 @@ def implict_optimization_grads(
class MinimizeScalarOp(ScipyScalarWrapperOp):
__props__ = ("method",)
def __init__(
self,
x: Variable,
......@@ -377,7 +377,14 @@ class MinimizeScalarOp(ScipyScalarWrapperOp):
self._fn = None
self._fn_wrapped = None
def __str__(self):
return f"{self.__class__.__name__}(method={self.method})"
def perform(self, node, inputs, outputs):
global optimize
if optimize is None:
import scipy.optimize as optimize
f = self.fn_wrapped
f.clear_cache()
......@@ -385,7 +392,7 @@ class MinimizeScalarOp(ScipyScalarWrapperOp):
# the args of the objective function), but it is not used in the optimization.
x0, *args = inputs
res = scipy_minimize_scalar(
res = optimize.minimize_scalar(
fun=f.value,
args=tuple(args),
method=self.method,
......@@ -426,6 +433,27 @@ def minimize_scalar(
):
"""
Minimize a scalar objective function using scipy.optimize.minimize_scalar.
Parameters
----------
objective : TensorVariable
The objective function to minimize. This should be a PyTensor variable representing a scalar value.
x : TensorVariable
The variable with respect to which the objective function is minimized. It must be a scalar and an
input to the computational graph of `objective`.
method : str, optional
The optimization method to use. Default is "brent". See `scipy.optimize.minimize_scalar` for other options.
optimizer_kwargs : dict, optional
Additional keyword arguments to pass to `scipy.optimize.minimize_scalar`.
Returns
-------
solution: TensorVariable
Value of `x` that minimizes `objective(x, *args)`. If the success flag is False, this will be the
final state returned by the minimization routine, not necessarily a minimum.
success : TensorVariable
Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
value, based on the requested convergence criteria.
"""
args = _find_optimization_parameters(objective, x)
......@@ -438,12 +466,14 @@ def minimize_scalar(
optimizer_kwargs=optimizer_kwargs,
)
return minimize_scalar_op(x, *args)
solution, success = cast(
tuple[TensorVariable, TensorVariable], minimize_scalar_op(x, *args)
)
return solution, success
class MinimizeOp(ScipyWrapperOp):
__props__ = ("method", "jac", "hess", "hessp")
class MinimizeOp(ScipyWrapperOp):
def __init__(
self,
x: Variable,
......@@ -487,11 +517,24 @@ class MinimizeOp(ScipyWrapperOp):
self._fn = None
self._fn_wrapped = None
def __str__(self):
str_args = ", ".join(
[
f"{arg}={getattr(self, arg)}"
for arg in ["method", "jac", "hess", "hessp"]
]
)
return f"{self.__class__.__name__}({str_args})"
def perform(self, node, inputs, outputs):
global optimize
if optimize is None:
import scipy.optimize as optimize
f = self.fn_wrapped
x0, *args = inputs
res = scipy_minimize(
res = optimize.minimize(
fun=f.value_and_grad if self.jac else f.value,
jac=self.jac,
x0=x0,
......@@ -538,7 +581,7 @@ def minimize(
jac: bool = True,
hess: bool = False,
optimizer_kwargs: dict | None = None,
):
) -> tuple[TensorVariable, TensorVariable]:
"""
Minimize a scalar objective function using scipy.optimize.minimize.
......@@ -563,9 +606,13 @@ def minimize(
Returns
-------
TensorVariable
The optimized value of x that minimizes the objective function.
solution: TensorVariable
The optimized value of the vector of inputs `x` that minimizes `objective(x, *args)`. If the success flag
is False, this will be the final state of the minimization routine, but not necessarily a minimum.
success: TensorVariable
Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
value, based on the requested convergence criteria.
"""
args = _find_optimization_parameters(objective, x)
......@@ -579,12 +626,14 @@ def minimize(
optimizer_kwargs=optimizer_kwargs,
)
return minimize_op(x, *args)
solution, success = cast(
tuple[TensorVariable, TensorVariable], minimize_op(x, *args)
)
return solution, success
class RootScalarOp(ScipyScalarWrapperOp):
__props__ = ("method", "jac", "hess")
def __init__(
self,
variables,
......@@ -633,14 +682,24 @@ class RootScalarOp(ScipyScalarWrapperOp):
self._fn = None
self._fn_wrapped = None
def __str__(self):
str_args = ", ".join(
[f"{arg}={getattr(self, arg)}" for arg in ["method", "jac", "hess"]]
)
return f"{self.__class__.__name__}({str_args})"
def perform(self, node, inputs, outputs):
global optimize
if optimize is None:
import scipy.optimize as optimize
f = self.fn_wrapped
f.clear_cache()
# f.copy_x = True
variables, *args = inputs
res = scipy_root_scalar(
res = optimize.root_scalar(
f=f.value,
fprime=f.grad if self.jac else None,
fprime2=f.hess if self.hess else None,
......@@ -676,19 +735,48 @@ class RootScalarOp(ScipyScalarWrapperOp):
def root_scalar(
equation: TensorVariable,
variables: TensorVariable,
variable: TensorVariable,
method: str = "secant",
jac: bool = False,
hess: bool = False,
optimizer_kwargs: dict | None = None,
):
) -> tuple[TensorVariable, TensorVariable]:
"""
Find roots of a scalar equation using scipy.optimize.root_scalar.
Parameters
----------
equation : TensorVariable
The equation for which to find roots. This should be a PyTensor variable representing a single equation in one
variable. The function will find `variables` such that `equation(variables, *args) = 0`.
variable : TensorVariable
The variable with respect to which the equation is solved. It must be a scalar and an input to the
computational graph of `equation`.
method : str, optional
The root-finding method to use. Default is "secant". See `scipy.optimize.root_scalar` for other options.
jac : bool, optional
Whether to compute and use the first derivative of the equation with respect to `variables`.
Default is False. Some methods require this.
hess : bool, optional
Whether to compute and use the second derivative of the equation with respect to `variables`.
Default is False. Some methods require this.
optimizer_kwargs : dict, optional
Additional keyword arguments to pass to `scipy.optimize.root_scalar`.
Returns
-------
solution: TensorVariable
The final state of the root-finding routine. When `success` is True, this is the value of `variables` that
causes `equation` to evaluate to zero. Otherwise it is the final state returned by the root-finding
routine, but not necessarily a root.
success: TensorVariable
Boolean indicating whether the root-finding was successful. If True, the solution is a root of the equation
"""
args = _find_optimization_parameters(equation, variables)
args = _find_optimization_parameters(equation, variable)
root_scalar_op = RootScalarOp(
variables,
variable,
*args,
equation=equation,
method=method,
......@@ -697,7 +785,11 @@ def root_scalar(
optimizer_kwargs=optimizer_kwargs,
)
return root_scalar_op(variables, *args)
solution, success = cast(
tuple[TensorVariable, TensorVariable], root_scalar_op(variable, *args)
)
return solution, success
class RootOp(ScipyWrapperOp):
......@@ -734,6 +826,12 @@ class RootOp(ScipyWrapperOp):
self._fn = None
self._fn_wrapped = None
def __str__(self):
str_args = ", ".join(
[f"{arg}={getattr(self, arg)}" for arg in ["method", "jac"]]
)
return f"{self.__class__.__name__}({str_args})"
def build_fn(self):
outputs = self.inner_outputs
variables, *args = self.inner_inputs
......@@ -761,13 +859,17 @@ class RootOp(ScipyWrapperOp):
self._fn_wrapped = LRUCache1(fn)
def perform(self, node, inputs, outputs):
global optimize
if optimize is None:
import scipy.optimize as optimize
f = self.fn_wrapped
f.clear_cache()
f.copy_x = True
variables, *args = inputs
res = scipy_root(
res = optimize.root(
fun=f,
jac=self.jac,
x0=variables,
......@@ -815,8 +917,36 @@ def root(
method: str = "hybr",
jac: bool = True,
optimizer_kwargs: dict | None = None,
):
"""Find roots of a system of equations using scipy.optimize.root."""
) -> tuple[TensorVariable, TensorVariable]:
"""
Find roots of a system of equations using scipy.optimize.root.
Parameters
----------
equations : TensorVariable
The system of equations for which to find roots. This should be a PyTensor variable representing a
vector (or scalar) value. The function will find `variables` such that `equations(variables, *args) = 0`.
variables : TensorVariable
The variable(s) with respect to which the system of equations is solved. It must be an input to the
computational graph of `equations` and have the same number of dimensions as `equations`.
method : str, optional
The root-finding method to use. Default is "hybr". See `scipy.optimize.root` for other options.
jac : bool, optional
Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
Default is True. Most methods require this.
optimizer_kwargs : dict, optional
Additional keyword arguments to pass to `scipy.optimize.root`.
Returns
-------
solution: TensorVariable
The final state of the root-finding routine. When `success` is True, this is the value of `variables` that
causes all `equations` to evaluate to zero. Otherwise it is the final state returned by the root-finding
routine, but not necessarily a root.
success: TensorVariable
Boolean indicating whether the root-finding was successful. If True, the solution is a root of the equation
"""
args = _find_optimization_parameters(equations, variables)
......@@ -829,7 +959,11 @@ def root(
optimizer_kwargs=optimizer_kwargs,
)
return root_op(variables, *args)
solution, success = cast(
tuple[TensorVariable, TensorVariable], root_op(variables, *args)
)
return solution, success
__all__ = ["minimize_scalar", "minimize", "root_scalar", "root"]
......@@ -58,6 +58,7 @@ folder_title_map = {
"introduction": "Introduction",
"rewrites": "Graph Rewriting",
"scan": "Looping in Pytensor",
"optimize": "Optimization in Pytensor",
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论