Unverified 提交 ab5037e9 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Add more MLX dispatches (#1684)

* Add mlx extra ops * Add log_softmax dispatch for mlx * Fix split bug, add tests * Feedback
上级 17c675a2
...@@ -10,4 +10,6 @@ import pytensor.link.mlx.dispatch.core ...@@ -10,4 +10,6 @@ import pytensor.link.mlx.dispatch.core
import pytensor.link.mlx.dispatch.signal import pytensor.link.mlx.dispatch.signal
import pytensor.link.mlx.dispatch.signal.conv import pytensor.link.mlx.dispatch.signal.conv
import pytensor.link.mlx.dispatch.blockwise import pytensor.link.mlx.dispatch.blockwise
import pytensor.link.mlx.dispatch.extra_ops
import pytensor.link.mlx.dispatch.sort
# isort: on # isort: on
...@@ -61,14 +61,17 @@ def mlx_funcify_Split(op: Split, node, **kwargs): ...@@ -61,14 +61,17 @@ def mlx_funcify_Split(op: Split, node, **kwargs):
# Resolve constants for significant performance improvement (14x speedup) # Resolve constants for significant performance improvement (14x speedup)
if constant_axis is not None: if constant_axis is not None:
axis = int(constant_axis) axis = int(constant_axis)
else:
raise ValueError(
"Symbolic axis is not supported in MLX Split implementation."
)
if constant_splits is not None: if constant_splits is not None:
splits = constant_splits splits_arr = mx.array(constant_splits)
cumsum_splits = np.cumsum(splits[:-1])
else: else:
# Dynamic case - use MLX operations
splits_arr = mx.array(splits) splits_arr = mx.array(splits)
cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist()
cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist()
# Validation checks # Validation checks
if len(splits) != op.len_splits: if len(splits) != op.len_splits:
......
from functools import singledispatch from functools import singledispatch
import mlx.core as mx import mlx.core as mx
import mlx.nn as mlx_nn
import numpy as np import numpy as np
from pytensor.link.mlx.dispatch.basic import mlx_funcify from pytensor.link.mlx.dispatch.basic import mlx_funcify
...@@ -40,7 +41,7 @@ from pytensor.scalar.basic import ( ...@@ -40,7 +41,7 @@ from pytensor.scalar.basic import (
) )
from pytensor.scalar.math import Erfc, Erfcx, Sigmoid, Softplus from pytensor.scalar.math import Erfc, Erfcx, Sigmoid, Softplus
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.special import Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@mlx_funcify.register(DimShuffle) @mlx_funcify.register(DimShuffle)
...@@ -142,6 +143,16 @@ def mlx_funcify_SoftmaxGrad(op, **kwargs): ...@@ -142,6 +143,16 @@ def mlx_funcify_SoftmaxGrad(op, **kwargs):
return softmax_grad return softmax_grad
@mlx_funcify.register(LogSoftmax)
def mlx_funcify_LogSoftmax(op, **kwargs):
axis = op.axis
def log_softmax(x):
return mlx_nn.log_softmax(x, axis=axis)
return log_softmax
@mlx_funcify.register(Softplus) @mlx_funcify.register(Softplus)
def mlx_funcify_Softplus(op, **kwargs): def mlx_funcify_Softplus(op, **kwargs):
def softplus(x): def softplus(x):
......
import mlx.core as mx
from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.extra_ops import CumOp, Repeat
@mlx_funcify.register(CumOp)
def mlx_funcify_CumOp(op, **kwargs):
axis = op.axis
mode = op.mode
def cumop(x, axis=axis, mode=mode):
match mode:
case "add":
return mx.cumsum(x, axis=axis)
case "mul":
return mx.cumprod(x, axis=axis)
case _:
raise NotImplementedError(f"CumOp mode {mode} not implemented in MLX")
return cumop
@mlx_funcify.register(Repeat)
def jax_funcify_Repeat(op, **kwargs):
axis = op.axis
def repeat(x, repeats, axis=axis):
if not isinstance(repeats, int):
raise NotImplementedError(
"MLX repeat does not support sequence-valued repeat argument."
)
return mx.repeat(x, repeats, axis=axis)
return repeat
import warnings
import mlx.core as mx
from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.sort import ArgSortOp, SortOp
@mlx_funcify.register(SortOp)
def mlx_funcify_Sort(op, **kwargs):
kind = op.kind
if kind != "quicksort":
warnings.warn(
message=f"MLX sort does not support the kind argument (got kind={kind}). The argument will be "
f"ignored.",
category=UserWarning,
)
def sort(x, axis):
return mx.sort(x, axis=axis)
return sort
@mlx_funcify.register(ArgSortOp)
def mlx_funcify_ArgSort(op, **kwargs):
kind = op.kind
if kind != "quicksort":
warnings.warn(
message=f"MLX argsort does not support the kind argument (got kind={kind}). The argument will be "
f"ignored.",
category=UserWarning,
)
def argsort(x, axis):
return mx.argsort(x, axis=axis)
return argsort
...@@ -2,9 +2,15 @@ import numpy as np ...@@ -2,9 +2,15 @@ import numpy as np
import pytest import pytest
import pytensor import pytensor
from pytensor import config
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.tensor.basic import Alloc from pytensor.tensor.basic import Alloc
from tests.link.mlx.test_basic import compile_mode, mlx_mode_no_compile, mx from tests.link.mlx.test_basic import (
compare_mlx_and_py,
compile_mode,
mlx_mode_no_compile,
mx,
)
def test_alloc_with_different_shape_types(): def test_alloc_with_different_shape_types():
...@@ -137,3 +143,24 @@ def test_empty_dynamic_shape(): ...@@ -137,3 +143,24 @@ def test_empty_dynamic_shape():
"used inside compiled functions", "used inside compiled functions",
): ):
f_compiled(3, 4) f_compiled(3, 4)
def test_split_const_axis_const_splits_compiled():
x = pt.vector("x")
splits = [2, 3]
outs = pt.split(x, splits, len(splits), axis=0)
compare_mlx_and_py([x], outs, [np.arange(5, dtype="float32")])
def test_split_dynamic_axis_const_splits():
x = pt.matrix("x")
axis = pt.scalar("axis", dtype="int64")
splits = [1, 2, 3]
outs = pt.split(x, splits, len(splits), axis=axis)
test_input = np.arange(12).astype(config.floatX).reshape(2, 6)
with pytest.raises(
ValueError, match="Symbolic axis is not supported in MLX Split implementation"
):
compare_mlx_and_py([x, axis], outs, [test_input, np.array(1)])
...@@ -30,7 +30,7 @@ from pytensor.tensor.math import any as pt_any ...@@ -30,7 +30,7 @@ from pytensor.tensor.math import any as pt_any
from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import min as pt_min from pytensor.tensor.math import min as pt_min
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.special import SoftmaxGrad, softmax from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, vector, vectors from pytensor.tensor.type import matrix, vector, vectors
from tests.link.mlx.test_basic import compare_mlx_and_py from tests.link.mlx.test_basic import compare_mlx_and_py
...@@ -97,6 +97,15 @@ def test_softmax_grad(axis): ...@@ -97,6 +97,15 @@ def test_softmax_grad(axis):
compare_mlx_and_py([dy, sm], [out], [dy_test_value, sm_test_value]) compare_mlx_and_py([dy, sm], [out], [dy_test_value, sm_test_value])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_logsoftmax(axis):
x = matrix("x")
x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = log_softmax(x, axis=axis)
compare_mlx_and_py([x], [out], [x_test_value])
@pytest.mark.parametrize("size", [(10, 10), (1000, 1000)]) @pytest.mark.parametrize("size", [(10, 10), (1000, 1000)])
@pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("axis", [0, 1])
def test_logsumexp_benchmark(size, axis, benchmark): def test_logsumexp_benchmark(size, axis, benchmark):
......
import numpy as np
import pytest
from pytensor.configdefaults import config
from pytensor.tensor import extra_ops as pt_extra_ops
from pytensor.tensor.type import matrix
from tests.link.mlx.test_basic import compare_mlx_and_py
mx = pytest.importorskip("mlx.core")
def test_extra_ops():
a = matrix("a")
a_test = np.arange(6, dtype=config.floatX).reshape((3, 2))
out = pt_extra_ops.cumsum(a, axis=0)
compare_mlx_and_py([a], [out], [a_test])
out = pt_extra_ops.cumprod(a, axis=1)
compare_mlx_and_py([a], [out], [a_test])
out = pt_extra_ops.repeat(a, 3, axis=1)
compare_mlx_and_py([a], [out], [a_test])
import numpy as np
import pytest
from pytensor.tensor.sort import argsort, sort
from pytensor.tensor.type import matrix
from tests.link.mlx.test_basic import compare_mlx_and_py
@pytest.mark.parametrize("axis", [None, -1])
@pytest.mark.parametrize("func", (sort, argsort))
def test_sort(func, axis):
x = matrix("x", shape=(2, 2), dtype="float64")
out = func(x, axis=axis)
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
compare_mlx_and_py([x], [out], [arr])
def test_sort_invalid_kind_warning():
x = matrix("x", shape=(2, 2), dtype="float64")
z = sort(x, axis=-1, kind="mergesort")
with pytest.warns(UserWarning, match="MLX sort does not support the kind argument"):
z.eval({x: np.array([[3.0, 1.0], [2.0, 4.0]])}, mode="MLX")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论