Unverified 提交 ab5037e9 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Add more MLX dispatches (#1684)

* Add mlx extra ops * Add log_softmax dispatch for mlx * Fix split bug, add tests * Feedback
上级 17c675a2
......@@ -10,4 +10,6 @@ import pytensor.link.mlx.dispatch.core
import pytensor.link.mlx.dispatch.signal
import pytensor.link.mlx.dispatch.signal.conv
import pytensor.link.mlx.dispatch.blockwise
import pytensor.link.mlx.dispatch.extra_ops
import pytensor.link.mlx.dispatch.sort
# isort: on
......@@ -61,14 +61,17 @@ def mlx_funcify_Split(op: Split, node, **kwargs):
# Resolve constants for significant performance improvement (14x speedup)
if constant_axis is not None:
axis = int(constant_axis)
else:
raise ValueError(
"Symbolic axis is not supported in MLX Split implementation."
)
if constant_splits is not None:
splits = constant_splits
cumsum_splits = np.cumsum(splits[:-1])
splits_arr = mx.array(constant_splits)
else:
# Dynamic case - use MLX operations
splits_arr = mx.array(splits)
cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist()
cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist()
# Validation checks
if len(splits) != op.len_splits:
......
from functools import singledispatch
import mlx.core as mx
import mlx.nn as mlx_nn
import numpy as np
from pytensor.link.mlx.dispatch.basic import mlx_funcify
......@@ -40,7 +41,7 @@ from pytensor.scalar.basic import (
)
from pytensor.scalar.math import Erfc, Erfcx, Sigmoid, Softplus
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.special import Softmax, SoftmaxGrad
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@mlx_funcify.register(DimShuffle)
......@@ -142,6 +143,16 @@ def mlx_funcify_SoftmaxGrad(op, **kwargs):
return softmax_grad
@mlx_funcify.register(LogSoftmax)
def mlx_funcify_LogSoftmax(op, **kwargs):
axis = op.axis
def log_softmax(x):
return mlx_nn.log_softmax(x, axis=axis)
return log_softmax
@mlx_funcify.register(Softplus)
def mlx_funcify_Softplus(op, **kwargs):
def softplus(x):
......
import mlx.core as mx
from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.extra_ops import CumOp, Repeat
@mlx_funcify.register(CumOp)
def mlx_funcify_CumOp(op, **kwargs):
axis = op.axis
mode = op.mode
def cumop(x, axis=axis, mode=mode):
match mode:
case "add":
return mx.cumsum(x, axis=axis)
case "mul":
return mx.cumprod(x, axis=axis)
case _:
raise NotImplementedError(f"CumOp mode {mode} not implemented in MLX")
return cumop
@mlx_funcify.register(Repeat)
def jax_funcify_Repeat(op, **kwargs):
axis = op.axis
def repeat(x, repeats, axis=axis):
if not isinstance(repeats, int):
raise NotImplementedError(
"MLX repeat does not support sequence-valued repeat argument."
)
return mx.repeat(x, repeats, axis=axis)
return repeat
import warnings
import mlx.core as mx
from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.sort import ArgSortOp, SortOp
@mlx_funcify.register(SortOp)
def mlx_funcify_Sort(op, **kwargs):
kind = op.kind
if kind != "quicksort":
warnings.warn(
message=f"MLX sort does not support the kind argument (got kind={kind}). The argument will be "
f"ignored.",
category=UserWarning,
)
def sort(x, axis):
return mx.sort(x, axis=axis)
return sort
@mlx_funcify.register(ArgSortOp)
def mlx_funcify_ArgSort(op, **kwargs):
kind = op.kind
if kind != "quicksort":
warnings.warn(
message=f"MLX argsort does not support the kind argument (got kind={kind}). The argument will be "
f"ignored.",
category=UserWarning,
)
def argsort(x, axis):
return mx.argsort(x, axis=axis)
return argsort
......@@ -2,9 +2,15 @@ import numpy as np
import pytest
import pytensor
from pytensor import config
from pytensor import tensor as pt
from pytensor.tensor.basic import Alloc
from tests.link.mlx.test_basic import compile_mode, mlx_mode_no_compile, mx
from tests.link.mlx.test_basic import (
compare_mlx_and_py,
compile_mode,
mlx_mode_no_compile,
mx,
)
def test_alloc_with_different_shape_types():
......@@ -137,3 +143,24 @@ def test_empty_dynamic_shape():
"used inside compiled functions",
):
f_compiled(3, 4)
def test_split_const_axis_const_splits_compiled():
x = pt.vector("x")
splits = [2, 3]
outs = pt.split(x, splits, len(splits), axis=0)
compare_mlx_and_py([x], outs, [np.arange(5, dtype="float32")])
def test_split_dynamic_axis_const_splits():
x = pt.matrix("x")
axis = pt.scalar("axis", dtype="int64")
splits = [1, 2, 3]
outs = pt.split(x, splits, len(splits), axis=axis)
test_input = np.arange(12).astype(config.floatX).reshape(2, 6)
with pytest.raises(
ValueError, match="Symbolic axis is not supported in MLX Split implementation"
):
compare_mlx_and_py([x, axis], outs, [test_input, np.array(1)])
......@@ -30,7 +30,7 @@ from pytensor.tensor.math import any as pt_any
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import min as pt_min
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.special import SoftmaxGrad, softmax
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, vector, vectors
from tests.link.mlx.test_basic import compare_mlx_and_py
......@@ -97,6 +97,15 @@ def test_softmax_grad(axis):
compare_mlx_and_py([dy, sm], [out], [dy_test_value, sm_test_value])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_logsoftmax(axis):
x = matrix("x")
x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = log_softmax(x, axis=axis)
compare_mlx_and_py([x], [out], [x_test_value])
@pytest.mark.parametrize("size", [(10, 10), (1000, 1000)])
@pytest.mark.parametrize("axis", [0, 1])
def test_logsumexp_benchmark(size, axis, benchmark):
......
import numpy as np
import pytest
from pytensor.configdefaults import config
from pytensor.tensor import extra_ops as pt_extra_ops
from pytensor.tensor.type import matrix
from tests.link.mlx.test_basic import compare_mlx_and_py
mx = pytest.importorskip("mlx.core")
def test_extra_ops():
a = matrix("a")
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))
out = pt_extra_ops.cumsum(a, axis=0)
compare_mlx_and_py([a], [out], [a_test])
out = pt_extra_ops.cumprod(a, axis=1)
compare_mlx_and_py([a], [out], [a_test])
out = pt_extra_ops.repeat(a, 3, axis=1)
compare_mlx_and_py([a], [out], [a_test])
import numpy as np
import pytest
from pytensor.tensor.sort import argsort, sort
from pytensor.tensor.type import matrix
from tests.link.mlx.test_basic import compare_mlx_and_py
@pytest.mark.parametrize("axis", [None, -1])
@pytest.mark.parametrize("func", (sort, argsort))
def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = func(x, axis=axis)
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_mlx_and_py([x], [out], [arr])
def test_sort_invalid_kind_warning():
x = matrix("x", shape=(2, 2), dtype="float64")
z = sort(x, axis=-1, kind="mergesort")
with pytest.warns(UserWarning, match="MLX sort does not support the kind argument"):
z.eval({x: np.array([[3.0, 1.0], [2.0, 4.0]])}, mode="MLX")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论