提交 0cc6314b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Optimize: Handle gradient wrt scalar inputs and guard against unsupported types

上级 a032cfbe
...@@ -6,21 +6,24 @@ import numpy as np ...@@ -6,21 +6,24 @@ import numpy as np
import pytensor.scalar as ps import pytensor.scalar as ps
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.gradient import grad, jacobian from pytensor.gradient import grad, grad_not_implemented, jacobian
from pytensor.graph.basic import Apply, Constant from pytensor.graph.basic import Apply, Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
from pytensor.graph.replace import graph_replace from pytensor.graph.replace import graph_replace
from pytensor.graph.traversal import ancestors, truncated_graph_inputs from pytensor.graph.traversal import ancestors, truncated_graph_inputs
from pytensor.scalar import ScalarType, ScalarVariable
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
atleast_2d, atleast_2d,
concatenate, concatenate,
scalar_from_tensor,
tensor, tensor,
tensor_from_scalar, tensor_from_scalar,
zeros_like, zeros_like,
) )
from pytensor.tensor.math import dot from pytensor.tensor.math import dot
from pytensor.tensor.slinalg import solve from pytensor.tensor.slinalg import solve
from pytensor.tensor.type import DenseTensorType
from pytensor.tensor.variable import TensorVariable, Variable from pytensor.tensor.variable import TensorVariable, Variable
...@@ -143,9 +146,9 @@ def _find_optimization_parameters( ...@@ -143,9 +146,9 @@ def _find_optimization_parameters(
def _get_parameter_grads_from_vector( def _get_parameter_grads_from_vector(
grad_wrt_args_vector: TensorVariable, grad_wrt_args_vector: TensorVariable,
x_star: TensorVariable, x_star: TensorVariable,
args: Sequence[Variable], args: Sequence[TensorVariable | ScalarVariable],
output_grad: TensorVariable, output_grad: TensorVariable,
) -> list[TensorVariable]: ) -> list[TensorVariable | ScalarVariable]:
""" """
Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters, Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter. returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
...@@ -160,7 +163,10 @@ def _get_parameter_grads_from_vector( ...@@ -160,7 +163,10 @@ def _get_parameter_grads_from_vector(
(*x_star.shape, *arg_shape) (*x_star.shape, *arg_shape)
) )
grad_wrt_args.append(dot(output_grad, arg_grad)) grad_wrt_arg = dot(output_grad, arg_grad)
if isinstance(arg.type, ScalarType):
grad_wrt_arg = scalar_from_tensor(grad_wrt_arg)
grad_wrt_args.append(grad_wrt_arg)
cursor += arg_size cursor += arg_size
...@@ -267,12 +273,12 @@ class ScipyVectorWrapperOp(ScipyWrapperOp): ...@@ -267,12 +273,12 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
def scalar_implict_optimization_grads( def scalar_implict_optimization_grads(
inner_fx: TensorVariable, inner_fx: TensorVariable,
inner_x: TensorVariable, inner_x: TensorVariable,
inner_args: Sequence[Variable], inner_args: Sequence[TensorVariable | ScalarVariable],
args: Sequence[Variable], args: Sequence[TensorVariable | ScalarVariable],
x_star: TensorVariable, x_star: TensorVariable,
output_grad: TensorVariable, output_grad: TensorVariable,
fgraph: FunctionGraph, fgraph: FunctionGraph,
) -> list[Variable]: ) -> list[TensorVariable | ScalarVariable]:
df_dx, *df_dthetas = grad( df_dx, *df_dthetas = grad(
inner_fx, [inner_x, *inner_args], disconnected_inputs="ignore" inner_fx, [inner_x, *inner_args], disconnected_inputs="ignore"
) )
...@@ -291,11 +297,11 @@ def scalar_implict_optimization_grads( ...@@ -291,11 +297,11 @@ def scalar_implict_optimization_grads(
def implict_optimization_grads( def implict_optimization_grads(
df_dx: TensorVariable, df_dx: TensorVariable,
df_dtheta_columns: Sequence[TensorVariable], df_dtheta_columns: Sequence[TensorVariable],
args: Sequence[Variable], args: Sequence[TensorVariable | ScalarVariable],
x_star: TensorVariable, x_star: TensorVariable,
output_grad: TensorVariable, output_grad: TensorVariable,
fgraph: FunctionGraph, fgraph: FunctionGraph,
) -> list[TensorVariable]: ) -> list[TensorVariable | ScalarVariable]:
r""" r"""
Compute gradients of an optimization problem with respect to its parameters. Compute gradients of an optimization problem with respect to its parameters.
...@@ -410,7 +416,19 @@ class MinimizeScalarOp(ScipyScalarWrapperOp): ...@@ -410,7 +416,19 @@ class MinimizeScalarOp(ScipyScalarWrapperOp):
outputs[1][0] = np.bool_(res.success) outputs[1][0] = np.bool_(res.success)
def L_op(self, inputs, outputs, output_grads): def L_op(self, inputs, outputs, output_grads):
# TODO: Handle disconnected inputs
x, *args = inputs x, *args = inputs
if non_supported_types := tuple(
inp.type
for inp in inputs
if not isinstance(inp.type, DenseTensorType | ScalarType)
):
# TODO: Support SparseTensorTypes
# TODO: Remaining types are likely just disconnected anyway
msg = f"Minimize gradient not implemented due to inputs of type {non_supported_types}"
return [
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
]
x_star, _ = outputs x_star, _ = outputs
output_grad, _ = output_grads output_grad, _ = output_grads
...@@ -560,7 +578,19 @@ class MinimizeOp(ScipyVectorWrapperOp): ...@@ -560,7 +578,19 @@ class MinimizeOp(ScipyVectorWrapperOp):
outputs[1][0] = np.bool_(res.success) outputs[1][0] = np.bool_(res.success)
def L_op(self, inputs, outputs, output_grads): def L_op(self, inputs, outputs, output_grads):
# TODO: Handle disconnected inputs
x, *args = inputs x, *args = inputs
if non_supported_types := tuple(
inp.type
for inp in inputs
if not isinstance(inp.type, DenseTensorType | ScalarType)
):
# TODO: Support SparseTensorTypes
# TODO: Remaining types are likely just disconnected anyway
msg = f"MinimizeOp gradient not implemented due to inputs of type {non_supported_types}"
return [
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
]
x_star, _success = outputs x_star, _success = outputs
output_grad, _ = output_grads output_grad, _ = output_grads
...@@ -727,7 +757,19 @@ class RootScalarOp(ScipyScalarWrapperOp): ...@@ -727,7 +757,19 @@ class RootScalarOp(ScipyScalarWrapperOp):
outputs[1][0] = np.bool_(res.converged) outputs[1][0] = np.bool_(res.converged)
def L_op(self, inputs, outputs, output_grads): def L_op(self, inputs, outputs, output_grads):
# TODO: Handle disconnected inputs
x, *args = inputs x, *args = inputs
if non_supported_types := tuple(
inp.type
for inp in inputs
if not isinstance(inp.type, DenseTensorType | ScalarType)
):
# TODO: Support SparseTensorTypes
# TODO: Remaining types are likely just disconnected anyway
msg = f"RootScalarOp gradient not implemented due to inputs of type {non_supported_types}"
return [
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
]
x_star, _ = outputs x_star, _ = outputs
output_grad, _ = output_grads output_grad, _ = output_grads
...@@ -908,6 +950,17 @@ class RootOp(ScipyVectorWrapperOp): ...@@ -908,6 +950,17 @@ class RootOp(ScipyVectorWrapperOp):
def L_op(self, inputs, outputs, output_grads): def L_op(self, inputs, outputs, output_grads):
# TODO: Handle disconnected inputs # TODO: Handle disconnected inputs
x, *args = inputs x, *args = inputs
if non_supported_types := tuple(
inp.type
for inp in inputs
if not isinstance(inp.type, DenseTensorType | ScalarType)
):
# TODO: Support SparseTensorTypes
# TODO: Remaining types are likely just disconnected anyway
msg = f"RootOp gradient not implemented due to inputs of type {non_supported_types}"
return [
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
]
x_star, _ = outputs x_star, _ = outputs
output_grad, _ = output_grads output_grad, _ = output_grads
......
...@@ -3,9 +3,10 @@ import pytest ...@@ -3,9 +3,10 @@ import pytest
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, function from pytensor import Variable, config, function
from pytensor.graph import Apply, Op from pytensor.gradient import NullTypeGradError, disconnected_type
from pytensor.tensor import scalar from pytensor.graph import Apply, Op, Type
from pytensor.tensor import alloc, scalar, scalar_from_tensor, tensor_from_scalar
from pytensor.tensor.optimize import minimize, minimize_scalar, root, root_scalar from pytensor.tensor.optimize import minimize, minimize_scalar, root, root_scalar
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -224,7 +225,7 @@ def test_root_system_of_equations(): ...@@ -224,7 +225,7 @@ def test_root_system_of_equations():
@pytest.mark.parametrize("optimize_op", (minimize, root)) @pytest.mark.parametrize("optimize_op", (minimize, root))
def test_minimize_0d(optimize_op): def test_optimize_0d(optimize_op):
# Scipy vector minimizers upcast 0d x to 1d. We need to work-around this # Scipy vector minimizers upcast 0d x to 1d. We need to work-around this
class AssertScalar(Op): class AssertScalar(Op):
...@@ -248,3 +249,106 @@ def test_minimize_0d(optimize_op): ...@@ -248,3 +249,106 @@ def test_minimize_0d(optimize_op):
np.testing.assert_allclose( np.testing.assert_allclose(
opt_x_res, 0, atol=1e-15 if floatX == "float64" else 1e-6 opt_x_res, 0, atol=1e-15 if floatX == "float64" else 1e-6
) )
@pytest.mark.parametrize("optimize_op", (minimize, minimize_scalar, root, root_scalar))
def test_optimize_grad_scalar_arg(optimize_op):
# Regression test for https://github.com/pymc-devs/pytensor/pull/1744
x = scalar("x")
theta = scalar("theta")
theta_scalar = scalar_from_tensor(theta)
obj = tensor_from_scalar((scalar_from_tensor(x) + theta_scalar) ** 2)
x0, _ = optimize_op(obj, x)
# Confirm theta is a direct input to the node
assert x0.owner.inputs[1] is theta_scalar
grad_wrt_theta = pt.grad(x0, theta)
np.testing.assert_allclose(grad_wrt_theta.eval({x: np.pi, theta: np.e}), -1)
@pytest.mark.parametrize("optimize_op", (minimize, minimize_scalar, root, root_scalar))
def test_optimize_grad_disconnected_numerical_inp(optimize_op):
x = scalar("x", dtype="float64")
theta = scalar("theta", dtype="int64")
obj = alloc(x**2, theta).sum() # repeat theta times and sum
x0, _ = optimize_op(obj, x)
# Confirm theta is a direct input to the node
assert x0.owner.inputs[1] is theta
# This should technically raise, but does not right now
grad_wrt_theta = pt.grad(x0, theta, disconnected_inputs="raise")
np.testing.assert_allclose(grad_wrt_theta.eval({x: np.pi, theta: 5}), 0)
# This should work even if the previous one raised
grad_wrt_theta = pt.grad(x0, theta, disconnected_inputs="ignore")
np.testing.assert_allclose(grad_wrt_theta.eval({x: np.pi, theta: 5}), 0)
@pytest.mark.parametrize("optimize_op", (minimize, minimize_scalar, root, root_scalar))
def test_optimize_grad_disconnected_non_numerical_inp(optimize_op):
class StrType(Type):
def filter(self, x, **kwargs):
if isinstance(x, str):
return x
raise TypeError
class SmileOrFrown(Op):
def make_node(self, x, str_emoji):
return Apply(self, [x, str_emoji], [x.type()])
def perform(self, node, inputs, output_storage):
[x, str_emoji] = inputs
match str_emoji:
case ":)":
out = np.array(x)
case ":(":
out = np.array(-x)
case _:
ValueError("str_emoji must be a smile or a frown")
output_storage[0][0] = out
def connection_pattern(self, node):
# Gradient connected only to first input
return [[True], [False]]
def L_op(self, inputs, outputs, output_gradients):
[_x, str_emoji] = inputs
[g] = output_gradients
return [
self(g, str_emoji),
disconnected_type(),
]
# We could try to use real types like NoneTypeT or SliceType, but this is more robust to future API changes
str_type = StrType()
smile_or_frown = SmileOrFrown()
x = scalar("x", dtype="float64")
num_theta = pt.scalar("num_theta", dtype="float64")
str_theta = Variable(str_type, None, None, name="str_theta")
obj = (smile_or_frown(x, str_theta) + num_theta) ** 2
x_star, _ = optimize_op(obj, x)
# Confirm thetas are direct inputs to the node
assert set(x_star.owner.inputs[1:]) == {num_theta, str_theta}
# Confirm forward pass works, no point in worrying about gradient otherwise
np.testing.assert_allclose(
x_star.eval({x: np.pi, num_theta: np.e, str_theta: ":)"}),
-np.e,
)
np.testing.assert_allclose(
x_star.eval({x: np.pi, num_theta: np.e, str_theta: ":("}),
np.e,
)
with pytest.raises(NullTypeGradError):
pt.grad(x_star, str_theta, disconnected_inputs="raise")
# This could be supported, but it is not right now.
with pytest.raises(NullTypeGradError):
_grad_wrt_num_theta = pt.grad(x_star, num_theta, disconnected_inputs="raise")
# np.testing.assert_allclose(grad_wrt_num_theta.eval({x: np.pi, num_theta: np.e, str_theta: ":)"}), -1)
# np.testing.assert_allclose(grad_wrt_num_theta.eval({x: np.pi, num_theta: np.e, str_theta: ":("}), 1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论