提交 1a0d12d8 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Implement `Blockwise` for `SVD`

上级 ee7c9469
...@@ -523,6 +523,7 @@ def qr(a, mode="reduced"): ...@@ -523,6 +523,7 @@ def qr(a, mode="reduced"):
class SVD(Op): class SVD(Op):
""" """
Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V
Parameters Parameters
---------- ----------
...@@ -543,13 +544,23 @@ class SVD(Op): ...@@ -543,13 +544,23 @@ class SVD(Op):
def __init__(self, full_matrices: bool = True, compute_uv: bool = True): def __init__(self, full_matrices: bool = True, compute_uv: bool = True):
self.full_matrices = bool(full_matrices) self.full_matrices = bool(full_matrices)
self.compute_uv = bool(compute_uv) self.compute_uv = bool(compute_uv)
if self.compute_uv:
if self.full_matrices:
self.gufunc_signature = "(m,n)->(m,m),(k),(n,n)"
else:
self.gufunc_signature = "(m,n)->(m,k),(k),(k,n)"
else:
self.gufunc_signature = "(m,n)->(k)"
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
assert x.ndim == 2, "The input of svd function should be a matrix." assert x.ndim == 2, "The input of svd function should be a matrix."
in_dtype = x.type.numpy_dtype in_dtype = x.type.numpy_dtype
out_dtype = np.dtype(f"f{in_dtype.itemsize}") if in_dtype.name.startswith("int"):
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
else:
out_dtype = in_dtype
s = vector(dtype=out_dtype) s = vector(dtype=out_dtype)
...@@ -603,7 +614,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True): ...@@ -603,7 +614,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True):
U, V, D : matrices U, V, D : matrices
""" """
return SVD(full_matrices, compute_uv)(a) return Blockwise(SVD(full_matrices, compute_uv))(a)
class Lstsq(Op): class Lstsq(Op):
......
from functools import partial
import numpy as np import numpy as np
import numpy.linalg import numpy.linalg
import pytest import pytest
...@@ -34,6 +36,7 @@ from pytensor.tensor.type import ( ...@@ -34,6 +36,7 @@ from pytensor.tensor.type import (
lscalar, lscalar,
matrix, matrix,
scalar, scalar,
tensor,
tensor3, tensor3,
tensor4, tensor4,
vector, vector,
...@@ -150,29 +153,52 @@ def test_qr_modes(): ...@@ -150,29 +153,52 @@ def test_qr_modes():
class TestSvd(utt.InferShapeTester): class TestSvd(utt.InferShapeTester):
op_class = SVD op_class = SVD
dtype = "float32"
def setup_method(self): def setup_method(self):
super().setup_method() super().setup_method()
self.rng = np.random.default_rng(utt.fetch_seed()) self.rng = np.random.default_rng(utt.fetch_seed())
self.A = matrix(dtype=self.dtype) self.A = matrix(dtype=config.floatX)
self.op = svd self.op = svd
def test_svd(self): @pytest.mark.parametrize(
A = matrix("A", dtype=self.dtype) "core_shape", [(3, 3), (4, 3), (3, 4)], ids=["square", "tall", "wide"]
U, S, VT = svd(A) )
fn = function([A], [U, S, VT]) @pytest.mark.parametrize(
a = self.rng.random((4, 4)).astype(self.dtype) "full_matrix", [True, False], ids=["full=True", "full=False"]
n_u, n_s, n_vt = np.linalg.svd(a) )
t_u, t_s, t_vt = fn(a) @pytest.mark.parametrize(
"compute_uv", [True, False], ids=["compute_uv=True", "compute_uv=False"]
)
@pytest.mark.parametrize(
"batched", [True, False], ids=["batched=True", "batched=False"]
)
@pytest.mark.parametrize(
"test_imag", [True, False], ids=["test_imag=True", "test_imag=False"]
)
def test_svd(self, core_shape, full_matrix, compute_uv, batched, test_imag):
dtype = config.floatX
if test_imag:
dtype = "complex128" if dtype.endswith("64") else "complex64"
shape = core_shape if not batched else (10, *core_shape)
A = tensor("A", shape=shape, dtype=dtype)
a = self.rng.random(shape).astype(dtype)
outputs = svd(A, compute_uv=compute_uv, full_matrices=full_matrix)
outputs = outputs if isinstance(outputs, list) else [outputs]
fn = function(inputs=[A], outputs=outputs)
np_fn = np.vectorize(
partial(np.linalg.svd, compute_uv=compute_uv, full_matrices=full_matrix),
signature=outputs[0].owner.op.core_op.gufunc_signature,
)
np_outputs = np_fn(a)
pt_outputs = fn(a)
assert _allclose(n_u, t_u) np_outputs = np_outputs if isinstance(np_outputs, tuple) else [np_outputs]
assert _allclose(n_s, t_s)
assert _allclose(n_vt, t_vt)
fn = function([A], svd(A, compute_uv=False)) for np_val, pt_val in zip(np_outputs, pt_outputs):
t_s = fn(a) assert _allclose(np_val, pt_val)
assert _allclose(n_s, t_s)
def test_svd_infer_shape(self): def test_svd_infer_shape(self):
self.validate_shape((4, 4), full_matrices=True, compute_uv=True) self.validate_shape((4, 4), full_matrices=True, compute_uv=True)
...@@ -183,7 +209,7 @@ class TestSvd(utt.InferShapeTester): ...@@ -183,7 +209,7 @@ class TestSvd(utt.InferShapeTester):
def validate_shape(self, shape, compute_uv=True, full_matrices=True): def validate_shape(self, shape, compute_uv=True, full_matrices=True):
A = self.A A = self.A
A_v = self.rng.random(shape).astype(self.dtype) A_v = self.rng.random(shape).astype(config.floatX)
outputs = self.op(A, full_matrices=full_matrices, compute_uv=compute_uv) outputs = self.op(A, full_matrices=full_matrices, compute_uv=compute_uv)
if not compute_uv: if not compute_uv:
outputs = [outputs] outputs = [outputs]
...@@ -451,8 +477,8 @@ class TestNormTests: ...@@ -451,8 +477,8 @@ class TestNormTests:
norm(3, None) norm(3, None)
def test_tensor_input(self): def test_tensor_input(self):
with pytest.raises(NotImplementedError): res = norm(np.random.random((3, 4, 5)), None)
norm(np.random.random((3, 4, 5)), None) assert res.shape.eval() == (3,)
def test_numpy_compare(self): def test_numpy_compare(self):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论