提交 0cc6314b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Optimize: Handle gradient wrt scalar inputs and guard against unsupported types

上级 a032cfbe
......@@ -6,21 +6,24 @@ import numpy as np
import pytensor.scalar as ps
from pytensor.compile.function import function
from pytensor.gradient import grad, jacobian
from pytensor.gradient import grad, grad_not_implemented, jacobian
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
from pytensor.graph.replace import graph_replace
from pytensor.graph.traversal import ancestors, truncated_graph_inputs
from pytensor.scalar import ScalarType, ScalarVariable
from pytensor.tensor.basic import (
atleast_2d,
concatenate,
scalar_from_tensor,
tensor,
tensor_from_scalar,
zeros_like,
)
from pytensor.tensor.math import dot
from pytensor.tensor.slinalg import solve
from pytensor.tensor.type import DenseTensorType
from pytensor.tensor.variable import TensorVariable, Variable
......@@ -143,9 +146,9 @@ def _find_optimization_parameters(
def _get_parameter_grads_from_vector(
grad_wrt_args_vector: TensorVariable,
x_star: TensorVariable,
args: Sequence[Variable],
args: Sequence[TensorVariable | ScalarVariable],
output_grad: TensorVariable,
) -> list[TensorVariable]:
) -> list[TensorVariable | ScalarVariable]:
"""
Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
......@@ -160,7 +163,10 @@ def _get_parameter_grads_from_vector(
(*x_star.shape, *arg_shape)
)
grad_wrt_args.append(dot(output_grad, arg_grad))
grad_wrt_arg = dot(output_grad, arg_grad)
if isinstance(arg.type, ScalarType):
grad_wrt_arg = scalar_from_tensor(grad_wrt_arg)
grad_wrt_args.append(grad_wrt_arg)
cursor += arg_size
......@@ -267,12 +273,12 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
def scalar_implict_optimization_grads(
inner_fx: TensorVariable,
inner_x: TensorVariable,
inner_args: Sequence[Variable],
args: Sequence[Variable],
inner_args: Sequence[TensorVariable | ScalarVariable],
args: Sequence[TensorVariable | ScalarVariable],
x_star: TensorVariable,
output_grad: TensorVariable,
fgraph: FunctionGraph,
) -> list[Variable]:
) -> list[TensorVariable | ScalarVariable]:
df_dx, *df_dthetas = grad(
inner_fx, [inner_x, *inner_args], disconnected_inputs="ignore"
)
......@@ -291,11 +297,11 @@ def scalar_implict_optimization_grads(
def implict_optimization_grads(
df_dx: TensorVariable,
df_dtheta_columns: Sequence[TensorVariable],
args: Sequence[Variable],
args: Sequence[TensorVariable | ScalarVariable],
x_star: TensorVariable,
output_grad: TensorVariable,
fgraph: FunctionGraph,
) -> list[TensorVariable]:
) -> list[TensorVariable | ScalarVariable]:
r"""
Compute gradients of an optimization problem with respect to its parameters.
......@@ -410,7 +416,19 @@ class MinimizeScalarOp(ScipyScalarWrapperOp):
outputs[1][0] = np.bool_(res.success)
def L_op(self, inputs, outputs, output_grads):
# TODO: Handle disconnected inputs
x, *args = inputs
if non_supported_types := tuple(
inp.type
for inp in inputs
if not isinstance(inp.type, DenseTensorType | ScalarType)
):
# TODO: Support SparseTensorTypes
# TODO: Remaining types are likely just disconnected anyway
msg = f"Minimize gradient not implemented due to inputs of type {non_supported_types}"
return [
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
]
x_star, _ = outputs
output_grad, _ = output_grads
......@@ -560,7 +578,19 @@ class MinimizeOp(ScipyVectorWrapperOp):
outputs[1][0] = np.bool_(res.success)
def L_op(self, inputs, outputs, output_grads):
# TODO: Handle disconnected inputs
x, *args = inputs
if non_supported_types := tuple(
inp.type
for inp in inputs
if not isinstance(inp.type, DenseTensorType | ScalarType)
):
# TODO: Support SparseTensorTypes
# TODO: Remaining types are likely just disconnected anyway
msg = f"MinimizeOp gradient not implemented due to inputs of type {non_supported_types}"
return [
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
]
x_star, _success = outputs
output_grad, _ = output_grads
......@@ -727,7 +757,19 @@ class RootScalarOp(ScipyScalarWrapperOp):
outputs[1][0] = np.bool_(res.converged)
def L_op(self, inputs, outputs, output_grads):
# TODO: Handle disconnected inputs
x, *args = inputs
if non_supported_types := tuple(
inp.type
for inp in inputs
if not isinstance(inp.type, DenseTensorType | ScalarType)
):
# TODO: Support SparseTensorTypes
# TODO: Remaining types are likely just disconnected anyway
msg = f"RootScalarOp gradient not implemented due to inputs of type {non_supported_types}"
return [
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
]
x_star, _ = outputs
output_grad, _ = output_grads
......@@ -908,6 +950,17 @@ class RootOp(ScipyVectorWrapperOp):
def L_op(self, inputs, outputs, output_grads):
# TODO: Handle disconnected inputs
x, *args = inputs
if non_supported_types := tuple(
inp.type
for inp in inputs
if not isinstance(inp.type, DenseTensorType | ScalarType)
):
# TODO: Support SparseTensorTypes
# TODO: Remaining types are likely just disconnected anyway
msg = f"RootOp gradient not implemented due to inputs of type {non_supported_types}"
return [
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
]
x_star, _ = outputs
output_grad, _ = output_grads
......
......@@ -3,9 +3,10 @@ import pytest
import pytensor
import pytensor.tensor as pt
from pytensor import config, function
from pytensor.graph import Apply, Op
from pytensor.tensor import scalar
from pytensor import Variable, config, function
from pytensor.gradient import NullTypeGradError, disconnected_type
from pytensor.graph import Apply, Op, Type
from pytensor.tensor import alloc, scalar, scalar_from_tensor, tensor_from_scalar
from pytensor.tensor.optimize import minimize, minimize_scalar, root, root_scalar
from tests import unittest_tools as utt
......@@ -224,7 +225,7 @@ def test_root_system_of_equations():
@pytest.mark.parametrize("optimize_op", (minimize, root))
def test_minimize_0d(optimize_op):
def test_optimize_0d(optimize_op):
# Scipy vector minimizers upcast 0d x to 1d. We need to work-around this
class AssertScalar(Op):
......@@ -248,3 +249,106 @@ def test_minimize_0d(optimize_op):
np.testing.assert_allclose(
opt_x_res, 0, atol=1e-15 if floatX == "float64" else 1e-6
)
@pytest.mark.parametrize("optimize_op", (minimize, minimize_scalar, root, root_scalar))
def test_optimize_grad_scalar_arg(optimize_op):
# Regression test for https://github.com/pymc-devs/pytensor/pull/1744
x = scalar("x")
theta = scalar("theta")
theta_scalar = scalar_from_tensor(theta)
obj = tensor_from_scalar((scalar_from_tensor(x) + theta_scalar) ** 2)
x0, _ = optimize_op(obj, x)
# Confirm theta is a direct input to the node
assert x0.owner.inputs[1] is theta_scalar
grad_wrt_theta = pt.grad(x0, theta)
np.testing.assert_allclose(grad_wrt_theta.eval({x: np.pi, theta: np.e}), -1)
@pytest.mark.parametrize("optimize_op", (minimize, minimize_scalar, root, root_scalar))
def test_optimize_grad_disconnected_numerical_inp(optimize_op):
x = scalar("x", dtype="float64")
theta = scalar("theta", dtype="int64")
obj = alloc(x**2, theta).sum() # repeat theta times and sum
x0, _ = optimize_op(obj, x)
# Confirm theta is a direct input to the node
assert x0.owner.inputs[1] is theta
# This should technically raise, but does not right now
grad_wrt_theta = pt.grad(x0, theta, disconnected_inputs="raise")
np.testing.assert_allclose(grad_wrt_theta.eval({x: np.pi, theta: 5}), 0)
# This should work even if the previous one raised
grad_wrt_theta = pt.grad(x0, theta, disconnected_inputs="ignore")
np.testing.assert_allclose(grad_wrt_theta.eval({x: np.pi, theta: 5}), 0)
@pytest.mark.parametrize("optimize_op", (minimize, minimize_scalar, root, root_scalar))
def test_optimize_grad_disconnected_non_numerical_inp(optimize_op):
class StrType(Type):
def filter(self, x, **kwargs):
if isinstance(x, str):
return x
raise TypeError
class SmileOrFrown(Op):
def make_node(self, x, str_emoji):
return Apply(self, [x, str_emoji], [x.type()])
def perform(self, node, inputs, output_storage):
[x, str_emoji] = inputs
match str_emoji:
case ":)":
out = np.array(x)
case ":(":
out = np.array(-x)
case _:
ValueError("str_emoji must be a smile or a frown")
output_storage[0][0] = out
def connection_pattern(self, node):
# Gradient connected only to first input
return [[True], [False]]
def L_op(self, inputs, outputs, output_gradients):
[_x, str_emoji] = inputs
[g] = output_gradients
return [
self(g, str_emoji),
disconnected_type(),
]
# We could try to use real types like NoneTypeT or SliceType, but this is more robust to future API changes
str_type = StrType()
smile_or_frown = SmileOrFrown()
x = scalar("x", dtype="float64")
num_theta = pt.scalar("num_theta", dtype="float64")
str_theta = Variable(str_type, None, None, name="str_theta")
obj = (smile_or_frown(x, str_theta) + num_theta) ** 2
x_star, _ = optimize_op(obj, x)
# Confirm thetas are direct inputs to the node
assert set(x_star.owner.inputs[1:]) == {num_theta, str_theta}
# Confirm forward pass works, no point in worrying about gradient otherwise
np.testing.assert_allclose(
x_star.eval({x: np.pi, num_theta: np.e, str_theta: ":)"}),
-np.e,
)
np.testing.assert_allclose(
x_star.eval({x: np.pi, num_theta: np.e, str_theta: ":("}),
np.e,
)
with pytest.raises(NullTypeGradError):
pt.grad(x_star, str_theta, disconnected_inputs="raise")
# This could be supported, but it is not right now.
with pytest.raises(NullTypeGradError):
_grad_wrt_num_theta = pt.grad(x_star, num_theta, disconnected_inputs="raise")
# np.testing.assert_allclose(grad_wrt_num_theta.eval({x: np.pi, num_theta: np.e, str_theta: ":)"}), -1)
# np.testing.assert_allclose(grad_wrt_num_theta.eval({x: np.pi, num_theta: np.e, str_theta: ":("}), 1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论