Unverified 提交 1f9a67bc authored 作者: Carlos Trujillo's avatar Carlos Trujillo 提交者: GitHub

Add linalg Ops to MLX backend (#1700)

上级 f83c05ba
......@@ -12,4 +12,6 @@ import pytensor.link.mlx.dispatch.signal.conv
import pytensor.link.mlx.dispatch.blockwise
import pytensor.link.mlx.dispatch.extra_ops
import pytensor.link.mlx.dispatch.sort
import pytensor.link.mlx.dispatch.slinalg
import pytensor.link.mlx.dispatch.nlinalg
# isort: on
import mlx.core as mx
from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.nlinalg import SVD, KroneckerProduct, MatrixInverse, MatrixPinv
@mlx_funcify.register(SVD)
def mlx_funcify_SVD(op, node, **kwargs):
full_matrices = op.full_matrices
compute_uv = op.compute_uv
X_dtype = getattr(mx, node.inputs[0].dtype)
if not full_matrices:
raise TypeError("full_matrices=False is not supported in the mlx backend.")
def svd_S_only(x):
return mx.linalg.svd(
x.astype(dtype=X_dtype, stream=mx.cpu), compute_uv=False, stream=mx.cpu
)
def svd_full(x):
outputs = mx.linalg.svd(
x.astype(dtype=X_dtype, stream=mx.cpu), compute_uv=True, stream=mx.cpu
)
return outputs
if compute_uv:
return svd_full
else:
return svd_S_only
@mlx_funcify.register(KroneckerProduct)
def mlx_funcify_KroneckerProduct(op, node, **kwargs):
otype = node.outputs[0].dtype
stream = mx.cpu if otype == "float64" else mx.gpu
A_dtype = getattr(mx, node.inputs[0].dtype)
B_dtype = getattr(mx, node.inputs[1].dtype)
def kron(a, b):
return mx.kron(
a.astype(dtype=A_dtype, stream=stream),
b.astype(dtype=B_dtype, stream=stream),
stream=stream,
)
return kron
@mlx_funcify.register(MatrixInverse)
def mlx_funcify_MatrixInverse(op, node, **kwargs):
X_dtype = getattr(mx, node.inputs[0].dtype)
def inv(x):
return mx.linalg.inv(x.astype(dtype=X_dtype, stream=mx.cpu), stream=mx.cpu)
return inv
@mlx_funcify.register(MatrixPinv)
def mlx_funcify_MatrixPinv(op, node, **kwargs):
x_dtype = getattr(mx, node.inputs[0].dtype)
def pinv(x):
return mx.linalg.pinv(x.astype(dtype=x_dtype, stream=mx.cpu), stream=mx.cpu)
return pinv
import warnings
import mlx.core as mx
from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.slinalg import LU, Cholesky, Solve, SolveTriangular
@mlx_funcify.register(Cholesky)
def mlx_funcify_Cholesky(op, node, **kwargs):
lower = op.lower
a_dtype = getattr(mx, node.inputs[0].dtype)
def cholesky(a):
return mx.linalg.cholesky(
a.astype(dtype=a_dtype, stream=mx.cpu), upper=not lower, stream=mx.cpu
)
return cholesky
@mlx_funcify.register(Solve)
def mlx_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a
a_dtype = getattr(mx, node.inputs[0].dtype)
b_dtype = getattr(mx, node.inputs[1].dtype)
if assume_a != "gen":
warnings.warn(
f"MLX solve does not support assume_a={op.assume_a}. Defaulting to assume_a='gen'.",
UserWarning,
)
def solve(a, b):
# MLX only supports solve on CPU
return mx.linalg.solve(
a.astype(stream=mx.cpu, dtype=a_dtype),
b.astype(stream=mx.cpu, dtype=b_dtype),
stream=mx.cpu,
)
return solve
@mlx_funcify.register(SolveTriangular)
def mlx_funcify_SolveTriangular(op, node, **kwargs):
lower = op.lower
A_dtype = getattr(mx, node.inputs[0].dtype)
b_dtype = getattr(mx, node.inputs[1].dtype)
def solve_triangular(A, b):
return mx.linalg.solve_triangular(
A.astype(stream=mx.cpu, dtype=A_dtype),
b.astype(stream=mx.cpu, dtype=b_dtype),
upper=not lower,
stream=mx.cpu, # MLX only supports solve_triangular on CPU
)
return solve_triangular
@mlx_funcify.register(LU)
def mlx_funcify_LU(op, node, **kwargs):
permute_l = op.permute_l
A_dtype = getattr(mx, node.inputs[0].dtype)
p_indices = op.p_indices
if permute_l:
raise ValueError("permute_l=True is not supported in the mlx backend.")
if not p_indices:
raise ValueError("p_indices=False is not supported in the mlx backend.")
def lu(a):
p_idx, L, U = mx.linalg.lu(
a.astype(dtype=A_dtype, stream=mx.cpu), stream=mx.cpu
)
return (
p_idx.astype(mx.int32, stream=mx.cpu),
L,
U,
)
return lu
from functools import partial
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor import config
from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode
@pytest.mark.parametrize("compute_uv", [True, False])
def test_mlx_svd(compute_uv):
rng = np.random.default_rng(15)
A = pt.matrix(name="X")
A_val = rng.normal(size=(3, 3)).astype(config.floatX)
A_val = A_val @ A_val.T
out = pt.linalg.svd(A, compute_uv=compute_uv)
compare_mlx_and_py(
[A],
out,
[A_val],
mlx_mode=mlx_mode,
assert_fn=partial(np.testing.assert_allclose, atol=1e-6, strict=True),
)
def test_mlx_kron():
rng = np.random.default_rng(15)
A = pt.matrix(name="A")
B = pt.matrix(name="B")
A_val, B_val = rng.normal(scale=0.1, size=(2, 3, 3)).astype(config.floatX)
out = pt.linalg.kron(A, B)
compare_mlx_and_py(
[A, B],
[out],
[A_val, B_val],
mlx_mode=mlx_mode,
assert_fn=partial(np.testing.assert_allclose, atol=1e-6, strict=True),
)
@pytest.mark.parametrize("op", [pt.linalg.inv, pt.linalg.pinv], ids=["inv", "pinv"])
def test_mlx_inv(op):
rng = np.random.default_rng(15)
n = 3
A = pt.matrix(name="A")
A_val = rng.normal(size=(n, n))
A_val = (A_val @ A_val.T).astype(config.floatX)
out = op(A)
compare_mlx_and_py(
[A],
[out],
[A_val],
mlx_mode=mlx_mode,
assert_fn=partial(
np.testing.assert_allclose, atol=1e-6, rtol=1e-6, strict=True
),
)
import contextlib
from functools import partial
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor import config
from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode
@pytest.mark.parametrize("lower", [True, False])
def test_mlx_cholesky(lower):
rng = np.random.default_rng(15)
n = 3
A = pt.tensor("A", shape=(n, n))
A_val = rng.normal(size=(n, n))
A_val = (A_val @ A_val.T).astype(config.floatX)
out = pt.linalg.cholesky(A, lower=lower)
compare_mlx_and_py(
[A],
[out],
[A_val],
mlx_mode=mlx_mode,
assert_fn=partial(np.testing.assert_allclose, atol=1e-6, strict=True),
)
@pytest.mark.parametrize("assume_a", ["gen", "pos"])
def test_mlx_solve(assume_a):
rng = np.random.default_rng(15)
n = 3
A = pt.tensor("A", shape=(n, n))
b = pt.tensor("B", shape=(n, n))
out = pt.linalg.solve(A, b, b_ndim=2, assume_a=assume_a)
A_val = rng.normal(size=(n, n)).astype(config.floatX)
A_val = A_val @ A_val.T
b_val = rng.normal(size=(n, n)).astype(config.floatX)
context = (
contextlib.suppress()
if assume_a == "gen"
else pytest.warns(
UserWarning, match=f"MLX solve does not support assume_a={assume_a}"
)
)
with context:
compare_mlx_and_py(
[A, b],
[out],
[A_val, b_val],
mlx_mode=mlx_mode,
assert_fn=partial(
np.testing.assert_allclose, atol=1e-6, rtol=1e-6, strict=True
),
)
@pytest.mark.parametrize("lower, trans", [(False, False), (True, True)])
def test_mlx_SolveTriangular(lower, trans):
rng = np.random.default_rng(15)
A = pt.tensor("A", shape=(5, 5))
b = pt.tensor("B", shape=(5, 5))
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
b_val = rng.normal(size=(5, 5)).astype(config.floatX)
out = pt.linalg.solve_triangular(
A,
b,
trans=0,
lower=lower,
unit_diagonal=False,
)
compare_mlx_and_py(
[A, b],
[out],
[A_val, b_val],
mlx_mode=mlx_mode,
assert_fn=partial(
np.testing.assert_allclose, atol=1e-6, rtol=1e-6, strict=True
),
)
def test_mlx_LU():
rng = np.random.default_rng(15)
A = pt.tensor("A", shape=(5, 5))
out = pt.linalg.lu(A, permute_l=False, p_indices=True)
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
compare_mlx_and_py(
[A],
out,
[A_val],
mlx_mode=mlx_mode,
assert_fn=partial(np.testing.assert_allclose, atol=1e-6, strict=True),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论