提交 e352e04f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add Numba conversions for aesara.tensor.nlinalg.pinv and inv

上级 7655b470
...@@ -69,7 +69,16 @@ from aesara.tensor.extra_ops import ( ...@@ -69,7 +69,16 @@ from aesara.tensor.extra_ops import (
UnravelIndex, UnravelIndex,
) )
from aesara.tensor.math import Dot, MaxAndArgmax from aesara.tensor.math import Dot, MaxAndArgmax
from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull from aesara.tensor.nlinalg import (
SVD,
Det,
Eig,
Eigh,
Inv,
MatrixInverse,
MatrixPinv,
QRFull,
)
from aesara.tensor.nnet.basic import LogSoftmax, Softmax from aesara.tensor.nnet.basic import LogSoftmax, Softmax
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve from aesara.tensor.slinalg import Cholesky, Solve
...@@ -1811,6 +1820,32 @@ def numba_funcify_MatrixInverse(op, node, **kwargs): ...@@ -1811,6 +1820,32 @@ def numba_funcify_MatrixInverse(op, node, **kwargs):
return matrix_inverse return matrix_inverse
@numba_funcify.register(MatrixPinv)
def numba_funcify_MatrixPinv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def matrixpinv(x):
return np.linalg.pinv(inputs_cast(x)).astype(out_dtype)
return matrixpinv
@numba_funcify.register(Inv)
def numba_funcify_Inv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit(inline="always")
def inv(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
return inv
@numba_funcify.register(QRFull) @numba_funcify.register(QRFull)
def numba_funcify_QRFull(op, node, **kwargs): def numba_funcify_QRFull(op, node, **kwargs):
mode = op.mode mode = op.mode
......
...@@ -2196,16 +2196,59 @@ def test_Eigh(x, uplo, exc): ...@@ -2196,16 +2196,59 @@ def test_Eigh(x, uplo, exc):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, exc", "op, x, exc, op_args",
[ [
( (
nlinalg.MatrixInverse,
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
(),
),
(
nlinalg.MatrixInverse,
set_test_value(
aet.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
None,
(),
),
(
nlinalg.Inv,
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
(),
),
(
nlinalg.Inv,
set_test_value(
aet.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
None,
(),
),
(
nlinalg.MatrixPinv,
set_test_value( set_test_value(
aet.dmatrix(), aet.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
), ),
None, None,
(True,),
), ),
( (
nlinalg.MatrixPinv,
set_test_value( set_test_value(
aet.lmatrix(), aet.lmatrix(),
(lambda x: x.T.dot(x))( (lambda x: x.T.dot(x))(
...@@ -2213,11 +2256,12 @@ def test_Eigh(x, uplo, exc): ...@@ -2213,11 +2256,12 @@ def test_Eigh(x, uplo, exc):
), ),
), ),
None, None,
(False,),
), ),
], ],
) )
def test_MatrixInverse(x, exc): def test_matrix_inverses(op, x, exc, op_args):
g = nlinalg.MatrixInverse()(x) g = op(*op_args)(x)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc) cm = contextlib.suppress() if exc is None else pytest.warns(exc)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论