提交 22cda11a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Handle upcasting of scalar to vector arrays by scipy vector optimizers

上级 e126020a
...@@ -233,14 +233,8 @@ class ScipyWrapperOp(Op, HasInnerGraph): ...@@ -233,14 +233,8 @@ class ScipyWrapperOp(Op, HasInnerGraph):
class ScipyScalarWrapperOp(ScipyWrapperOp): class ScipyScalarWrapperOp(ScipyWrapperOp):
def build_fn(self): def build_fn(self):
""" # We need to adjust the graph to work with what scipy will be passing into the inner function --
This is overloaded because scipy converts scalar inputs to lists, changing the return type. The # always scalar array of float64 type
wrapper function logic is there to handle this.
"""
# We have no control over the inputs to the scipy inner function for scalar_minimize. As a result,
# we need to adjust the graph to work with what scipy will be passing into the inner function --
# always scalar, and always float64
x, *args = self.inner_inputs x, *args = self.inner_inputs
new_root_x = ps.float64(name="x_scalar") new_root_x = ps.float64(name="x_scalar")
new_x = tensor_from_scalar(new_root_x.astype(x.type.dtype)) new_x = tensor_from_scalar(new_root_x.astype(x.type.dtype))
...@@ -255,6 +249,24 @@ class ScipyScalarWrapperOp(ScipyWrapperOp): ...@@ -255,6 +249,24 @@ class ScipyScalarWrapperOp(ScipyWrapperOp):
self._fn_wrapped = LRUCache1(fn) self._fn_wrapped = LRUCache1(fn)
class ScipyVectorWrapperOp(ScipyWrapperOp):
def build_fn(self):
# We need to adjust the graph to work with what scipy will be passing into the inner function --
# always a vector array with size of at least 1
x, *args = self.inner_inputs
if x.type.shape != ():
return super().build_fn()
new_root_x = x[None].type()
new_x = new_root_x.squeeze()
new_outputs = graph_replace(self.inner_outputs, {x: new_x})
self._fn = fn = function([new_root_x, *args], new_outputs, trust_input=True)
# Do this reassignment to see the compiled graph in the dprint
# self.fgraph = fn.maker.fgraph
self._fn_wrapped = LRUCache1(fn)
def scalar_implict_optimization_grads( def scalar_implict_optimization_grads(
inner_fx: Variable, inner_fx: Variable,
inner_x: Variable, inner_x: Variable,
...@@ -474,7 +486,7 @@ def minimize_scalar( ...@@ -474,7 +486,7 @@ def minimize_scalar(
return solution, success return solution, success
class MinimizeOp(ScipyWrapperOp): class MinimizeOp(ScipyVectorWrapperOp):
def __init__( def __init__(
self, self,
x: Variable, x: Variable,
...@@ -808,7 +820,7 @@ def root_scalar( ...@@ -808,7 +820,7 @@ def root_scalar(
return solution, success return solution, success
class RootOp(ScipyWrapperOp): class RootOp(ScipyVectorWrapperOp):
__props__ = ("method", "jac") __props__ = ("method", "jac")
def __init__( def __init__(
......
...@@ -4,6 +4,8 @@ import pytest ...@@ -4,6 +4,8 @@ import pytest
import pytensor import pytensor
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, function from pytensor import config, function
from pytensor.graph import Apply, Op
from pytensor.tensor import scalar
from pytensor.tensor.optimize import minimize, minimize_scalar, root, root_scalar from pytensor.tensor.optimize import minimize, minimize_scalar, root, root_scalar
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -219,3 +221,30 @@ def test_root_system_of_equations(): ...@@ -219,3 +221,30 @@ def test_root_system_of_equations():
utt.verify_grad( utt.verify_grad(
root_fn, [x0, a_val, b_val], eps=1e-6 if floatX == "float64" else 1e-3 root_fn, [x0, a_val, b_val], eps=1e-6 if floatX == "float64" else 1e-3
) )
@pytest.mark.parametrize("optimize_op", (minimize, root))
def test_minimize_0d(optimize_op):
# Scipy vector minimizers upcast 0d x to 1d. We need to work-around this
class AssertScalar(Op):
view_map = {0: [0]}
def make_node(self, x):
return Apply(self, [x], [x.type()])
def perform(self, node, inputs, output_storage):
[x] = inputs
assert x.ndim == 0
output_storage[0][0] = x
def L_op(self, inputs, outputs, out_grads):
return out_grads
x = scalar("x")
x_check = AssertScalar()(x)
opt_x, _ = optimize_op(x_check**2, x)
opt_x_res = opt_x.eval({x: np.array(5, dtype=x.type.dtype)})
np.testing.assert_allclose(
opt_x_res, 0, atol=1e-15 if floatX == "float64" else 1e-6
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论