提交 13bb75c1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove unnecessary aesara.tensor.nlinalg.QRIncomplete

上级 f2ecc2e5
...@@ -2887,23 +2887,6 @@ def local_gpu_magma_qr(fgraph, op, context_name, inputs, outputs): ...@@ -2887,23 +2887,6 @@ def local_gpu_magma_qr(fgraph, op, context_name, inputs, outputs):
return out return out
@register_opt("magma", "fast_compile")
@op_lifter([nlinalg.QRIncomplete])
@register_opt2([nlinalg.QRIncomplete], "magma", "fast_compile")
def local_gpu_magma_qr_incomplete(fgraph, op, context_name, inputs, outputs):
if not config.magma__enabled:
return
if inputs[0].dtype not in ["float16", "float32"]:
return
x = inputs[0]
if inputs[0].dtype == "float16":
x = inputs[0].astype("float32")
out = gpu_qr(x, complete=False)
if inputs[0].dtype == "float16":
return [out.astype("float16")]
return out
# Matrix inverse # Matrix inverse
@register_opt("magma", "fast_compile") @register_opt("magma", "fast_compile")
@op_lifter([nlinalg.MatrixInverse]) @op_lifter([nlinalg.MatrixInverse])
......
...@@ -44,15 +44,7 @@ from aesara.tensor.extra_ops import ( ...@@ -44,15 +44,7 @@ from aesara.tensor.extra_ops import (
UnravelIndex, UnravelIndex,
) )
from aesara.tensor.math import Dot, MaxAndArgmax from aesara.tensor.math import Dot, MaxAndArgmax
from aesara.tensor.nlinalg import ( from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
SVD,
Det,
Eig,
Eigh,
MatrixInverse,
QRFull,
QRIncomplete,
)
from aesara.tensor.nnet.basic import LogSoftmax, Softmax from aesara.tensor.nnet.basic import LogSoftmax, Softmax
from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
...@@ -835,16 +827,6 @@ def jax_funcify_QRFull(op, **kwargs): ...@@ -835,16 +827,6 @@ def jax_funcify_QRFull(op, **kwargs):
return qr_full return qr_full
@jax_funcify.register(QRIncomplete)
def jax_funcify_QRIncomplete(op, **kwargs):
mode = op.mode
def qr_incomplete(x, mode=mode):
return jnp.linalg.qr(x, mode=mode)
return qr_incomplete
@jax_funcify.register(SVD) @jax_funcify.register(SVD)
def jax_funcify_SVD(op, **kwargs): def jax_funcify_SVD(op, **kwargs):
full_matrices = op.full_matrices full_matrices = op.full_matrices
......
...@@ -408,47 +408,31 @@ class QRFull(Op): ...@@ -408,47 +408,31 @@ class QRFull(Op):
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
assert x.ndim == 2, "The input of qr function should be a matrix." assert x.ndim == 2, "The input of qr function should be a matrix."
q = matrix(dtype=x.dtype)
in_dtype = x.type.numpy_dtype
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
if self.mode != "raw": if self.mode != "raw":
r = matrix(dtype=x.dtype) r = matrix(dtype=x.dtype)
else: else:
r = vector(dtype=x.dtype) r = vector(dtype=x.dtype)
return Apply(self, [x], [q, r]) if self.mode != "r":
q = matrix(dtype=out_dtype)
def perform(self, node, inputs, outputs): outputs = [q, r]
(x,) = inputs else:
(q, r) = outputs outputs = [r]
assert x.ndim == 2, "The input of qr function should be a matrix."
q[0], r[0] = self._numop(x, self.mode)
class QRIncomplete(Op):
"""
Incomplete QR Decomposition.
Computes the QR decomposition of a matrix.
Factor the matrix a as qr and return a single matrix R.
"""
_numop = staticmethod(np.linalg.qr)
__props__ = ("mode",)
def __init__(self, mode):
self.mode = mode
def make_node(self, x): return Apply(self, [x], outputs)
x = as_tensor_variable(x)
assert x.ndim == 2, "The input of qr function should be a matrix."
r = matrix(dtype=x.dtype)
return Apply(self, [x], [r])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x,) = inputs (x,) = inputs
(r,) = outputs
assert x.ndim == 2, "The input of qr function should be a matrix." assert x.ndim == 2, "The input of qr function should be a matrix."
r[0] = self._numop(x, self.mode) res = self._numop(x, self.mode)
if self.mode != "r":
outputs[0][0], outputs[1][0] = res
else:
outputs[0][0] = res
def qr(a, mode="reduced"): def qr(a, mode="reduced"):
...@@ -493,12 +477,7 @@ def qr(a, mode="reduced"): ...@@ -493,12 +477,7 @@ def qr(a, mode="reduced"):
The upper-triangular matrix. The upper-triangular matrix.
""" """
return QRFull(mode)(a)
x = [[2, 1], [3, 4]]
if isinstance(np.linalg.qr(x, mode), tuple):
return QRFull(mode)(a)
else:
return QRIncomplete(mode)(a)
class SVD(Op): class SVD(Op):
......
...@@ -22,15 +22,7 @@ from aesara.gpuarray.linalg import ( ...@@ -22,15 +22,7 @@ from aesara.gpuarray.linalg import (
gpu_solve_lower_triangular, gpu_solve_lower_triangular,
gpu_svd, gpu_svd,
) )
from aesara.tensor.nlinalg import ( from aesara.tensor.nlinalg import SVD, MatrixInverse, QRFull, eigh, matrix_inverse, qr
SVD,
MatrixInverse,
QRFull,
QRIncomplete,
eigh,
matrix_inverse,
qr,
)
from aesara.tensor.slinalg import Cholesky, cholesky, imported_scipy from aesara.tensor.slinalg import Cholesky, cholesky, imported_scipy
from aesara.tensor.type import fmatrix, matrix, tensor3, vector from aesara.tensor.type import fmatrix, matrix, tensor3, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -380,7 +372,6 @@ class TestMagma: ...@@ -380,7 +372,6 @@ class TestMagma:
(MatrixInverse(), GpuMagmaMatrixInverse), (MatrixInverse(), GpuMagmaMatrixInverse),
(SVD(), GpuMagmaSVD), (SVD(), GpuMagmaSVD),
(QRFull(mode="reduced"), GpuMagmaQR), (QRFull(mode="reduced"), GpuMagmaQR),
(QRIncomplete(mode="r"), GpuMagmaQR),
# TODO: add support for float16 to Eigh numpy # TODO: add support for float16 to Eigh numpy
# (Eigh(), GpuMagmaEigh), # (Eigh(), GpuMagmaEigh),
(Cholesky(), GpuMagmaCholesky), (Cholesky(), GpuMagmaCholesky),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论