提交 1a0d12d8 authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Jesse Grabowski

Implement `Blockwise` for `SVD`

上级 ee7c9469
......@@ -523,6 +523,7 @@ def qr(a, mode="reduced"):
class SVD(Op):
"""
Computes singular value decomposition of matrix A, into U, S, V such that A = U @ S @ V
Parameters
----------
......@@ -543,13 +544,23 @@ class SVD(Op):
def __init__(self, full_matrices: bool = True, compute_uv: bool = True):
self.full_matrices = bool(full_matrices)
self.compute_uv = bool(compute_uv)
if self.compute_uv:
if self.full_matrices:
self.gufunc_signature = "(m,n)->(m,m),(k),(n,n)"
else:
self.gufunc_signature = "(m,n)->(m,k),(k),(k,n)"
else:
self.gufunc_signature = "(m,n)->(k)"
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2, "The input of svd function should be a matrix."
in_dtype = x.type.numpy_dtype
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
if in_dtype.name.startswith("int"):
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
else:
out_dtype = in_dtype
s = vector(dtype=out_dtype)
......@@ -603,7 +614,7 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True):
U, V, D : matrices
"""
return SVD(full_matrices, compute_uv)(a)
return Blockwise(SVD(full_matrices, compute_uv))(a)
class Lstsq(Op):
......
from functools import partial
import numpy as np
import numpy.linalg
import pytest
......@@ -34,6 +36,7 @@ from pytensor.tensor.type import (
lscalar,
matrix,
scalar,
tensor,
tensor3,
tensor4,
vector,
......@@ -150,29 +153,52 @@ def test_qr_modes():
class TestSvd(utt.InferShapeTester):
op_class = SVD
dtype = "float32"
def setup_method(self):
super().setup_method()
self.rng = np.random.default_rng(utt.fetch_seed())
self.A = matrix(dtype=self.dtype)
self.A = matrix(dtype=config.floatX)
self.op = svd
def test_svd(self):
A = matrix("A", dtype=self.dtype)
U, S, VT = svd(A)
fn = function([A], [U, S, VT])
a = self.rng.random((4, 4)).astype(self.dtype)
n_u, n_s, n_vt = np.linalg.svd(a)
t_u, t_s, t_vt = fn(a)
@pytest.mark.parametrize(
"core_shape", [(3, 3), (4, 3), (3, 4)], ids=["square", "tall", "wide"]
)
@pytest.mark.parametrize(
"full_matrix", [True, False], ids=["full=True", "full=False"]
)
@pytest.mark.parametrize(
"compute_uv", [True, False], ids=["compute_uv=True", "compute_uv=False"]
)
@pytest.mark.parametrize(
"batched", [True, False], ids=["batched=True", "batched=False"]
)
@pytest.mark.parametrize(
"test_imag", [True, False], ids=["test_imag=True", "test_imag=False"]
)
def test_svd(self, core_shape, full_matrix, compute_uv, batched, test_imag):
dtype = config.floatX
if test_imag:
dtype = "complex128" if dtype.endswith("64") else "complex64"
shape = core_shape if not batched else (10, *core_shape)
A = tensor("A", shape=shape, dtype=dtype)
a = self.rng.random(shape).astype(dtype)
outputs = svd(A, compute_uv=compute_uv, full_matrices=full_matrix)
outputs = outputs if isinstance(outputs, list) else [outputs]
fn = function(inputs=[A], outputs=outputs)
np_fn = np.vectorize(
partial(np.linalg.svd, compute_uv=compute_uv, full_matrices=full_matrix),
signature=outputs[0].owner.op.core_op.gufunc_signature,
)
np_outputs = np_fn(a)
pt_outputs = fn(a)
assert _allclose(n_u, t_u)
assert _allclose(n_s, t_s)
assert _allclose(n_vt, t_vt)
np_outputs = np_outputs if isinstance(np_outputs, tuple) else [np_outputs]
fn = function([A], svd(A, compute_uv=False))
t_s = fn(a)
assert _allclose(n_s, t_s)
for np_val, pt_val in zip(np_outputs, pt_outputs):
assert _allclose(np_val, pt_val)
def test_svd_infer_shape(self):
self.validate_shape((4, 4), full_matrices=True, compute_uv=True)
......@@ -183,7 +209,7 @@ class TestSvd(utt.InferShapeTester):
def validate_shape(self, shape, compute_uv=True, full_matrices=True):
A = self.A
A_v = self.rng.random(shape).astype(self.dtype)
A_v = self.rng.random(shape).astype(config.floatX)
outputs = self.op(A, full_matrices=full_matrices, compute_uv=compute_uv)
if not compute_uv:
outputs = [outputs]
......@@ -451,8 +477,8 @@ class TestNormTests:
norm(3, None)
def test_tensor_input(self):
with pytest.raises(NotImplementedError):
norm(np.random.random((3, 4, 5)), None)
res = norm(np.random.random((3, 4, 5)), None)
assert res.shape.eval() == (3,)
def test_numpy_compare(self):
rng = np.random.default_rng(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论