提交 e2cb1e2d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Optimize: Enforce input types, and stop appeasing mypy

上级 10d225dd
import logging
from collections.abc import Sequence
from copy import copy
from typing import cast
import numpy as np
......@@ -126,7 +125,9 @@ class LRUCache1:
self.hess_calls = 0
def _find_optimization_parameters(objective: TensorVariable, x: TensorVariable):
def _find_optimization_parameters(
objective: TensorVariable, x: TensorVariable
) -> list[Variable]:
"""
Find the parameters of the optimization problem that are not the variable `x`.
......@@ -140,23 +141,19 @@ def _find_optimization_parameters(objective: TensorVariable, x: TensorVariable):
def _get_parameter_grads_from_vector(
grad_wrt_args_vector: Variable,
x_star: Variable,
grad_wrt_args_vector: TensorVariable,
x_star: TensorVariable,
args: Sequence[Variable],
output_grad: Variable,
):
output_grad: TensorVariable,
) -> list[TensorVariable]:
"""
Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
"""
grad_wrt_args_vector = cast(TensorVariable, grad_wrt_args_vector)
x_star = cast(TensorVariable, x_star)
cursor = 0
grad_wrt_args = []
for arg in args:
arg = cast(TensorVariable, arg)
arg_shape = arg.shape
arg_size = arg_shape.prod()
arg_grad = grad_wrt_args_vector[:, cursor : cursor + arg_size].reshape(
......@@ -268,17 +265,16 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
def scalar_implict_optimization_grads(
inner_fx: Variable,
inner_x: Variable,
inner_fx: TensorVariable,
inner_x: TensorVariable,
inner_args: Sequence[Variable],
args: Sequence[Variable],
x_star: Variable,
output_grad: Variable,
x_star: TensorVariable,
output_grad: TensorVariable,
fgraph: FunctionGraph,
) -> list[Variable]:
df_dx, *df_dthetas = cast(
list[Variable],
grad(inner_fx, [inner_x, *inner_args], disconnected_inputs="ignore"),
df_dx, *df_dthetas = grad(
inner_fx, [inner_x, *inner_args], disconnected_inputs="ignore"
)
replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
......@@ -286,20 +282,20 @@ def scalar_implict_optimization_grads(
grad_wrt_args = [
(-df_dtheta_star / df_dx_star) * output_grad
for df_dtheta_star in cast(list[TensorVariable], df_dthetas_stars)
for df_dtheta_star in df_dthetas_stars
]
return grad_wrt_args
def implict_optimization_grads(
df_dx: Variable,
df_dtheta_columns: Sequence[Variable],
df_dx: TensorVariable,
df_dtheta_columns: Sequence[TensorVariable],
args: Sequence[Variable],
x_star: Variable,
output_grad: Variable,
x_star: TensorVariable,
output_grad: TensorVariable,
fgraph: FunctionGraph,
):
) -> list[TensorVariable]:
r"""
Compute gradients of an optimization problem with respect to its parameters.
......@@ -341,21 +337,15 @@ def implict_optimization_grads(
fgraph : FunctionGraph
The function graph that contains the inputs and outputs of the optimization problem.
"""
df_dx = cast(TensorVariable, df_dx)
df_dtheta = concatenate(
[
atleast_2d(jac_col, left=False)
for jac_col in cast(list[TensorVariable], df_dtheta_columns)
],
[atleast_2d(jac_col, left=False) for jac_col in df_dtheta_columns],
axis=-1,
)
replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
df_dx_star, df_dtheta_star = cast(
list[TensorVariable],
graph_replace([atleast_2d(df_dx), df_dtheta], replace=replace),
df_dx_star, df_dtheta_star = graph_replace(
[atleast_2d(df_dx), df_dtheta], replace=replace
)
grad_wrt_args_vector = solve(-df_dx_star, df_dtheta_star)
......@@ -369,20 +359,24 @@ def implict_optimization_grads(
class MinimizeScalarOp(ScipyScalarWrapperOp):
def __init__(
self,
x: Variable,
x: TensorVariable,
*args: Variable,
objective: Variable,
method: str = "brent",
objective: TensorVariable,
method: str,
optimizer_kwargs: dict | None = None,
):
if not cast(TensorVariable, x).ndim == 0:
if not (isinstance(x, TensorVariable) and x.ndim == 0):
raise ValueError(
"The variable `x` must be a scalar (0-dimensional) tensor for minimize_scalar."
)
if not cast(TensorVariable, objective).ndim == 0:
if not (isinstance(objective, TensorVariable) and objective.ndim == 0):
raise ValueError(
"The objective function must be a scalar (0-dimensional) tensor for minimize_scalar."
)
if x not in ancestors([objective]):
raise ValueError(
"The variable `x` must be an input to the computational graph of the objective function."
)
self.fgraph = FunctionGraph([x, *args], [objective])
self.method = method
......@@ -468,7 +462,6 @@ def minimize_scalar(
Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
value, based on the requested convergence criteria.
"""
args = _find_optimization_parameters(objective, x)
minimize_scalar_op = MinimizeScalarOp(
......@@ -479,9 +472,7 @@ def minimize_scalar(
optimizer_kwargs=optimizer_kwargs,
)
solution, success = cast(
tuple[TensorVariable, TensorVariable], minimize_scalar_op(x, *args)
)
solution, success = minimize_scalar_op(x, *args)
return solution, success
......@@ -489,17 +480,21 @@ def minimize_scalar(
class MinimizeOp(ScipyVectorWrapperOp):
def __init__(
self,
x: Variable,
x: TensorVariable,
*args: Variable,
objective: Variable,
method: str = "BFGS",
objective: TensorVariable,
method: str,
jac: bool = True,
hess: bool = False,
hessp: bool = False,
use_vectorized_jac: bool = False,
optimizer_kwargs: dict | None = None,
):
if not cast(TensorVariable, objective).ndim == 0:
if not (isinstance(x, TensorVariable) and x.ndim in (0, 1)):
raise ValueError(
"The variable `x` must be a scalar or vector (0-or-1-dimensional) tensor for minimize."
)
if not (isinstance(objective, TensorVariable) and objective.ndim == 0):
raise ValueError(
"The objective function must be a scalar (0-dimensional) tensor for minimize."
)
......@@ -512,19 +507,14 @@ class MinimizeOp(ScipyVectorWrapperOp):
self.use_vectorized_jac = use_vectorized_jac
if jac:
grad_wrt_x = cast(
Variable, grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
)
grad_wrt_x = grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
self.fgraph.add_output(grad_wrt_x)
if hess:
hess_wrt_x = cast(
Variable,
jacobian(
self.fgraph.outputs[-1],
self.fgraph.inputs[0],
vectorize=use_vectorized_jac,
),
hess_wrt_x = jacobian(
self.fgraph.outputs[-1],
self.fgraph.inputs[0],
vectorize=use_vectorized_jac,
)
self.fgraph.add_output(hess_wrt_x)
......@@ -654,9 +644,7 @@ def minimize(
optimizer_kwargs=optimizer_kwargs,
)
solution, success = cast(
tuple[TensorVariable, TensorVariable], minimize_op(x, *args)
)
solution, success = minimize_op(x, *args)
return solution, success
......@@ -664,21 +652,23 @@ def minimize(
class RootScalarOp(ScipyScalarWrapperOp):
def __init__(
self,
variables,
*args,
equation,
method,
variables: TensorVariable,
*args: Variable,
equation: TensorVariable,
method: str,
jac: bool = False,
hess: bool = False,
optimizer_kwargs=None,
):
if not equation.ndim == 0:
if not (isinstance(variables, TensorVariable) and variables.ndim == 0):
raise ValueError(
"The variable `x` must be a scalar (0-dimensional) tensor for root_scalar."
)
if not (isinstance(equation, TensorVariable) and equation.ndim == 0):
raise ValueError(
"The equation must be a scalar (0-dimensional) tensor for root_scalar."
)
if not isinstance(variables, Variable) or variables not in ancestors(
[equation]
):
if variables not in ancestors([equation]):
raise ValueError(
"The variable `variables` must be an input to the computational graph of the equation."
)
......@@ -686,9 +676,7 @@ class RootScalarOp(ScipyScalarWrapperOp):
self.fgraph = FunctionGraph([variables, *args], [equation])
if jac:
f_prime = cast(
Variable, grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
)
f_prime = grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
self.fgraph.add_output(f_prime)
if hess:
......@@ -697,9 +685,7 @@ class RootScalarOp(ScipyScalarWrapperOp):
"Cannot set `hess=True` without `jac=True`. No methods use second derivatives without also"
" using first derivatives."
)
f_double_prime = cast(
Variable, grad(self.fgraph.outputs[-1], self.fgraph.inputs[0])
)
f_double_prime = grad(self.fgraph.outputs[-1], self.fgraph.inputs[0])
self.fgraph.add_output(f_double_prime)
self.method = method
......@@ -813,9 +799,7 @@ def root_scalar(
optimizer_kwargs=optimizer_kwargs,
)
solution, success = cast(
tuple[TensorVariable, TensorVariable], root_scalar_op(variable, *args)
)
solution, success = root_scalar_op(variable, *args)
return solution, success
......@@ -825,15 +809,19 @@ class RootOp(ScipyVectorWrapperOp):
def __init__(
self,
variables: Variable,
variables: TensorVariable,
*args: Variable,
equations: Variable,
method: str = "hybr",
equations: TensorVariable,
method: str,
jac: bool = True,
optimizer_kwargs: dict | None = None,
use_vectorized_jac: bool = False,
):
if cast(TensorVariable, variables).ndim != cast(TensorVariable, equations).ndim:
if not isinstance(variables, TensorVariable):
raise ValueError("The variable `variables` must be a tensor for root.")
if not isinstance(equations, TensorVariable):
raise ValueError("The equations must be a tensor for root.")
if variables.ndim != equations.ndim:
raise ValueError(
"The variable `variables` must have the same number of dimensions as the equations."
)
......@@ -916,12 +904,8 @@ class RootOp(ScipyVectorWrapperOp):
outputs[0][0] = res.x.reshape(variables.shape).astype(variables.dtype)
outputs[1][0] = np.bool_(res.success)
def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
def L_op(self, inputs, outputs, output_grads):
# TODO: Handle disconnected inputs
x, *args = inputs
x_star, _ = outputs
output_grad, _ = output_grads
......@@ -1004,9 +988,7 @@ def root(
use_vectorized_jac=use_vectorized_jac,
)
solution, success = cast(
tuple[TensorVariable, TensorVariable], root_op(variables, *args)
)
solution, success = root_op(variables, *args)
return solution, success
......
......@@ -15,6 +15,7 @@ pytensor/tensor/blas_headers.py
pytensor/tensor/elemwise.py
pytensor/tensor/extra_ops.py
pytensor/tensor/math.py
pytensor/tensor/optimize.py
pytensor/tensor/random/basic.py
pytensor/tensor/random/op.py
pytensor/tensor/random/utils.py
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论