Unverified 提交 934306f2 authored 作者: Will Dean's avatar Will Dean 提交者: GitHub
上级 92c3b490
...@@ -81,6 +81,7 @@ jobs: ...@@ -81,6 +81,7 @@ jobs:
install-numba: [0] install-numba: [0]
install-jax: [0] install-jax: [0]
install-torch: [0] install-torch: [0]
install-mlx: [0]
install-xarray: [0] install-xarray: [0]
part: part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/xtensor" - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/xtensor"
...@@ -106,6 +107,7 @@ jobs: ...@@ -106,6 +107,7 @@ jobs:
install-numba: 0 install-numba: 0
install-jax: 0 install-jax: 0
install-torch: 0 install-torch: 0
install-mlx: 0
install-xarray: 0 install-xarray: 0
- install-numba: 1 - install-numba: 1
os: "ubuntu-latest" os: "ubuntu-latest"
...@@ -149,7 +151,16 @@ jobs: ...@@ -149,7 +151,16 @@ jobs:
fast-compile: 0 fast-compile: 0
float32: 0 float32: 0
part: "tests/xtensor" part: "tests/xtensor"
- os: macos-15 - os: "macos-15"
python-version: "3.11"
fast-compile: 0
float32: 0
install-mlx: 1
install-numba: 0
install-jax: 0
install-torch: 0
part: "tests/link/mlx"
- os: "macos-15"
python-version: "3.13" python-version: "3.13"
fast-compile: 0 fast-compile: 0
float32: 0 float32: 0
...@@ -194,6 +205,7 @@ jobs: ...@@ -194,6 +205,7 @@ jobs:
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mlx; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
pip install -e ./ pip install -e ./
...@@ -210,6 +222,7 @@ jobs: ...@@ -210,6 +222,7 @@ jobs:
INSTALL_JAX: ${{ matrix.install-jax }} INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}} INSTALL_TORCH: ${{ matrix.install-torch}}
INSTALL_XARRAY: ${{ matrix.install-xarray }} INSTALL_XARRAY: ${{ matrix.install-xarray }}
INSTALL_MLX: ${{ matrix.install-mlx }}
OS: ${{ matrix.os}} OS: ${{ matrix.os}}
- name: Run tests - name: Run tests
......
...@@ -27,7 +27,6 @@ __pycache__ ...@@ -27,7 +27,6 @@ __pycache__
\#*\# \#*\#
build build
compiled/*.cpp compiled/*.cpp
core.*
cutils_ext.cpp cutils_ext.cpp
dist dist
doc/.build/ doc/.build/
......
...@@ -27,6 +27,7 @@ from pytensor.graph.rewriting.db import ( ...@@ -27,6 +27,7 @@ from pytensor.graph.rewriting.db import (
from pytensor.link.basic import Linker, PerformLinker from pytensor.link.basic import Linker, PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.link.jax.linker import JAXLinker from pytensor.link.jax.linker import JAXLinker
from pytensor.link.mlx.linker import MLXLinker
from pytensor.link.numba.linker import NumbaLinker from pytensor.link.numba.linker import NumbaLinker
from pytensor.link.pytorch.linker import PytorchLinker from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.link.vm import VMLinker from pytensor.link.vm import VMLinker
...@@ -50,6 +51,7 @@ predefined_linkers = { ...@@ -50,6 +51,7 @@ predefined_linkers = {
"jax": JAXLinker(), "jax": JAXLinker(),
"pytorch": PytorchLinker(), "pytorch": PytorchLinker(),
"numba": NumbaLinker(), "numba": NumbaLinker(),
"mlx": MLXLinker(),
} }
...@@ -504,6 +506,20 @@ PYTORCH = Mode( ...@@ -504,6 +506,20 @@ PYTORCH = Mode(
), ),
) )
MLX = Mode(
MLXLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
],
),
)
predefined_modes = { predefined_modes = {
"FAST_COMPILE": FAST_COMPILE, "FAST_COMPILE": FAST_COMPILE,
...@@ -511,6 +527,7 @@ predefined_modes = { ...@@ -511,6 +527,7 @@ predefined_modes = {
"JAX": JAX, "JAX": JAX,
"NUMBA": NUMBA, "NUMBA": NUMBA,
"PYTORCH": PYTORCH, "PYTORCH": PYTORCH,
"MLX": MLX,
} }
_CACHED_RUNTIME_MODES: dict[str, Mode] = {} _CACHED_RUNTIME_MODES: dict[str, Mode] = {}
......
from pytensor.link.mlx.linker import MLXLinker
# isort: off
from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify
import pytensor.link.mlx.dispatch.math
import pytensor.link.mlx.dispatch.basic
import pytensor.link.mlx.dispatch.elemwise
import pytensor.link.mlx.dispatch.shape
import pytensor.link.mlx.dispatch.subtensor
import pytensor.link.mlx.dispatch.core
import pytensor.link.mlx.dispatch.signal
import pytensor.link.mlx.dispatch.signal.conv
import pytensor.link.mlx.dispatch.blockwise
# isort: on
import warnings
from copy import deepcopy
from functools import singledispatch
from types import NoneType
import mlx.core as mx
import numpy as np
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import Assert, CheckAndRaise
@singledispatch
def mlx_typify(data, **kwargs):
raise NotImplementedError(f"mlx_typify is not implemented for {type(data)}")
@mlx_typify.register(np.ndarray)
def mlx_typify_tensor(data, dtype=None, **kwargs):
return mx.array(data, dtype=dtype)
@mlx_typify.register(slice)
@mlx_typify.register(NoneType)
@mlx_typify.register(mx.array)
def mlx_typify_no_conversion_needed(data, **kwargs):
return data
@mlx_typify.register(int)
@mlx_typify.register(float)
def mlx_typify_python_scalar(data, **kwargs):
return mx.array(data)
@mlx_typify.register(bool)
@mlx_typify.register(np.bool_)
def mlx_typify_bool(data, **kwargs):
return bool(data)
@mlx_typify.register(np.integer)
@mlx_typify.register(np.floating)
@mlx_typify.register(np.complexfloating)
def mlx_typify_numpy_scalar(data, **kwargs):
return mx.array(data)
@singledispatch
def mlx_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a MLX compatible function from an PyTensor `Op`."""
raise NotImplementedError(
f"No MLX conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/1350` for progress or to request we prioritize this operation"
)
@mlx_funcify.register(FunctionGraph)
def mlx_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="mlx_funcified_fgraph",
conversion_func=mlx_funcify,
**kwargs,
):
built_kwargs = {"conversion_func": conversion_func, **kwargs}
return fgraph_to_python(
fgraph,
conversion_func,
type_conversion_fn=mlx_typify,
fgraph_name=fgraph_name,
**built_kwargs,
)
@mlx_funcify.register(DeepCopyOp)
def mlx_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x):
return deepcopy(x)
return deepcopyop
@mlx_funcify.register(Assert)
@mlx_funcify.register(CheckAndRaise)
def mlx_funcify_CheckAndRaise(op, node, **kwargs):
conds = node.inputs[1:]
if any(isinstance(cond, Constant) and not bool(cond.data) for cond in conds):
raise op.exc_type(op.msg)
warnings.warn(
f"""Skipping `{type(op).__name__}` Op (assertion: {op.msg}) as MLX tracing would remove it.""",
stacklevel=2,
)
def assert_fn(x, *inputs):
return x
return assert_fn
import mlx.core as mx
from pytensor.link.mlx.dispatch import mlx_funcify
from pytensor.tensor.blockwise import Blockwise
@mlx_funcify.register(Blockwise)
def funcify_Blockwise(op: Blockwise, node, **kwargs):
# 2) Otherwise, get the core python function for this Blockwise
core_node = op._create_dummy_core_node(node.inputs)
core_f = mlx_funcify(op.core_op, core_node)
# 3) Determine how many inputs correspond to batch dimensions
n_batch = op.batch_ndim(node)
# 4) Handle case where no vectorization is needed
if n_batch == 0:
return core_f
# 5) Vectorize using mx.vmap over any batched inputs
in_axes: list[int | None] = []
for inp, sig in zip(node.inputs, op.inputs_sig):
batch_ndim = inp.type.ndim - len(sig)
if batch_ndim == 0:
in_axes.append(None)
continue
batch_bcast = inp.type.broadcastable[:batch_ndim]
# If all batch dims are broadcastable (size 1), treat input as static
in_axes.append(0 if not all(batch_bcast) else None)
if not any(axis == 0 for axis in in_axes):
return core_f
return mx.vmap(core_f, in_axes=tuple(in_axes))
差异被折叠。
差异被折叠。
import mlx.core as mx
from pytensor.link.mlx.dispatch import mlx_funcify
from pytensor.tensor.math import Argmax, Dot, Max
@mlx_funcify.register(Dot)
def mlx_funcify_Dot(op, node=None, **kwargs):
def dot(x, y):
return mx.matmul(x, y)
return dot
@mlx_funcify.register(Max)
def mlx_funcify_Max(op, node=None, **kwargs):
def max_fn(x):
axes = op.axis
if axes is None:
reduce_axes = None
else:
reduce_axes = tuple(int(ax) for ax in axes)
keepdims = getattr(op, "keepdims", False)
return mx.max(x, axis=reduce_axes, keepdims=keepdims)
return max_fn
@mlx_funcify.register(Argmax)
def mlx_funcify_Argmax(op, node=None, **kwargs):
axis = op.axis
def argmax_fn(x):
if axis is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axis)
keep_axes = [i for i in range(x.ndim) if i not in axes]
transposed_x = mx.transpose(x, tuple(keep_axes + list(axes)))
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]
flat_size = 1
for dim in reduced_shape:
flat_size *= int(dim)
reshaped_x = transposed_x.reshape((*kept_shape, flat_size))
max_idx = mx.argmax(reshaped_x, axis=-1)
result = max_idx.astype(mx.int64)
if getattr(op, "keepdims", False):
reshape_shape = []
keep_iter = iter(kept_shape)
axis_iter = iter(sorted(axes))
next_axis = next(axis_iter, None)
for dim_idx in range(x.ndim):
if next_axis is not None and dim_idx == next_axis:
reshape_shape.append(1)
next_axis = next(axis_iter, None)
else:
reshape_shape.append(int(next(keep_iter)))
return result.reshape(tuple(reshape_shape))
return result
return argmax_fn
import mlx.core as mx
from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@mlx_funcify.register(Shape)
def mlx_funcify_Shape(op, **kwargs):
def shape(x):
return mx.array(x.shape, dtype=mx.int64)
return shape
@mlx_funcify.register(SpecifyShape)
def mlx_funcify_SpecifyShape(op, node, **kwargs):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
for actual, expected in zip(x.shape, shape, strict=True):
if expected is None:
continue
if actual != expected:
raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}")
return x
return specifyshape
@mlx_funcify.register(Shape_i)
def mlx_funcify_Shape_i(op, node, **kwargs):
def shape_i(x):
return x.shape[op.i]
return shape_i
@mlx_funcify.register(Reshape)
def mlx_funcify_Reshape(op, **kwargs):
def reshape(x, shp):
return mx.reshape(x, shp)
return reshape
import mlx.core as mx
from pytensor.link.mlx.dispatch import mlx_funcify, mlx_typify
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.signal.conv import Convolve1d
@mlx_funcify.register(Convolve1d)
def mlx_funcify_Convolve1d(op, node, **kwargs):
_, _, full_mode_var = node.inputs
try:
full_mode = bool(get_underlying_scalar_constant_value(full_mode_var))
runtime_mode_static = True
except NotScalarConstantError:
full_mode = True
runtime_mode_static = False
def conv1d(raw_data, raw_kernel, runtime_full_mode):
data = mlx_typify(raw_data, dtype=None)
kernel = mlx_typify(raw_kernel, dtype=None)
if runtime_mode_static:
runtime_mode = full_mode
else:
runtime_full_mode = mx.array(runtime_full_mode)
runtime_mode = bool(runtime_full_mode.reshape(-1)[0])
mode = "full" if runtime_mode else "valid"
return mx.convolve(data, kernel, mode=mode)
return conv1d
from copy import deepcopy
from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice
@mlx_funcify.register(Subtensor)
def mlx_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists):
indices = indices_from_subtensor([int(element) for element in ilists], idx_list)
if len(indices) == 1:
indices = indices[0]
return x.__getitem__(indices)
return subtensor
@mlx_funcify.register(AdvancedSubtensor)
@mlx_funcify.register(AdvancedSubtensor1)
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def advanced_subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
if len(indices) == 1:
indices = indices[0]
return x.__getitem__(indices)
return advanced_subtensor
@mlx_funcify.register(IncSubtensor)
@mlx_funcify.register(AdvancedIncSubtensor1)
def mlx_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False):
def mlx_fn(x, indices, y):
if not op.inplace:
x = deepcopy(x)
x[indices] = y
return x
else:
def mlx_fn(x, indices, y):
if not op.inplace:
x = deepcopy(x)
x[indices] += y
return x
def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
return mlx_fn(x, indices, y)
return incsubtensor
@mlx_funcify.register(AdvancedIncSubtensor)
def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):
def mlx_fn(x, indices, y):
if not op.inplace:
x = deepcopy(x)
x[indices] = y
return x
else:
def mlx_fn(x, indices, y):
if not op.inplace:
x = deepcopy(x)
x[indices] += y
return x
def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn):
return mlx_fn(x, ilist, y)
return advancedincsubtensor
@mlx_funcify.register(MakeSlice)
def mlx_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)
return makeslice
from pytensor.link.basic import JITLinker
class MLXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX."""
def __init__(self, use_compile=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gen_functors = []
self.use_compile = use_compile
def fgraph_convert(self, fgraph, **kwargs):
"""Convert a PyTensor FunctionGraph to an MLX-compatible function.
Parameters
----------
fgraph : FunctionGraph
The function graph to convert
Returns
-------
callable
An MLX-compatible function
"""
from pytensor.link.mlx.dispatch import mlx_funcify
return mlx_funcify(
fgraph,
**kwargs,
)
def jit_compile(self, fn):
import mlx.core as mx
from pytensor.link.mlx.dispatch import mlx_typify
if not self.use_compile:
# Skip compilation and just return the function with MLX typification
def fn_no_compile(*inputs):
return fn(*(mlx_typify(inp) for inp in inputs))
return fn_no_compile
inner_fn = mx.compile(fn)
def fn(*inputs, inner_fn=inner_fn):
return inner_fn(*(mlx_typify(inp) for inp in inputs))
return fn
def create_thunk_inputs(self, storage_map):
"""Create inputs for the MLX thunk.
Parameters
----------
storage_map : dict
Map from variables to their storage
Returns
-------
list
The inputs for the thunk
"""
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
thunk_inputs.append(sinput)
return thunk_inputs
...@@ -31,13 +31,15 @@ class PytorchLinker(JITLinker): ...@@ -31,13 +31,15 @@ class PytorchLinker(JITLinker):
**kwargs, **kwargs,
} }
return pytorch_funcify( return pytorch_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs fgraph,
input_storage=input_storage,
storage_map=storage_map,
**built_kwargs,
) )
def jit_compile(self, fn): def jit_compile(self, fn):
import torch import torch
# flag that tend to help our graphs
torch._dynamo.config.capture_dynamic_output_shape_ops = True torch._dynamo.config.capture_dynamic_output_shape_ops = True
from pytensor.link.pytorch.dispatch import pytorch_typify from pytensor.link.pytorch.dispatch import pytorch_typify
......
差异被折叠。
import numpy as np
import pytensor.tensor as pt
from pytensor.tensor import tensor
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot
from tests.link.mlx.test_basic import compare_mlx_and_py
# Equivalent blockwise to matmul but with dumb signature
odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)")
def test_blockwise_conv1d():
rng = np.random.default_rng(14)
a = tensor("a", shape=(2, 100))
b = tensor("b", shape=(2, 8))
a_test = rng.normal(size=(2, 100))
b_test = rng.normal(size=(2, 8))
test_values = [a_test, b_test]
out = pt.signal.convolve1d(a, b, mode="valid")
# assert isinstance(out.owner.op, Blockwise)
compare_mlx_and_py([a, b], [out], test_values, must_be_device_array=True)
import numpy as np
import pytest
import pytensor
from pytensor import tensor as pt
from pytensor.tensor.basic import Alloc
from tests.link.mlx.test_basic import compile_mode, mlx_mode_no_compile, mx
def test_alloc_with_different_shape_types():
"""Test Alloc works with different types of shape parameters.
This addresses the TypeError that occurred when shape parameters
contained MLX arrays instead of Python integers.
"""
from pytensor.link.mlx.dispatch.core import (
mlx_funcify_Alloc,
)
# Create a mock node (we don't need a real node for this test)
class MockNode:
def __init__(self):
self.op = Alloc()
self.inputs = None
self.outputs = None
alloc_func = mlx_funcify_Alloc(Alloc(), MockNode())
x = mx.array(5.0)
# Test with Python ints
result = alloc_func(x, 3, 4)
assert result.shape == (3, 4)
assert float(result[0, 0]) == 5.0
# Test with MLX arrays (this used to fail)
result = alloc_func(x, mx.array(3), mx.array(4))
assert result.shape == (3, 4)
assert float(result[0, 0]) == 5.0
# Test with mixed types
result = alloc_func(x, 3, mx.array(4))
assert result.shape == (3, 4)
assert float(result[0, 0]) == 5.0
def test_alloc_pytensor_integration():
"""Test Alloc in a PyTensor graph context."""
# Test basic constant shape allocation
x = pt.scalar("x", dtype="float32")
result = pt.alloc(x, 3, 4)
f = pytensor.function([x], result, mode="MLX")
output = f(5.0)
assert output.shape == (3, 4)
assert float(output[0, 0]) == 5.0
def test_alloc_compilation_limitation():
"""Test that Alloc operations with dynamic shapes provide helpful error in compiled contexts."""
# Create variables
x = pt.scalar("x", dtype="float32")
s1 = pt.scalar("s1", dtype="int64")
s2 = pt.scalar("s2", dtype="int64")
# Create Alloc operation with dynamic shapes
result = pt.alloc(x, s1, s2)
# Create function with non-compiled MLX mode
f = pytensor.function([x, s1, s2], result, mode=mlx_mode_no_compile)
# Test that it works with concrete values (non-compiled context)
output = f(5.0, 3, 4)
assert output.shape == (3, 4)
np.testing.assert_allclose(output, 5.0)
# Test that compilation fails with helpful error
compiled_f = pytensor.function([x, s1, s2], result, mode=compile_mode)
with pytest.raises(
ValueError,
match="MLX compilation limitation: Alloc operations with dynamic shapes cannot be "
"used inside compiled functions",
):
compiled_f(5.0, 3, 4)
def test_alloc_static_shapes_compilation():
"""Test that Alloc operations with static shapes work fine in compiled contexts."""
# Create a scenario with static shapes that should work
x = pt.scalar("x", dtype="float32")
# Use constant shape - this should work even in compilation
result = pt.alloc(x, 3, 4) # Static shapes
# Test both compiled and non-compiled modes
f_normal = pytensor.function([x], result, mode=mlx_mode_no_compile)
f_compiled = pytensor.function([x], result, mode=compile_mode)
# Both should work
output_normal = f_normal(5.0)
output_compiled = f_compiled(5.0)
assert output_normal.shape == (3, 4)
assert output_compiled.shape == (3, 4)
np.testing.assert_allclose(output_normal, 5.0)
np.testing.assert_allclose(output_compiled, 5.0)
np.testing.assert_allclose(output_normal, output_compiled)
def test_empty_static_shape():
result = pt.empty((3, 4), dtype="float32")
f = pytensor.function([], result, mode="MLX")
output = f()
assert output.shape == (3, 4)
np.testing.assert_allclose(output, 0.0)
def test_empty_dynamic_shape():
s1 = pt.scalar("s1", dtype="int64")
s2 = pt.scalar("s2", dtype="int64")
result = pt.empty((s1, s2), dtype="float32")
f = pytensor.function([s1, s2], result, mode=mlx_mode_no_compile)
output = f(3, 4)
assert output.shape == (3, 4)
np.testing.assert_allclose(output, 0.0)
f_compiled = pytensor.function([s1, s2], result, mode=compile_mode)
with pytest.raises(
ValueError,
match="MLX compilation limitation: Alloc operations with dynamic shapes cannot be "
"used inside compiled functions",
):
f_compiled(3, 4)
import numpy as np
import pytest
import scipy
from pytensor import config, function
from pytensor.tensor.basic import switch
from pytensor.tensor.math import (
add,
cos,
eq,
exp,
ge,
gt,
int_div,
isinf,
le,
log,
lt,
mul,
neq,
power,
prod,
sigmoid,
sin,
sub,
true_div,
)
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import any as pt_any
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import min as pt_min
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.special import SoftmaxGrad, softmax
from pytensor.tensor.type import matrix, vector, vectors
from tests.link.mlx.test_basic import compare_mlx_and_py
mx = pytest.importorskip("mlx.core")
@pytest.mark.parametrize("op", [pt_any, pt_all, pt_max, pt_min])
def test_input(op) -> None:
x = vector("x")
out = op(x > 0)
x_test = mx.array([1.0, 2.0, 3.0])
compare_mlx_and_py([x], out, [x_test])
def test_mlx_CAReduce():
a_pt = vector("a")
a_pt.tag.test_value = np.r_[1, 2, 3].astype(config.floatX)
x = pt_sum(a_pt, axis=None)
compare_mlx_and_py([a_pt], [x], [np.r_[1, 2, 3].astype(config.floatX)])
a_pt = matrix("a")
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
x = pt_sum(a_pt, axis=0)
compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
x = pt_sum(a_pt, axis=1)
compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
a_pt = matrix("a")
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
x = prod(a_pt, axis=0)
compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
x = pt_all(a_pt)
compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis):
x = matrix("x")
x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = softmax(x, axis=axis)
compare_mlx_and_py([x], [out], [x_test_value])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax_grad(axis):
dy = matrix("dy")
dy_test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
sm = matrix("sm")
sm_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = SoftmaxGrad(axis=axis)(dy, sm)
compare_mlx_and_py([dy, sm], [out], [dy_test_value, sm_test_value])
@pytest.mark.parametrize("size", [(10, 10), (1000, 1000)])
@pytest.mark.parametrize("axis", [0, 1])
def test_logsumexp_benchmark(size, axis, benchmark):
X = matrix("X")
X_max = pt_max(X, axis=axis, keepdims=True)
X_max = switch(isinf(X_max), 0, X_max)
X_lse = log(pt_sum(exp(X - X_max), axis=axis, keepdims=True)) + X_max
rng = np.random.default_rng(23920)
X_val = rng.normal(size=size)
X_lse_fn = function([X], X_lse, mode="MLX")
# JIT compile first
_ = X_lse_fn(X_val)
res = benchmark(X_lse_fn, X_val)
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res)
def test_multiple_input_multiply():
x, y, z = vectors("xyz")
out = mul(x, y, z)
compare_mlx_and_py([x, y, z], [out], test_inputs=[[1.5], [2.5], [3.5]])
@pytest.mark.parametrize(
"op",
[
pytest.param(exp, id="exp"),
pytest.param(log, id="log"),
pytest.param(sin, id="sin"),
pytest.param(cos, id="cos"),
pytest.param(sigmoid, id="sigmoid"),
],
)
def test_elemwise_one_input(op) -> None:
x = vector("x")
out = op(x)
x_test = mx.array([1.0, 2.0, 3.0])
compare_mlx_and_py([x], out, [x_test])
@pytest.mark.parametrize(
"op",
[
add,
sub,
mul,
power,
le,
lt,
ge,
gt,
eq,
neq,
true_div,
int_div,
],
ids=[
"add",
"sub",
"mul",
"power",
"le",
"lt",
"ge",
"gt",
"eq",
"neq",
"true_div",
"int_div",
],
)
def test_elemwise_two_inputs(op) -> None:
x = vector("x")
y = vector("y")
out = op(x, y)
x_test = mx.array([1.0, 2.0, 3.0])
y_test = mx.array([4.0, 5.0, 6.0])
compare_mlx_and_py([x, y], out, [x_test, y_test])
import numpy as np
import pytest
import pytensor
import pytensor.tensor as pt
from pytensor.tensor.math import Argmax, Max
from tests.link.mlx.test_basic import compare_mlx_and_py
mx = pytest.importorskip("mlx.core")
def test_dot():
x = pt.matrix("x")
y = pt.matrix("y")
out = x.dot(y)
fn = pytensor.function([x, y], out, mode="MLX")
seed = sum(map(ord, "test_mlx_dot"))
rng = np.random.default_rng(seed)
test_x = rng.normal(size=(3, 2))
test_y = rng.normal(size=(2, 4))
actual = fn(test_x, test_y)
assert isinstance(actual, mx.array)
expected = np.dot(test_x, test_y)
np.testing.assert_allclose(actual, expected, rtol=1e-6)
def test_switch() -> None:
x = pt.vector("x")
y = pt.vector("y")
out = pt.switch(x > 0, y, x)
x_test = mx.array([-1.0, 2.0, 3.0])
y_test = mx.array([4.0, 5.0, 6.0])
compare_mlx_and_py([x, y], out, [x_test, y_test])
def test_int_div_specific() -> None:
x = pt.vector("x")
y = pt.vector("y")
out = pt.int_div(x, y)
# Test with integers that demonstrate floor division behavior
x_test = mx.array([7.0, 8.0, 9.0, -7.0, -8.0])
y_test = mx.array([3.0, 3.0, 3.0, 3.0, 3.0])
compare_mlx_and_py([x, y], out, [x_test, y_test])
def test_isnan() -> None:
x = pt.vector("x")
out = pt.isnan(x)
x_test = mx.array([1.0, np.nan, 3.0, np.inf, -np.nan, 0.0, -np.inf])
compare_mlx_and_py([x], out, [x_test])
def test_isnan_edge_cases() -> None:
x = pt.scalar("x")
out = pt.isnan(x)
# Test individual cases
test_cases = [0.0, np.nan, np.inf, -np.inf, 1e-10, 1e10]
for test_val in test_cases:
x_test = test_val
compare_mlx_and_py([x], out, [x_test])
def test_erfc() -> None:
"""Test complementary error function"""
x = pt.vector("x")
out = pt.erfc(x)
# Test with various values including negative, positive, and zero
x_test = mx.array([0.0, 0.5, 1.0, -0.5, -1.0, 2.0, -2.0, 0.1])
compare_mlx_and_py([x], out, [x_test])
def test_erfc_extreme_values() -> None:
"""Test erfc with extreme values"""
x = pt.vector("x")
out = pt.erfc(x)
# Test with larger values where erfc approaches 0 or 2
x_test = mx.array([-3.0, -2.5, 2.5, 3.0])
# Use relaxed tolerance for extreme values due to numerical precision differences
from functools import partial
relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-3, atol=1e-6)
compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)
def test_erfcx() -> None:
"""Test scaled complementary error function"""
x = pt.vector("x")
out = pt.erfcx(x)
# Test with positive values where erfcx is most numerically stable
x_test = mx.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5])
compare_mlx_and_py([x], out, [x_test])
def test_erfcx_small_values() -> None:
"""Test erfcx with small values"""
x = pt.vector("x")
out = pt.erfcx(x)
# Test with small values
x_test = mx.array([0.001, 0.01, 0.1, 0.2])
compare_mlx_and_py([x], out, [x_test])
def test_softplus() -> None:
"""Test softplus (log(1 + exp(x))) function"""
x = pt.vector("x")
out = pt.softplus(x)
# Test with normal range values
x_test = mx.array([0.0, 1.0, 2.0, -1.0, -2.0, 10.0])
compare_mlx_and_py([x], out, [x_test])
def test_softplus_extreme_values() -> None:
"""Test softplus with extreme values to verify numerical stability"""
x = pt.vector("x")
out = pt.softplus(x)
# Test with extreme values where different branches of the implementation are used
x_test = mx.array([-40.0, -50.0, 20.0, 30.0, 35.0, 50.0])
# Use relaxed tolerance for extreme values due to numerical precision differences
from functools import partial
relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-4, atol=1e-8)
compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)
def test_mlx_max_and_argmax():
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
x = pt.dvector()
mx = Max([0])(x)
amx = Argmax([0])(x)
out = mx * amx
compare_mlx_and_py([x], [out], [np.r_[1, 2]])
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config
from pytensor.tensor.shape import Shape, Shape_i, reshape
from pytensor.tensor.type import iscalar, vector
from tests.link.mlx.test_basic import compare_mlx_and_py
def test_mlx_shape_ops():
x_np = np.zeros((20, 3))
x = Shape()(pt.as_tensor_variable(x_np))
compare_mlx_and_py([], [x], [], must_be_device_array=False)
x = Shape_i(1)(pt.as_tensor_variable(x_np))
compare_mlx_and_py([], [x], [], must_be_device_array=False)
def test_mlx_specify_shape():
in_pt = pt.matrix("in")
x = pt.specify_shape(in_pt, (4, None))
compare_mlx_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)])
# When used to assert two arrays have similar shapes
in_pt = pt.matrix("in")
shape_pt = pt.matrix("shape")
x = pt.specify_shape(in_pt, shape_pt.shape)
compare_mlx_and_py(
[in_pt, shape_pt],
[x],
[np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
)
def test_mlx_Reshape_constant():
a = vector("a")
x = reshape(a, (2, 2))
compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
def test_mlx_Reshape_various_shapes():
"""Test reshape with various different shapes to ensure robustness."""
# 1D to 2D
a = vector("a")
x = reshape(a, (2, 3))
compare_mlx_and_py([a], [x], [np.arange(6, dtype=config.floatX)])
# 2D to 1D
b = pt.matrix("b")
y = reshape(b, (6,))
compare_mlx_and_py([b], [y], [np.arange(6, dtype=config.floatX).reshape(2, 3)])
# 2D to 3D
c = pt.matrix("c")
z = reshape(c, (2, 2, 3))
compare_mlx_and_py([c], [z], [np.arange(12, dtype=config.floatX).reshape(4, 3)])
# 3D to 2D
d = pt.tensor3("d")
w = reshape(d, (3, 4))
compare_mlx_and_py([d], [w], [np.arange(12, dtype=config.floatX).reshape(2, 2, 3)])
def test_mlx_Reshape_negative_one():
"""Test reshape with -1 dimension (infer dimension)."""
a = vector("a")
# Use -1 to infer the second dimension
x = reshape(a, (2, -1))
compare_mlx_and_py([a], [x], [np.arange(8, dtype=config.floatX)])
# Use -1 to infer the first dimension
y = reshape(a, (-1, 4))
compare_mlx_and_py([a], [y], [np.arange(8, dtype=config.floatX)])
def test_mlx_Reshape_concrete_shape():
"""MLX should compile when a concrete value is passed for the `shape` parameter."""
a = vector("a")
x = reshape(a, a.shape)
compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2))
compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
@pytest.mark.xfail(reason="`shape_pt` should be specified as a static argument")
def test_mlx_Reshape_shape_graph_input():
a = vector("a")
shape_pt = iscalar("b")
x = reshape(a, (shape_pt, shape_pt))
compare_mlx_and_py(
[a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]
)
@pytest.mark.xfail(reason="ViewOp Op is not supported yet")
def test_mlx_compile_ops():
x = DeepCopyOp()(pt.as_tensor_variable(1.1))
compare_mlx_and_py([], [x], [])
x_np = np.zeros((20, 1, 1))
x = ViewOp()(pt.as_tensor_variable(x_np))
compare_mlx_and_py([], [x], [])
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor.tensor import subtensor as pt_subtensor
from pytensor.tensor import tensor
from tests.link.mlx.test_basic import compare_mlx_and_py
mx = pytest.importorskip("mlx.core")
def test_mlx_Subtensor_basic():
"""Test basic subtensor operations with constant indices."""
shape = (3, 4, 5)
x_pt = tensor("x", shape=shape, dtype="float32")
x_np = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
# Basic indexing with single elements
out_pt = x_pt[1, 2, 0]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
# Basic indexing with slices
out_pt = x_pt[1:, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[1:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
# Negative indexing
out_pt = x_pt[-1, -1, -1]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
# Step slicing
out_pt = x_pt[::2, ::2, ::2]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
# Reverse indexing
out_pt = x_pt[::-1, ::-1, ::-1]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
def test_mlx_AdvancedSubtensor():
"""Test advanced subtensor operations."""
shape = (3, 4, 5)
x_pt = tensor("x", shape=shape, dtype="float32")
x_np = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
# Advanced indexing with array indices
out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
# Multi-dimensional advanced indexing
out_pt = x_pt[[1, 2], [2, 3]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
# Mixed advanced and basic indexing
out_pt = x_pt[[1, 2], :]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[[1, 2], :, [3, 4]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
@pytest.mark.xfail(
raises=ValueError, reason="MLX does not support boolean indexing yet"
)
def test_mlx_AdvancedSubtensor_boolean():
"""Test advanced subtensor operations with boolean indexing."""
shape = (3, 4, 5)
x_pt = tensor("x", shape=shape, dtype="float32")
x_np = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
# Boolean indexing with constant mask
bool_mask = np.array([True, False, True])
out_pt = x_pt[bool_mask]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
def test_mlx_IncSubtensor_set():
"""Test set operations using IncSubtensor (set_instead_of_inc=True)."""
# Test data
x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
x_pt = pt.constant(x_np)
# Set single element
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=np.float32))
out_pt = pt_subtensor.set_subtensor(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
assert out_pt.owner.op.set_instead_of_inc
compare_mlx_and_py([], [out_pt], [])
def test_mlx_IncSubtensor_increment():
"""Test increment operations using IncSubtensor (set_instead_of_inc=False)."""
# Test data
x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
x_pt = pt.constant(x_np)
# Increment single element
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=np.float32))
out_pt = pt_subtensor.inc_subtensor(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
assert not out_pt.owner.op.set_instead_of_inc
compare_mlx_and_py([], [out_pt], [])
def test_mlx_AdvancedIncSubtensor_set():
"""Test advanced set operations using AdvancedIncSubtensor."""
rng = np.random.default_rng(213234)
# Test data
x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
x_pt = pt.constant(x_np)
# Set with advanced indexing - this actually works in MLX!
st_pt = pt.as_tensor_variable(rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32))
out_pt = pt_subtensor.set_subtensor(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
assert out_pt.owner.op.set_instead_of_inc
compare_mlx_and_py([], [out_pt], [])
def test_mlx_AdvancedIncSubtensor_increment():
"""Test advanced increment operations using AdvancedIncSubtensor."""
rng = np.random.default_rng(213234)
# Test data
x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
x_pt = pt.constant(x_np)
# Increment with advanced indexing - this actually works in MLX!
st_pt = pt.as_tensor_variable(rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32))
out_pt = pt_subtensor.inc_subtensor(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
assert not out_pt.owner.op.set_instead_of_inc
compare_mlx_and_py([], [out_pt], [])
def test_mlx_AdvancedIncSubtensor1_operations():
"""Test AdvancedIncSubtensor1 operations (handled by IncSubtensor dispatcher)."""
rng = np.random.default_rng(213234)
# Test data
x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
x_pt = pt.constant(x_np)
# Test set operation - this actually works in MLX!
st_pt = pt.as_tensor_variable(rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32))
indices = [1, 2]
# Create AdvancedIncSubtensor1 manually for set operation
out_pt = pt_subtensor.advanced_set_subtensor1(x_pt, st_pt, indices)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1)
assert out_pt.owner.op.set_instead_of_inc
compare_mlx_and_py([], [out_pt], [])
@pytest.mark.xfail(reason="Inplace operations not yet supported in MLX mode")
def test_mlx_inplace_variants():
"""Test inplace variants of all subtensor operations."""
# Test data
x_np = np.arange(12, dtype=np.float32).reshape((3, 4))
x_pt = pt.constant(x_np)
# Test inplace IncSubtensor (set)
st_pt = pt.as_tensor_variable(np.array([-1.0, -2.0], dtype=np.float32))
out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], st_pt, inplace=True)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
assert out_pt.owner.op.inplace
assert out_pt.owner.op.set_instead_of_inc
compare_mlx_and_py([], [out_pt], [])
@pytest.mark.xfail(
reason="MLX slice indices must be integers or None, dynamic slices not supported"
)
def test_mlx_MakeSlice():
"""Test MakeSlice operation."""
# Test slice creation
start = pt.iscalar("start")
stop = pt.iscalar("stop")
step = pt.iscalar("step")
# Create a slice using MakeSlice
slice_op = pt_subtensor.MakeSlice()
slice_pt = slice_op(start, stop, step)
# Use simple constant array instead of arange
x_pt = pt.constant(np.arange(10, dtype=np.float32))
out_pt = x_pt[slice_pt]
compare_mlx_and_py([start, stop, step], [out_pt], [1, 8, 2])
def test_mlx_subtensor_edge_cases():
"""Test edge cases and boundary conditions."""
# Empty slices - use constant array
x_pt = pt.constant(np.arange(10, dtype=np.float32))
out_pt = x_pt[5:5] # Empty slice
compare_mlx_and_py([], [out_pt], [])
# Single element arrays
x_pt = pt.tensor(shape=(1,), dtype="float32", name="x")
x_np = np.array([42.0], dtype=np.float32)
out_pt = x_pt[0]
compare_mlx_and_py([x_pt], [out_pt], [x_np])
# Large step sizes - use constant array
x_pt = pt.constant(np.arange(20, dtype=np.float32))
out_pt = x_pt[::5]
compare_mlx_and_py([], [out_pt], [])
# Negative steps - use constant array
x_pt = pt.constant(np.arange(10, dtype=np.float32))
out_pt = x_pt[::-2]
compare_mlx_and_py([], [out_pt], [])
@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported")
def test_mlx_subtensor_with_variables():
"""Test subtensor operations with PyTensor variables as inputs."""
# Test with variable arrays (not constants)
x_pt = pt.matrix("x", dtype="float32")
y_pt = pt.vector("y", dtype="float32")
x_np = np.arange(12, dtype=np.float32).reshape((3, 4))
y_np = np.array([-1.0, -2.0], dtype=np.float32)
# Set operation with variables
out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt)
compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论