提交 36e2e1a5 authored 作者: Jing Xie's avatar Jing Xie 提交者: Brandon T. Willard

Add CholeskySolve Op

上级 01a24b83
...@@ -188,6 +188,77 @@ class CholeskyGrad(Op): ...@@ -188,6 +188,77 @@ class CholeskyGrad(Op):
return [shapes[0]] return [shapes[0]]
class CholeskySolve(Op):
__props__ = ("lower", "check_finite")
def __init__(
self,
lower=True,
check_finite=True,
):
self.lower = lower
self.check_finite = check_finite
def __repr__(self):
return "CholeskySolve{%s}" % str(self._props())
def make_node(self, C, b):
C = as_tensor_variable(C)
b = as_tensor_variable(b)
assert C.ndim == 2
assert b.ndim in [1, 2]
# infer dtype by solving the most simple
# case with (1, 1) matrices
o_dtype = scipy.linalg.solve(
np.eye(1).astype(C.dtype), np.eye(1).astype(b.dtype)
).dtype
x = tensor(broadcastable=b.broadcastable, dtype=o_dtype)
return Apply(self, [C, b], [x])
def perform(self, node, inputs, output_storage):
C, b = inputs
rval = scipy.linalg.cho_solve(
(C, self.lower),
b,
check_finite=self.check_finite,
)
output_storage[0][0] = rval
def infer_shape(self, fgraph, node, shapes):
Cshape, Bshape = shapes
rows = Cshape[1]
if len(Bshape) == 1: # b is a Vector
return [(rows,)]
else:
cols = Bshape[1] # b is a Matrix
return [(rows, cols)]
cho_solve = CholeskySolve()
def cho_solve(c_and_lower, b, check_finite=True):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
Parameters
----------
(c, lower) : tuple, (array, bool)
Cholesky factorization of a, as given by cho_factor
b : array
Right-hand side
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
"""
A, lower = c_and_lower
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b)
class Solve(Op): class Solve(Op):
""" """
Solve a system of linear equations. Solve a system of linear equations.
......
...@@ -12,7 +12,9 @@ from aesara.configdefaults import config ...@@ -12,7 +12,9 @@ from aesara.configdefaults import config
from aesara.tensor.slinalg import ( from aesara.tensor.slinalg import (
Cholesky, Cholesky,
CholeskyGrad, CholeskyGrad,
CholeskySolve,
Solve, Solve,
cho_solve,
cholesky, cholesky,
eigvalsh, eigvalsh,
expm, expm,
...@@ -314,6 +316,117 @@ class TestSolve(utt.InferShapeTester): ...@@ -314,6 +316,117 @@ class TestSolve(utt.InferShapeTester):
self.verify_solve_grad(m, n, assume_a, lower, rng) self.verify_solve_grad(m, n, assume_a, lower, rng)
class TestCholeskySolve(utt.InferShapeTester):
def setup_method(self):
self.op_class = CholeskySolve
self.op = CholeskySolve()
self.op_upper = CholeskySolve(lower=False)
super().setup_method()
def test_repr(self):
assert repr(CholeskySolve()) == "CholeskySolve{(True, True)}"
def test_infer_shape(self):
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
b = matrix()
self._compile_and_check(
[A, b], # aesara.function inputs
[self.op(A, b)], # aesara.function outputs
# A must be square
[
np.asarray(rng.random((5, 5)), dtype=config.floatX),
np.asarray(rng.random((5, 1)), dtype=config.floatX),
],
self.op_class,
warn=False,
)
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
b = vector()
self._compile_and_check(
[A, b], # aesara.function inputs
[self.op(A, b)], # aesara.function outputs
# A must be square
[
np.asarray(rng.random((5, 5)), dtype=config.floatX),
np.asarray(rng.random((5)), dtype=config.floatX),
],
self.op_class,
warn=False,
)
def test_solve_correctness(self):
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
b = matrix()
y = self.op(A, b)
cho_solve_lower_func = aesara.function([A, b], y)
y = self.op_upper(A, b)
cho_solve_upper_func = aesara.function([A, b], y)
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)
A_val = np.tril(np.asarray(rng.random((5, 5)), dtype=config.floatX))
assert np.allclose(
scipy.linalg.cho_solve((A_val, True), b_val),
cho_solve_lower_func(A_val, b_val),
)
A_val = np.triu(np.asarray(rng.random((5, 5)), dtype=config.floatX))
assert np.allclose(
scipy.linalg.cho_solve((A_val, False), b_val),
cho_solve_upper_func(A_val, b_val),
)
def test_solve_dtype(self):
dtypes = [
"uint8",
"uint16",
"uint32",
"uint64",
"int8",
"int16",
"int32",
"int64",
"float16",
"float32",
"float64",
]
A_val = np.eye(2)
b_val = np.ones((2, 1))
# try all dtype combinations
for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
A = matrix(dtype=A_dtype)
b = matrix(dtype=b_dtype)
x = self.op(A, b)
fn = function([A, b], x)
x_result = fn(A_val.astype(A_dtype), b_val.astype(b_dtype))
assert x.dtype == x_result.dtype
def test_cho_solve():
rng = np.random.default_rng(utt.fetch_seed())
A = matrix()
b = matrix()
y = cho_solve((A, True), b)
cho_solve_lower_func = aesara.function([A, b], y)
b_val = np.asarray(rng.random((5, 1)), dtype=config.floatX)
A_val = np.tril(np.asarray(rng.random((5, 5)), dtype=config.floatX))
assert np.allclose(
scipy.linalg.cho_solve((A_val, True), b_val),
cho_solve_lower_func(A_val, b_val),
)
def test_expm(): def test_expm():
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
A = rng.standard_normal((5, 5)).astype(config.floatX) A = rng.standard_normal((5, 5)).astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论