提交 22cda11a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Handle upcasting of scalar to vector arrays by scipy vector optimizers

上级 e126020a
......@@ -233,14 +233,8 @@ class ScipyWrapperOp(Op, HasInnerGraph):
class ScipyScalarWrapperOp(ScipyWrapperOp):
def build_fn(self):
"""
This is overloaded because scipy converts scalar inputs to lists, changing the return type. The
wrapper function logic is there to handle this.
"""
# We have no control over the inputs to the scipy inner function for scalar_minimize. As a result,
# we need to adjust the graph to work with what scipy will be passing into the inner function --
# always scalar, and always float64
# We need to adjust the graph to work with what scipy will be passing into the inner function --
# always scalar array of float64 type
x, *args = self.inner_inputs
new_root_x = ps.float64(name="x_scalar")
new_x = tensor_from_scalar(new_root_x.astype(x.type.dtype))
......@@ -255,6 +249,24 @@ class ScipyScalarWrapperOp(ScipyWrapperOp):
self._fn_wrapped = LRUCache1(fn)
class ScipyVectorWrapperOp(ScipyWrapperOp):
def build_fn(self):
# We need to adjust the graph to work with what scipy will be passing into the inner function --
# always a vector array with size of at least 1
x, *args = self.inner_inputs
if x.type.shape != ():
return super().build_fn()
new_root_x = x[None].type()
new_x = new_root_x.squeeze()
new_outputs = graph_replace(self.inner_outputs, {x: new_x})
self._fn = fn = function([new_root_x, *args], new_outputs, trust_input=True)
# Do this reassignment to see the compiled graph in the dprint
# self.fgraph = fn.maker.fgraph
self._fn_wrapped = LRUCache1(fn)
def scalar_implict_optimization_grads(
inner_fx: Variable,
inner_x: Variable,
......@@ -474,7 +486,7 @@ def minimize_scalar(
return solution, success
class MinimizeOp(ScipyWrapperOp):
class MinimizeOp(ScipyVectorWrapperOp):
def __init__(
self,
x: Variable,
......@@ -808,7 +820,7 @@ def root_scalar(
return solution, success
class RootOp(ScipyWrapperOp):
class RootOp(ScipyVectorWrapperOp):
__props__ = ("method", "jac")
def __init__(
......
......@@ -4,6 +4,8 @@ import pytest
import pytensor
import pytensor.tensor as pt
from pytensor import config, function
from pytensor.graph import Apply, Op
from pytensor.tensor import scalar
from pytensor.tensor.optimize import minimize, minimize_scalar, root, root_scalar
from tests import unittest_tools as utt
......@@ -219,3 +221,30 @@ def test_root_system_of_equations():
utt.verify_grad(
root_fn, [x0, a_val, b_val], eps=1e-6 if floatX == "float64" else 1e-3
)
@pytest.mark.parametrize("optimize_op", (minimize, root))
def test_minimize_0d(optimize_op):
# Scipy vector minimizers upcast 0d x to 1d. We need to work-around this
class AssertScalar(Op):
view_map = {0: [0]}
def make_node(self, x):
return Apply(self, [x], [x.type()])
def perform(self, node, inputs, output_storage):
[x] = inputs
assert x.ndim == 0
output_storage[0][0] = x
def L_op(self, inputs, outputs, out_grads):
return out_grads
x = scalar("x")
x_check = AssertScalar()(x)
opt_x, _ = optimize_op(x_check**2, x)
opt_x_res = opt_x.eval({x: np.array(5, dtype=x.type.dtype)})
np.testing.assert_allclose(
opt_x_res, 0, atol=1e-15 if floatX == "float64" else 1e-6
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论