提交 a6975da3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

CholeskySolve inherits from BaseSolve

上级 d6f0185b
......@@ -49,7 +49,7 @@ class Cholesky(Op):
__props__ = ("lower", "destructive", "on_error")
def __init__(self, lower=True, on_error="raise"):
def __init__(self, *, lower=True, on_error="raise"):
self.lower = lower
self.destructive = False
if on_error not in ("raise", "nan"):
......@@ -127,77 +127,8 @@ class Cholesky(Op):
return [grad]
cholesky = Cholesky()
class CholeskySolve(Op):
__props__ = ("lower", "check_finite")
def __init__(
self,
lower=True,
check_finite=True,
):
self.lower = lower
self.check_finite = check_finite
def __repr__(self):
return "CholeskySolve{%s}" % str(self._props())
def make_node(self, C, b):
C = as_tensor_variable(C)
b = as_tensor_variable(b)
assert C.ndim == 2
assert b.ndim in (1, 2)
# infer dtype by solving the most simple
# case with (1, 1) matrices
o_dtype = scipy.linalg.solve(
np.eye(1).astype(C.dtype), np.eye(1).astype(b.dtype)
).dtype
x = tensor(dtype=o_dtype, shape=b.type.shape)
return Apply(self, [C, b], [x])
def perform(self, node, inputs, output_storage):
C, b = inputs
rval = scipy.linalg.cho_solve(
(C, self.lower),
b,
check_finite=self.check_finite,
)
output_storage[0][0] = rval
def infer_shape(self, fgraph, node, shapes):
Cshape, Bshape = shapes
rows = Cshape[1]
if len(Bshape) == 1: # b is a Vector
return [(rows,)]
else:
cols = Bshape[1] # b is a Matrix
return [(rows, cols)]
cho_solve = CholeskySolve()
def cho_solve(c_and_lower, b, check_finite=True):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
Parameters
----------
(c, lower) : tuple, (array, bool)
Cholesky factorization of a, as given by cho_factor
b : array
Right-hand side
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
"""
A, lower = c_and_lower
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b)
def cholesky(x, lower=True, on_error="raise"):
return Cholesky(lower=lower, on_error=on_error)(x)
class SolveBase(Op):
......@@ -210,6 +141,7 @@ class SolveBase(Op):
def __init__(
self,
*,
lower=False,
check_finite=True,
):
......@@ -276,28 +208,56 @@ class SolveBase(Op):
return [A_bar, b_bar]
def __repr__(self):
return f"{type(self).__name__}{self._props()}"
class CholeskySolve(SolveBase):
def __init__(self, **kwargs):
kwargs.setdefault("lower", True)
super().__init__(**kwargs)
def perform(self, node, inputs, output_storage):
C, b = inputs
rval = scipy.linalg.cho_solve(
(C, self.lower),
b,
check_finite=self.check_finite,
)
output_storage[0][0] = rval
def L_op(self, *args, **kwargs):
raise NotImplementedError()
def cho_solve(c_and_lower, b, *, check_finite=True):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
Parameters
----------
(c, lower) : tuple, (array, bool)
Cholesky factorization of a, as given by cho_factor
b : array
Right-hand side
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
"""
A, lower = c_and_lower
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b)
class SolveTriangular(SolveBase):
"""Solve a system of linear equations."""
__props__ = (
"lower",
"trans",
"unit_diagonal",
"lower",
"check_finite",
)
def __init__(
self,
trans=0,
lower=False,
unit_diagonal=False,
check_finite=True,
):
super().__init__(lower=lower, check_finite=check_finite)
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
super().__init__(**kwargs)
self.trans = trans
self.unit_diagonal = unit_diagonal
......@@ -326,6 +286,7 @@ class SolveTriangular(SolveBase):
def solve_triangular(
a: TensorVariable,
b: TensorVariable,
*,
trans: Union[int, str] = 0,
lower: bool = False,
unit_diagonal: bool = False,
......@@ -373,16 +334,11 @@ class Solve(SolveBase):
"check_finite",
)
def __init__(
self,
assume_a="gen",
lower=False,
check_finite=True,
):
def __init__(self, *, assume_a="gen", **kwargs):
if assume_a not in ("gen", "sym", "her", "pos"):
raise ValueError(f"{assume_a} is not a recognized matrix structure")
super().__init__(lower=lower, check_finite=check_finite)
super().__init__(**kwargs)
self.assume_a = assume_a
def perform(self, node, inputs, outputs):
......@@ -396,7 +352,7 @@ class Solve(SolveBase):
)
def solve(a, b, assume_a="gen", lower=False, check_finite=True):
def solve(a, b, *, assume_a="gen", lower=False, check_finite=True):
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
If the data matrix is known to be a particular type then supplying the
......
......@@ -46,7 +46,7 @@ rng = np.random.default_rng(42849)
],
)
def test_Cholesky(x, lower, exc):
g = slinalg.Cholesky(lower)(x)
g = slinalg.Cholesky(lower=lower)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
......@@ -91,7 +91,7 @@ def test_Cholesky(x, lower, exc):
],
)
def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower)(A, x)
g = slinalg.Solve(lower=lower)(A, x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
......@@ -125,7 +125,7 @@ def test_Solve(A, x, lower, exc):
],
)
def test_SolveTriangular(A, x, lower, exc):
g = slinalg.SolveTriangular(lower)(A, x)
g = slinalg.SolveTriangular(lower=lower)(A, x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
......
......@@ -361,7 +361,7 @@ class TestCholeskySolve(utt.InferShapeTester):
super().setup_method()
def test_repr(self):
assert repr(CholeskySolve()) == "CholeskySolve{(True, True)}"
assert repr(CholeskySolve()) == "CholeskySolve(lower=True,check_finite=True)"
def test_infer_shape(self):
rng = np.random.default_rng(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论