提交 d6f0185b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Thomas Wiecki

Remove duplicated `Inv` Op

上级 1492d3f7
...@@ -14,7 +14,6 @@ from pytensor.tensor.nlinalg import ( ...@@ -14,7 +14,6 @@ from pytensor.tensor.nlinalg import (
Det, Det,
Eig, Eig,
Eigh, Eigh,
Inv,
MatrixInverse, MatrixInverse,
MatrixPinv, MatrixPinv,
QRFull, QRFull,
...@@ -125,18 +124,6 @@ def numba_funcify_Eigh(op, node, **kwargs): ...@@ -125,18 +124,6 @@ def numba_funcify_Eigh(op, node, **kwargs):
return eigh return eigh
@numba_funcify.register(Inv)
def numba_funcify_Inv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_basic.numba_njit(inline="always")
def inv(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
return inv
@numba_funcify.register(MatrixInverse) @numba_funcify.register(MatrixInverse)
def numba_funcify_MatrixInverse(op, node, **kwargs): def numba_funcify_MatrixInverse(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
......
...@@ -78,25 +78,6 @@ def pinv(x, hermitian=False): ...@@ -78,25 +78,6 @@ def pinv(x, hermitian=False):
return MatrixPinv(hermitian=hermitian)(x) return MatrixPinv(hermitian=hermitian)(x)
class Inv(Op):
"""Computes the inverse of one or more matrices."""
def make_node(self, x):
x = as_tensor_variable(x)
return Apply(self, [x], [x.type()])
def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
z[0] = np.linalg.inv(x).astype(x.dtype)
def infer_shape(self, fgraph, node, shapes):
return shapes
inv = Inv()
class MatrixInverse(Op): class MatrixInverse(Op):
r"""Computes the inverse of a matrix :math:`A`. r"""Computes the inverse of a matrix :math:`A`.
...@@ -169,7 +150,7 @@ class MatrixInverse(Op): ...@@ -169,7 +150,7 @@ class MatrixInverse(Op):
return shapes return shapes
matrix_inverse = MatrixInverse() inv = matrix_inverse = MatrixInverse()
def matrix_dot(*args): def matrix_dot(*args):
......
...@@ -352,26 +352,6 @@ def test_Eigh(x, uplo, exc): ...@@ -352,26 +352,6 @@ def test_Eigh(x, uplo, exc):
None, None,
(), (),
), ),
(
nlinalg.Inv,
set_test_value(
at.dmatrix(),
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
),
None,
(),
),
(
nlinalg.Inv,
set_test_value(
at.lmatrix(),
(lambda x: x.T.dot(x))(
rng.integers(1, 10, size=(3, 3)).astype("int64")
),
),
None,
(),
),
( (
nlinalg.MatrixPinv, nlinalg.MatrixPinv,
set_test_value( set_test_value(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论