提交 e352e04f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add Numba conversions for aesara.tensor.nlinalg.pinv and inv

上级 7655b470
......@@ -69,7 +69,16 @@ from aesara.tensor.extra_ops import (
UnravelIndex,
)
from aesara.tensor.math import Dot, MaxAndArgmax
from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
from aesara.tensor.nlinalg import (
SVD,
Det,
Eig,
Eigh,
Inv,
MatrixInverse,
MatrixPinv,
QRFull,
)
from aesara.tensor.nnet.basic import LogSoftmax, Softmax
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve
......@@ -1811,6 +1820,32 @@ def numba_funcify_MatrixInverse(op, node, **kwargs):
return matrix_inverse
@numba_funcify.register(MatrixPinv)
def numba_funcify_MatrixPinv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def matrixpinv(x):
return np.linalg.pinv(inputs_cast(x)).astype(out_dtype)
return matrixpinv
@numba_funcify.register(Inv)
def numba_funcify_Inv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def inv(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
return inv
@numba_funcify.register(QRFull)
def numba_funcify_QRFull(op, node, **kwargs):
mode = op.mode
......
......@@ -2196,16 +2196,59 @@ def test_Eigh(x, uplo, exc):
@pytest.mark.parametrize(
"x, exc",
"op, x, exc, op_args",
[
(
nlinalg.MatrixInverse,
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
(),
),
(
nlinalg.MatrixInverse,
set_test_value(
aet.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
None,
(),
),
(
nlinalg.Inv,
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
(),
),
(
nlinalg.Inv,
set_test_value(
aet.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
None,
(),
),
(
nlinalg.MatrixPinv,
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
(True,),
),
(
nlinalg.MatrixPinv,
set_test_value(
aet.lmatrix(),
(lambda x: x.T.dot(x))(
......@@ -2213,11 +2256,12 @@ def test_Eigh(x, uplo, exc):
),
),
None,
(False,),
),
],
)
def test_MatrixInverse(x, exc):
g = nlinalg.MatrixInverse()(x)
def test_matrix_inverses(op, x, exc, op_args):
g = op(*op_args)(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论