Unverified 提交 934306f2 authored 作者: Will Dean's avatar Will Dean 提交者: GitHub
上级 92c3b490
...@@ -81,6 +81,7 @@ jobs: ...@@ -81,6 +81,7 @@ jobs:
install-numba: [0] install-numba: [0]
install-jax: [0] install-jax: [0]
install-torch: [0] install-torch: [0]
install-mlx: [0]
install-xarray: [0] install-xarray: [0]
part: part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/xtensor" - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/xtensor"
...@@ -106,6 +107,7 @@ jobs: ...@@ -106,6 +107,7 @@ jobs:
install-numba: 0 install-numba: 0
install-jax: 0 install-jax: 0
install-torch: 0 install-torch: 0
install-mlx: 0
install-xarray: 0 install-xarray: 0
- install-numba: 1 - install-numba: 1
os: "ubuntu-latest" os: "ubuntu-latest"
...@@ -149,7 +151,16 @@ jobs: ...@@ -149,7 +151,16 @@ jobs:
fast-compile: 0 fast-compile: 0
float32: 0 float32: 0
part: "tests/xtensor" part: "tests/xtensor"
- os: macos-15 - os: "macos-15"
python-version: "3.11"
fast-compile: 0
float32: 0
install-mlx: 1
install-numba: 0
install-jax: 0
install-torch: 0
part: "tests/link/mlx"
- os: "macos-15"
python-version: "3.13" python-version: "3.13"
fast-compile: 0 fast-compile: 0
float32: 0 float32: 0
...@@ -194,6 +205,7 @@ jobs: ...@@ -194,6 +205,7 @@ jobs:
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mlx; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
pip install -e ./ pip install -e ./
...@@ -210,6 +222,7 @@ jobs: ...@@ -210,6 +222,7 @@ jobs:
INSTALL_JAX: ${{ matrix.install-jax }} INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}} INSTALL_TORCH: ${{ matrix.install-torch}}
INSTALL_XARRAY: ${{ matrix.install-xarray }} INSTALL_XARRAY: ${{ matrix.install-xarray }}
INSTALL_MLX: ${{ matrix.install-mlx }}
OS: ${{ matrix.os}} OS: ${{ matrix.os}}
- name: Run tests - name: Run tests
......
...@@ -27,7 +27,6 @@ __pycache__ ...@@ -27,7 +27,6 @@ __pycache__
\#*\# \#*\#
build build
compiled/*.cpp compiled/*.cpp
core.*
cutils_ext.cpp cutils_ext.cpp
dist dist
doc/.build/ doc/.build/
......
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Obtaining file:///Users/carlostrujillo/Documents/GitHub/pytensor\n",
" Installing build dependencies ... \u001b[?25ldone\n",
"\u001b[?25h Checking if build backend supports build_editable ... \u001b[?25ldone\n",
"\u001b[?25h Getting requirements to build editable ... \u001b[?25ldone\n",
"\u001b[?25h Preparing editable metadata (pyproject.toml) ... \u001b[?25ldone\n",
"\u001b[?25hBuilding wheels for collected packages: pytensor\n",
" Building editable for pytensor (pyproject.toml) ... \u001b[?25ldone\n",
"\u001b[?25h Created wheel for pytensor: filename=pytensor-2.31.7+80.g06ccf91ba.dirty-0.editable-cp312-cp312-macosx_11_0_arm64.whl size=7323 sha256=c09587a5f3141d49000666d2817c5a01436f13ff5a19aa3deda20f647660afee\n",
" Stored in directory: /private/var/folders/f0/rbz8xs8s17n3k3f_ccp31bvh0000gn/T/pip-ephem-wheel-cache-i00nb67k/wheels/52/f6/4c/e6784e2203d5405c94db1d544248730e598e4397674416af05\n",
"Successfully built pytensor\n",
"Installing collected packages: pytensor\n",
" Attempting uninstall: pytensor\n",
" Found existing installation: pytensor 2.31.7+80.g06ccf91ba.dirty\n",
" Uninstalling pytensor-2.31.7+80.g06ccf91ba.dirty:\n",
" Successfully uninstalled pytensor-2.31.7+80.g06ccf91ba.dirty\n",
"Successfully installed pytensor-2.31.7+80.g06ccf91ba.dirty\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install -e ../.. --no-deps"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import numpy as np\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import pytensor\n",
"import pytensor.tensor as pt\n",
"from pytensor.compile.function import function\n",
"from pytensor.compile.mode import Mode\n",
"from pytensor.graph import RewriteDatabaseQuery\n",
"from pytensor.link.jax import JAXLinker\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Configure JAX to use float32 for consistency with MLX\n",
"jax.config.update(\"jax_enable_x64\", False)\n",
"\n",
"# Set up PyTensor JAX mode\n",
"jax_optimizer = RewriteDatabaseQuery(include=[\"jax\"], exclude=[])\n",
"pytensor_jax_mode = \"JAX\"\n",
"\n",
"# Try to set up MLX mode\n",
"try:\n",
" from pytensor.link.mlx import MLXLinker\n",
" import mlx.core as mx\n",
" mlx_optimizer = RewriteDatabaseQuery(include=[\"mlx\"], exclude=[])\n",
" pytensor_mlx_mode = \"MLX\"\n",
" MLX_AVAILABLE = True\n",
"except ImportError:\n",
" MLX_AVAILABLE = False\n",
"\n",
"def timer_jax(func, N=1000):\n",
" \"\"\"Time function execution with proper JAX synchronization, repeated N times\"\"\"\n",
" def wrapper(*args, **kwargs):\n",
" times = []\n",
" for _ in range(N):\n",
" start = time.perf_counter()\n",
" result = func(*args, **kwargs)\n",
" if hasattr(result, 'block_until_ready'):\n",
" result.block_until_ready()\n",
" elif isinstance(result, (list, tuple)):\n",
" for r in result:\n",
" if hasattr(r, 'block_until_ready'):\n",
" r.block_until_ready()\n",
" end = time.perf_counter()\n",
" times.append(end - start)\n",
" \n",
" mean_time = np.mean(times)\n",
" std_time = np.std(times)\n",
" return result, mean_time, std_time\n",
" return wrapper\n",
"\n",
"def timer_mlx(func, N=1000):\n",
" \"\"\"Time function execution with proper MLX synchronization, repeated N times\"\"\"\n",
" def wrapper(*args, **kwargs):\n",
" times = []\n",
" for _ in range(N):\n",
" start = time.perf_counter()\n",
" result = func(*args, **kwargs)\n",
" # For MLX, we need to use mx.eval() to force computation\n",
" if MLX_AVAILABLE:\n",
" if isinstance(result, (list, tuple)):\n",
" mx.eval(*result)\n",
" else:\n",
" mx.eval(result)\n",
" end = time.perf_counter()\n",
" times.append(end - start)\n",
" \n",
" mean_time = np.mean(times)\n",
" std_time = np.std(times)\n",
" return result, mean_time, std_time\n",
" return wrapper\n",
"\n",
"def run_benchmark(N=1000):\n",
" \"\"\"Run comprehensive benchmark comparing PyTensor JAX vs MLX backends\"\"\"\n",
" import pandas as pd\n",
" \n",
" sizes = [2, 4, 1080, 2080, 3080]\n",
" results = []\n",
" \n",
" print(f\"Running benchmarks with N={N} repetitions per test...\")\n",
" \n",
" for size in sizes:\n",
" print(f\"Testing {size}x{size} matrices...\")\n",
" \n",
" # Generate test matrices with fixed seed for reproducibility\n",
" np.random.seed(42)\n",
" A = np.random.randn(size, size).astype(np.float32)\n",
" B = np.random.randn(size, size).astype(np.float32)\n",
" C = np.random.randn(size, size).astype(np.float32)\n",
"\n",
" pt_A = pt.matrix('A', dtype='float32')\n",
" pt_B = pt.matrix('B', dtype='float32') \n",
" pt_C = pt.matrix('C', dtype='float32')\n",
" result = pt.dot(pt.dot(pt_A, pt_B), pt_C)\n",
"\n",
"\n",
" f_jax = function([pt_A, pt_B, pt_C], result, mode=pytensor_jax_mode, trust_input=True)\n",
" f_mlx = function([pt_A, pt_B, pt_C], result, mode=pytensor_mlx_mode, trust_input=True)\n",
" f_jax(A, B, C)\n",
" f_mlx(A, B, C)\n",
" \n",
" # === TEST 1: Matrix Multiplication Chain ===\n",
" # PyTensor + JAX backend\n",
" @timer_jax\n",
" def pytensor_jax_matmul():\n",
" return f_jax(A, B, C)\n",
" \n",
" # PyTensor + MLX backend\n",
" @timer_mlx\n",
" def pytensor_mlx_matmul():\n",
" if not MLX_AVAILABLE:\n",
" return None, float('inf'), 0\n",
" return f_mlx(A, B, C)\n",
" \n",
" # Run matrix multiplication test\n",
" _, jax_mean, jax_std = pytensor_jax_matmul()\n",
" try:\n",
" _, mlx_mean, mlx_std = pytensor_mlx_matmul()\n",
" except Exception as e:\n",
" print(f\"MLX matmul error: {e}\")\n",
" mlx_mean, mlx_std = float('inf'), 0\n",
" \n",
" # Calculate percentage improvement (positive = MLX is faster, negative = MLX is slower)\n",
" if mlx_mean != float('inf') and mlx_mean > 0:\n",
" speedup_percentage = ((jax_mean - mlx_mean) / jax_mean) * 100\n",
" speedup_str = f'{speedup_percentage:+.1f}%'\n",
" else:\n",
" speedup_str = 'N/A'\n",
" \n",
" results.append({\n",
" 'Size': f'{size}x{size}',\n",
" 'Operation': 'Matrix Chain (A @ B @ C)',\n",
" 'PyTensor+JAX Mean (s)': f'{jax_mean:.6f}',\n",
" 'PyTensor+JAX Std (s)': f'{jax_std:.6f}',\n",
" 'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',\n",
" 'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',\n",
" 'MLX Performance': speedup_str\n",
" })\n",
" \n",
" # === TEST 2: Element-wise Operations ===\n",
" # PyTensor + JAX\n",
" result = pt.sin(pt_A) + pt.cos(pt_B)\n",
" f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)\n",
" f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)\n",
" f_jax(A, B)\n",
" f_mlx(A, B)\n",
"\n",
" @timer_jax\n",
" def pytensor_jax_elemwise():\n",
" return f_jax(A, B)\n",
" \n",
" # PyTensor + MLX\n",
" @timer_mlx\n",
" def pytensor_mlx_elemwise():\n",
" if not MLX_AVAILABLE:\n",
" return None, float('inf'), 0\n",
" return f_mlx(A, B)\n",
" \n",
" # Run element-wise test\n",
" _, jax_mean, jax_std = pytensor_jax_elemwise()\n",
" try:\n",
" _, mlx_mean, mlx_std = pytensor_mlx_elemwise()\n",
" except Exception as e:\n",
" print(f\"MLX elemwise error: {e}\")\n",
" mlx_mean, mlx_std = float('inf'), 0\n",
" \n",
" # Calculate percentage improvement\n",
" if mlx_mean != float('inf') and mlx_mean > 0:\n",
" speedup_percentage = ((jax_mean - mlx_mean) / jax_mean) * 100\n",
" speedup_str = f'{speedup_percentage:+.1f}%'\n",
" else:\n",
" speedup_str = 'N/A'\n",
" \n",
" results.append({\n",
" 'Size': f'{size}x{size}',\n",
" 'Operation': 'Element-wise (sin(A) + cos(B))',\n",
" 'PyTensor+JAX Mean (s)': f'{jax_mean:.6f}',\n",
" 'PyTensor+JAX Std (s)': f'{jax_std:.6f}',\n",
" 'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',\n",
" 'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',\n",
" 'MLX Performance': speedup_str\n",
" })\n",
" \n",
" # === TEST 3: Matrix Addition with Broadcasting ===\n",
" # PyTensor + JAX\n",
" result = pt_A + pt_B.T\n",
" f_jax = function([pt_A, pt_B], result, mode=pytensor_jax_mode, trust_input=True)\n",
" f_mlx = function([pt_A, pt_B], result, mode=pytensor_mlx_mode, trust_input=True)\n",
" f_jax(A, B)\n",
" f_mlx(A, B)\n",
" @timer_jax\n",
" def pytensor_jax_broadcast():\n",
" return f_jax(A, B)\n",
" \n",
" # PyTensor + MLX\n",
" @timer_mlx\n",
" def pytensor_mlx_broadcast():\n",
" if not MLX_AVAILABLE:\n",
" return None, float('inf'), 0\n",
" return f_mlx(A, B)\n",
" \n",
" # Run broadcasting test\n",
" _, jax_mean, jax_std = pytensor_jax_broadcast()\n",
" try:\n",
" _, mlx_mean, mlx_std = pytensor_mlx_broadcast()\n",
" except Exception as e:\n",
" print(f\"MLX broadcast error: {e}\")\n",
" mlx_mean, mlx_std = float('inf'), 0\n",
" \n",
" # Calculate percentage improvement\n",
" if mlx_mean != float('inf') and mlx_mean > 0:\n",
" speedup_percentage = ((jax_mean - mlx_mean) / jax_mean) * 100\n",
" speedup_str = f'{speedup_percentage:+.1f}%'\n",
" else:\n",
" speedup_str = 'N/A'\n",
" \n",
" results.append({\n",
" 'Size': f'{size}x{size}',\n",
" 'Operation': 'Broadcasting (A + B.T)',\n",
" 'PyTensor+JAX Mean (s)': f'{jax_mean:.6f}',\n",
" 'PyTensor+JAX Std (s)': f'{jax_std:.6f}',\n",
" 'PyTensor+MLX Mean (s)': f'{mlx_mean:.6f}' if mlx_mean != float('inf') else 'Error',\n",
" 'PyTensor+MLX Std (s)': f'{mlx_std:.6f}' if mlx_mean != float('inf') else 'N/A',\n",
" 'MLX Performance': speedup_str\n",
" })\n",
" \n",
" # Create and display results table\n",
" df = pd.DataFrame(results)\n",
" return df\n",
"\n",
"def main(N=1000):\n",
" \"\"\"Main benchmark execution\"\"\"\n",
" # Display system info\n",
" system_info = {\n",
" 'JAX version': jax.__version__,\n",
" 'PyTensor version': pytensor.__version__,\n",
" 'MLX Available': 'Yes' if MLX_AVAILABLE else 'No',\n",
" 'Platform': 'Apple Silicon' if MLX_AVAILABLE else 'Generic',\n",
" 'Repetitions (N)': N\n",
" }\n",
" \n",
" if MLX_AVAILABLE:\n",
" system_info['MLX version'] = mx.__version__\n",
" \n",
" import pandas as pd\n",
" info_df = pd.DataFrame([system_info])\n",
" \n",
" # Then run benchmarks\n",
" results_df = run_benchmark(N=N)\n",
" \n",
" return info_df, results_df\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running benchmarks with N=150 repetitions per test...\n",
"Testing 2x2 matrices...\n",
"Testing 4x4 matrices...\n",
"Testing 1080x1080 matrices...\n",
"Testing 2080x2080 matrices...\n",
"Testing 3080x3080 matrices...\n"
]
}
],
"source": [
"iteration=150\n",
"_, results = main(N=iteration)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Benchmark Results over 150 repetitions:\n",
" Size Operation PyTensor+JAX Mean (s) PyTensor+JAX Std (s) PyTensor+MLX Mean (s) PyTensor+MLX Std (s) MLX Performance\n",
" 2x2 Matrix Chain (A @ B @ C) 0.000009 0.000002 0.000305 0.000299 -3213.5%\n",
" 2x2 Element-wise (sin(A) + cos(B)) 0.000007 0.000002 0.000352 0.003757 -5078.0%\n",
" 2x2 Broadcasting (A + B.T) 0.000007 0.000001 0.000188 0.000153 -2721.1%\n",
" 4x4 Matrix Chain (A @ B @ C) 0.000009 0.000001 0.000209 0.000063 -2126.2%\n",
" 4x4 Element-wise (sin(A) + cos(B)) 0.000007 0.000001 0.000180 0.000066 -2449.5%\n",
" 4x4 Broadcasting (A + B.T) 0.000007 0.000003 0.000181 0.000065 -2564.1%\n",
"1080x1080 Matrix Chain (A @ B @ C) 0.005951 0.000356 0.001355 0.000392 +77.2%\n",
"1080x1080 Element-wise (sin(A) + cos(B)) 0.002820 0.000107 0.000432 0.000207 +84.7%\n",
"1080x1080 Broadcasting (A + B.T) 0.000212 0.000035 0.000428 0.000206 -102.0%\n",
"2080x2080 Matrix Chain (A @ B @ C) 0.027609 0.001255 0.004550 0.002528 +83.5%\n",
"2080x2080 Element-wise (sin(A) + cos(B)) 0.010086 0.000417 0.001175 0.000350 +88.3%\n",
"2080x2080 Broadcasting (A + B.T) 0.000856 0.000068 0.001124 0.000241 -31.2%\n",
"3080x3080 Matrix Chain (A @ B @ C) 0.093115 0.003823 0.013649 0.000513 +85.3%\n",
"3080x3080 Element-wise (sin(A) + cos(B)) 0.022586 0.000756 0.001930 0.000287 +91.5%\n",
"3080x3080 Broadcasting (A + B.T) 0.002580 0.000161 0.001937 0.000257 +24.9%\n"
]
}
],
"source": [
"print(f\"\\nBenchmark Results over {iteration} repetitions:\")\n",
"print(results.to_string(index=False))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# # Additional timing analysis - separate compilation vs execution time\n",
"# if MLX_AVAILABLE:\n",
"# print(\"\\n=== Detailed MLX Timing Analysis ===\")\n",
" \n",
"# # Test with medium-sized matrix\n",
"# np.random.seed(42)\n",
"# A = np.random.randn(512, 512).astype(np.float32)\n",
"# B = np.random.randn(512, 512).astype(np.float32)\n",
"# C = np.random.randn(512, 512).astype(np.float32)\n",
" \n",
"# # Create PyTensor function (compilation time)\n",
"# start = time.perf_counter()\n",
"# pt_A = pt.matrix('A', dtype='float32')\n",
"# pt_B = pt.matrix('B', dtype='float32')\n",
"# pt_C = pt.matrix('C', dtype='float32')\n",
"# result_expr = pt_A @ pt_B @ pt_C\n",
"# f_mlx = function([pt_A, pt_B, pt_C], result_expr, mode=pytensor_mlx_mode)\n",
"# compilation_time = time.perf_counter() - start\n",
" \n",
"# # First execution (may include additional compilation/optimization)\n",
"# start = time.perf_counter()\n",
"# result = f_mlx(A, B, C)\n",
"# mx.eval(result) # Force evaluation\n",
"# first_exec_time = time.perf_counter() - start\n",
" \n",
"# # Subsequent executions (should be faster)\n",
"# exec_times = []\n",
"# for _ in range(1000):\n",
"# start = time.perf_counter()\n",
"# result = f_mlx(A, B, C)\n",
"# mx.eval(result)\n",
"# exec_times.append(time.perf_counter() - start)\n",
" \n",
"# avg_exec_time = np.mean(exec_times)\n",
"# std_exec_time = np.std(exec_times)\n",
" \n",
"# print(f\"Compilation time: {compilation_time:.4f}s\")\n",
"# print(f\"First execution: {first_exec_time:.4f}s\")\n",
"# print(f\"Average execution (5 runs): {avg_exec_time:.4f}s ± {std_exec_time:.4f}s\")\n",
"# print(f\"Individual execution times: {[f'{t:.4f}' for t in exec_times]}\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mlx_env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
...@@ -27,6 +27,7 @@ from pytensor.graph.rewriting.db import ( ...@@ -27,6 +27,7 @@ from pytensor.graph.rewriting.db import (
from pytensor.link.basic import Linker, PerformLinker from pytensor.link.basic import Linker, PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.link.jax.linker import JAXLinker from pytensor.link.jax.linker import JAXLinker
from pytensor.link.mlx.linker import MLXLinker
from pytensor.link.numba.linker import NumbaLinker from pytensor.link.numba.linker import NumbaLinker
from pytensor.link.pytorch.linker import PytorchLinker from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.link.vm import VMLinker from pytensor.link.vm import VMLinker
...@@ -50,6 +51,7 @@ predefined_linkers = { ...@@ -50,6 +51,7 @@ predefined_linkers = {
"jax": JAXLinker(), "jax": JAXLinker(),
"pytorch": PytorchLinker(), "pytorch": PytorchLinker(),
"numba": NumbaLinker(), "numba": NumbaLinker(),
"mlx": MLXLinker(),
} }
...@@ -504,6 +506,20 @@ PYTORCH = Mode( ...@@ -504,6 +506,20 @@ PYTORCH = Mode(
), ),
) )
MLX = Mode(
MLXLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"scan_save_mem_prealloc",
],
),
)
predefined_modes = { predefined_modes = {
"FAST_COMPILE": FAST_COMPILE, "FAST_COMPILE": FAST_COMPILE,
...@@ -511,6 +527,7 @@ predefined_modes = { ...@@ -511,6 +527,7 @@ predefined_modes = {
"JAX": JAX, "JAX": JAX,
"NUMBA": NUMBA, "NUMBA": NUMBA,
"PYTORCH": PYTORCH, "PYTORCH": PYTORCH,
"MLX": MLX,
} }
_CACHED_RUNTIME_MODES: dict[str, Mode] = {} _CACHED_RUNTIME_MODES: dict[str, Mode] = {}
......
from pytensor.link.mlx.linker import MLXLinker
# isort: off
from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify
import pytensor.link.mlx.dispatch.math
import pytensor.link.mlx.dispatch.basic
import pytensor.link.mlx.dispatch.elemwise
import pytensor.link.mlx.dispatch.shape
import pytensor.link.mlx.dispatch.subtensor
import pytensor.link.mlx.dispatch.core
import pytensor.link.mlx.dispatch.signal
import pytensor.link.mlx.dispatch.signal.conv
import pytensor.link.mlx.dispatch.blockwise
# isort: on
import warnings
from copy import deepcopy
from functools import singledispatch
from types import NoneType
import mlx.core as mx
import numpy as np
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import Assert, CheckAndRaise
@singledispatch
def mlx_typify(data, **kwargs):
raise NotImplementedError(f"mlx_typify is not implemented for {type(data)}")
@mlx_typify.register(np.ndarray)
def mlx_typify_tensor(data, dtype=None, **kwargs):
return mx.array(data, dtype=dtype)
@mlx_typify.register(slice)
@mlx_typify.register(NoneType)
@mlx_typify.register(mx.array)
def mlx_typify_no_conversion_needed(data, **kwargs):
return data
@mlx_typify.register(int)
@mlx_typify.register(float)
def mlx_typify_python_scalar(data, **kwargs):
return mx.array(data)
@mlx_typify.register(bool)
@mlx_typify.register(np.bool_)
def mlx_typify_bool(data, **kwargs):
return bool(data)
@mlx_typify.register(np.integer)
@mlx_typify.register(np.floating)
@mlx_typify.register(np.complexfloating)
def mlx_typify_numpy_scalar(data, **kwargs):
return mx.array(data)
@singledispatch
def mlx_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a MLX compatible function from an PyTensor `Op`."""
raise NotImplementedError(
f"No MLX conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/1350` for progress or to request we prioritize this operation"
)
@mlx_funcify.register(FunctionGraph)
def mlx_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="mlx_funcified_fgraph",
conversion_func=mlx_funcify,
**kwargs,
):
built_kwargs = {"conversion_func": conversion_func, **kwargs}
return fgraph_to_python(
fgraph,
conversion_func,
type_conversion_fn=mlx_typify,
fgraph_name=fgraph_name,
**built_kwargs,
)
@mlx_funcify.register(DeepCopyOp)
def mlx_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x):
return deepcopy(x)
return deepcopyop
@mlx_funcify.register(Assert)
@mlx_funcify.register(CheckAndRaise)
def mlx_funcify_CheckAndRaise(op, node, **kwargs):
conds = node.inputs[1:]
if any(isinstance(cond, Constant) and not bool(cond.data) for cond in conds):
raise op.exc_type(op.msg)
warnings.warn(
f"""Skipping `{type(op).__name__}` Op (assertion: {op.msg}) as MLX tracing would remove it.""",
stacklevel=2,
)
def assert_fn(x, *inputs):
return x
return assert_fn
import mlx.core as mx
from pytensor.link.mlx.dispatch import mlx_funcify
from pytensor.tensor.blockwise import Blockwise
@mlx_funcify.register(Blockwise)
def funcify_Blockwise(op: Blockwise, node, **kwargs):
# 2) Otherwise, get the core python function for this Blockwise
core_node = op._create_dummy_core_node(node.inputs)
core_f = mlx_funcify(op.core_op, core_node)
# 3) Determine how many inputs correspond to batch dimensions
n_batch = op.batch_ndim(node)
# 4) Handle case where no vectorization is needed
if n_batch == 0:
return core_f
# 5) Vectorize using mx.vmap over any batched inputs
in_axes: list[int | None] = []
for inp, sig in zip(node.inputs, op.inputs_sig):
batch_ndim = inp.type.ndim - len(sig)
if batch_ndim == 0:
in_axes.append(None)
continue
batch_bcast = inp.type.broadcastable[:batch_ndim]
# If all batch dims are broadcastable (size 1), treat input as static
in_axes.append(0 if not all(batch_bcast) else None)
if not any(axis == 0 for axis in in_axes):
return core_f
return mx.vmap(core_f, in_axes=tuple(in_axes))
import mlx.core as mx
import numpy as np
from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
ExtractDiag,
Eye,
Join,
MakeVector,
ScalarFromTensor,
Split,
TensorFromScalar,
Tri,
get_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
MLX_DYNAMIC_SHAPE_ERROR = (
"MLX compilation limitation: Alloc operations with dynamic shapes "
"cannot be used inside compiled functions. This is because MLX "
"compilation forbids evaluating arrays to extract shape values. "
"\n\nWorkarounds:"
"\n1. Avoid using Alloc with dynamic shapes in compiled contexts"
"\n2. Use static shapes when possible"
"\n3. Move Alloc operations outside compiled functions"
)
@mlx_funcify.register(Join)
def mlx_funcify_Join(op, **kwargs):
def join(axis, *tensors):
return mx.concatenate(tensors, axis=axis)
return join
@mlx_funcify.register(Split)
def mlx_funcify_Split(op: Split, node, **kwargs):
_, axis_sym, splits_sym = node.inputs
try:
constant_axis = get_scalar_constant_value(axis_sym)
except NotScalarConstantError:
constant_axis = None
try:
constant_splits = np.array(
[
get_scalar_constant_value(splits_sym[i])
for i in range(get_vector_length(splits_sym))
]
)
except (ValueError, NotScalarConstantError):
constant_splits = None
def split(x, axis, splits):
# Resolve constants for significant performance improvement (14x speedup)
if constant_axis is not None:
axis = int(constant_axis)
if constant_splits is not None:
splits = constant_splits
cumsum_splits = np.cumsum(splits[:-1])
else:
# Dynamic case - use MLX operations
splits_arr = mx.array(splits)
cumsum_splits = mx.cumsum(splits_arr[:-1]).tolist()
# Validation checks
if len(splits) != op.len_splits:
raise ValueError("Length of 'splits' is not equal to n_splits")
if np.sum(np.asarray(splits)) != x.shape[axis]:
raise ValueError(
"Split sizes do not sum to the input length on the chosen axis."
)
if np.any(np.asarray(splits) < 0):
raise ValueError("Split sizes cannot be negative.")
return mx.split(x, cumsum_splits, axis=axis)
return split
@mlx_funcify.register(ExtractDiag)
def mlx_funcify_ExtractDiag(op, **kwargs):
offset, axis1, axis2 = op.offset, op.axis1, op.axis2
def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2):
return mx.diagonal(x, offset=offset, axis1=axis1, axis2=axis2)
return extract_diag
@mlx_funcify.register(Eye)
def mlx_funcify_Eye(op, node, **kwargs):
# Extract constants for performance optimization
const_args = [getattr(inp, "data", None) for inp in node.inputs]
dtype = convert_dtype_to_mlx(op.dtype)
def eye(*args):
# Replace args with compile-time constants when available for better performance
args = [
arg if const_a is None else const_a
for arg, const_a in zip(args, const_args, strict=True)
]
N, M, k = args
return mx.eye(int(N), int(M), int(k), dtype=dtype)
return eye
def convert_dtype_to_mlx(dtype_str, auto_cast_unsupported=True):
"""Convert PyTensor dtype strings to MLX dtype objects.
MLX expects dtype objects rather than string literals for type conversion.
This function maps common dtype strings to their MLX equivalents.
Parameters
----------
dtype_str : str or MLX dtype
The dtype to convert
auto_cast_unsupported : bool
If True, automatically cast unsupported dtypes to supported ones with warnings
Returns
-------
MLX dtype object
"""
import warnings
if isinstance(dtype_str, str):
if dtype_str == "bool":
return mx.bool_
elif dtype_str == "int8":
return mx.int8
elif dtype_str == "int16":
return mx.int16
elif dtype_str == "int32":
return mx.int32
elif dtype_str == "int64":
return mx.int64
elif dtype_str == "uint8":
return mx.uint8
elif dtype_str == "uint16":
return mx.uint16
elif dtype_str == "uint32":
return mx.uint32
elif dtype_str == "uint64":
return mx.uint64
elif dtype_str == "float16":
return mx.float16
elif dtype_str == "float32":
return mx.float32
elif dtype_str == "float64":
if auto_cast_unsupported:
warnings.warn(
"MLX does not support float64 on GPU. Automatically casting to float32. "
"This may result in reduced precision. To avoid this warning, "
"explicitly use float32 in your code or set floatX='float32' in PyTensor config.",
UserWarning,
stacklevel=3,
)
return mx.float32
else:
return mx.float64
elif dtype_str == "bfloat16":
return mx.bfloat16
elif dtype_str == "complex64":
return mx.complex64
elif dtype_str == "complex128":
if auto_cast_unsupported:
warnings.warn(
"MLX does not support complex128. Automatically casting to complex64. "
"This may result in reduced precision. To avoid this warning, "
"explicitly use complex64 in your code.",
UserWarning,
stacklevel=3,
)
return mx.complex64
else:
# Return the original even though it might fail
# This allows users to opt out of auto-casting if needed
return mx.complex64 # MLX doesn't have complex128, so fallback
# Return as is if it's already an MLX dtype or not a recognized string
return dtype_str
@mlx_funcify.register(MakeVector)
def mlx_funcify_MakeVector(op, **kwargs):
dtype = convert_dtype_to_mlx(op.dtype)
def makevector(*x):
return mx.array(x, dtype=dtype)
return makevector
@mlx_funcify.register(TensorFromScalar)
def mlx_funcify_TensorFromScalar(op, **kwargs):
def tensor_from_scalar(x):
return x # already an MLX array / scalar
return tensor_from_scalar
@mlx_funcify.register(ScalarFromTensor)
def mlx_funcify_ScalarFromTensor(op, **kwargs):
def scalar_from_tensor(x):
"We can't not return a scalar in MLX without trigger evaluation"
return x
return scalar_from_tensor
@mlx_funcify.register(Tri)
def mlx_funcify_Tri(op, node, **kwargs):
# node.inputs -> N, M, k
const_args = [getattr(inp, "data", None) for inp in node.inputs]
dtype = convert_dtype_to_mlx(op.dtype)
def tri(*args):
# Replace args with compile-time constants when available
args = [
arg if const_a is None else const_a
for arg, const_a in zip(args, const_args, strict=True)
]
return mx.tri(*args, dtype=dtype)
return tri
@mlx_funcify.register(AllocEmpty)
def mlx_funcify_AllocEmpty(op, node, **kwargs):
dtype = convert_dtype_to_mlx(op.dtype)
node_inputs = node.inputs
static_dims = (
_extract_static_dims(node_inputs)
if node_inputs and len(node_inputs) > 1
else None
)
def allocempty(*shape):
resolved_shape = (
_resolve_shape(static_dims, shape)
if static_dims is not None
else tuple(_coerce_to_int(dim) for dim in shape)
)
return mx.zeros(resolved_shape, dtype=dtype)
return allocempty
@mlx_funcify.register(Alloc)
def mlx_funcify_Alloc(op, node, **kwargs):
node_inputs = node.inputs
static_dims = (
_extract_static_dims(node_inputs[1:])
if node_inputs and len(node_inputs) > 1
else None
)
def alloc(x, *shape):
resolved_shape = (
_resolve_shape(static_dims, shape)
if static_dims is not None
else tuple(_coerce_to_int(dim) for dim in shape)
)
result = mx.broadcast_to(x, resolved_shape)
if node_inputs is not None:
value_for_check = x if hasattr(x, "shape") else np.asarray(x)
Alloc._check_runtime_broadcast(node, value_for_check, resolved_shape)
return result
return alloc
def _extract_static_dims(shape_inputs):
static_dims = []
for dim in shape_inputs:
try:
static_dims.append(int(get_scalar_constant_value(dim)))
except NotScalarConstantError:
static_dims.append(None)
return tuple(static_dims)
def _resolve_shape(static_dims, runtime_shape):
if len(static_dims) != len(runtime_shape):
raise ValueError("Alloc received unexpected number of shape dimensions")
resolved = []
for const_dim, dim in zip(static_dims, runtime_shape, strict=True):
resolved.append(const_dim if const_dim is not None else _coerce_to_int(dim))
return tuple(resolved)
def _coerce_to_int(value):
if isinstance(value, np.integer | int):
return int(value)
try:
if hasattr(value, "item"):
return int(value.item())
return int(value)
except (ValueError, TypeError) as exc:
_rethrow_dynamic_shape_error(exc)
raise
raise TypeError(
"MLX Alloc expects integer shape components; got value of type "
f"{type(value).__name__}."
)
def _rethrow_dynamic_shape_error(exc):
msg = str(exc)
if "[eval] Attempting to eval an array during function transformations" in msg:
raise ValueError(f"{MLX_DYNAMIC_SHAPE_ERROR}\n\nOriginal error: {msg}") from exc
from functools import singledispatch
import mlx.core as mx
import numpy as np
from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx
from pytensor.scalar.basic import (
AND,
EQ,
GE,
GT,
LE,
LT,
NEQ,
OR,
Abs,
Add,
Cast,
Cos,
Exp,
IntDiv,
Invert,
IsInf,
IsNan,
Log,
Log1p,
Mul,
Neg,
Pow,
ScalarMaximum,
ScalarMinimum,
Sign,
Sin,
Sqr,
Sqrt,
Sub,
Switch,
TrueDiv,
)
from pytensor.scalar.math import Erfc, Erfcx, Sigmoid, Softplus
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.special import Softmax, SoftmaxGrad
@mlx_funcify.register(DimShuffle)
def mlx_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x):
# Convert scalar to array if needed
if isinstance(x, int | float) or (
isinstance(x, np.number) and not isinstance(x, np.ndarray)
):
x = mx.array(x)
res = mx.transpose(x, op.transposition)
shape = list(res.shape[: len(op.shuffle)])
for augm in op.augment:
shape.insert(augm, 1)
return mx.reshape(res, shape)
return dimshuffle
# Second-level dispatch for scalar operations in CAReduce
@singledispatch
def mlx_funcify_CAReduce_scalar_op(scalar_op, axis):
raise NotImplementedError(
f"MLX does not support CAReduce with scalar op {scalar_op}"
)
@mlx_funcify.register(CAReduce)
def mlx_funcify_CAReduce(op, **kwargs):
return mlx_funcify_CAReduce_scalar_op(op.scalar_op, op.axis)
@mlx_funcify_CAReduce_scalar_op.register(Add)
def mlx_funcify_CAReduce_scalar_Add(scalar_op, axis):
def sum_reduce(x):
return mx.sum(x, axis=axis)
return sum_reduce
@mlx_funcify_CAReduce_scalar_op.register(Mul)
def mlx_funcify_CAReduce_scalar_Mul(scalar_op, axis):
def prod_reduce(x):
return mx.prod(x, axis=axis)
return prod_reduce
@mlx_funcify_CAReduce_scalar_op.register(AND)
def mlx_funcify_CAReduce_scalar_AND(scalar_op, axis):
def all_reduce(x):
return x.all(axis=axis)
return all_reduce
@mlx_funcify_CAReduce_scalar_op.register(OR)
def mlx_funcify_CARreduce_OR(scalar_op, axis):
def any_reduce(x):
return mx.any(x, axis=axis)
return any_reduce
@mlx_funcify_CAReduce_scalar_op.register(ScalarMaximum)
def mlx_funcify_CARreduce_Maximum(scalar_op, axis):
def max_reduce(x):
return mx.max(x, axis=axis)
return max_reduce
@mlx_funcify_CAReduce_scalar_op.register(ScalarMinimum)
def mlx_funcify_CARreduce_Minimum(scalar_op, axis):
def min_reduce(x):
return mx.min(x, axis=axis)
return min_reduce
@mlx_funcify.register(Softmax)
def mlx_funcify_Softmax(op, **kwargs):
axis = op.axis
def softmax(x):
return mx.softmax(x, axis=axis)
return softmax
@mlx_funcify.register(SoftmaxGrad)
def mlx_funcify_SoftmaxGrad(op, **kwargs):
axis = op.axis
def softmax_grad(dy, sm):
dy_times_sm = dy * sm
return dy_times_sm - mx.sum(dy_times_sm, axis=axis, keepdims=True) * sm
return softmax_grad
@mlx_funcify.register(Softplus)
def mlx_funcify_Softplus(op, **kwargs):
def softplus(x):
return mx.where(
x < -37.0,
mx.exp(x),
mx.where(
x < 18.0,
mx.log1p(mx.exp(x)),
mx.where(
x < 33.3,
x + mx.exp(-x),
x,
),
),
)
return softplus
@mlx_funcify.register(Cast)
def mlx_funcify_Cast(op, **kwargs):
def cast(x):
dtype = convert_dtype_to_mlx(op.scalar_op.o_type.dtype)
try:
return x.astype(dtype)
except ValueError as e:
if "is not supported on the GPU" in str(e):
# MLX GPU limitation - try auto-casting with warning
import warnings
warnings.warn(
f"MLX GPU limitation: {e}. Attempting automatic fallback casting.",
UserWarning,
stacklevel=2,
)
# Get the auto-cast version
fallback_dtype = convert_dtype_to_mlx(
op.scalar_op.o_type.dtype, auto_cast_unsupported=True
)
return x.astype(fallback_dtype)
else:
# Re-raise other ValueError exceptions
raise
return cast
@singledispatch
def mlx_funcify_Elemwise_scalar_op(scalar_op):
"""Simplified implementation for MLX scalar operations."""
# Try using the operation name directly (most common case)
op_name = getattr(scalar_op, "name", None)
if op_name is not None:
try:
mlx_func = getattr(mx, op_name)
# Handle variadic functions like Add
if hasattr(scalar_op, "inputs") and len(scalar_op.inputs) > 2:
def variadic_func(*args):
result = args[0]
for arg in args[1:]:
result = mlx_func(result, arg)
return result
return variadic_func
else:
return mlx_func
except AttributeError:
pass
raise NotImplementedError(f"MLX does not support Elemwise scalar op {scalar_op}")
@mlx_funcify_Elemwise_scalar_op.register(Add)
def mlx_funcify_Elemwise_scalar_Add(scalar_op):
def add(*args):
result = args[0]
for arg in args[1:]:
result = mx.add(result, arg)
return result
return add
@mlx_funcify_Elemwise_scalar_op.register(Sub)
def mlx_funcify_Elemwise_scalar_Sub(scalar_op):
return mx.subtract
@mlx_funcify_Elemwise_scalar_op.register(Mul)
def mlx_funcify_Elemwise_scalar_Mul(scalar_op):
def mul(*args):
result = args[0]
for arg in args[1:]:
result = mx.multiply(result, arg)
return result
return mul
@mlx_funcify_Elemwise_scalar_op.register(TrueDiv)
def mlx_funcify_Elemwise_scalar_TrueDiv(scalar_op):
return mx.divide
@mlx_funcify_Elemwise_scalar_op.register(IntDiv)
def mlx_funcify_Elemwise_scalar_IntDiv(scalar_op):
return mx.floor_divide
@mlx_funcify_Elemwise_scalar_op.register(Pow)
def mlx_funcify_Elemwise_scalar_Pow(scalar_op):
return mx.power
@mlx_funcify_Elemwise_scalar_op.register(Exp)
def mlx_funcify_Elemwise_scalar_Exp(scalar_op):
return mx.exp
@mlx_funcify_Elemwise_scalar_op.register(Log)
def mlx_funcify_Elemwise_scalar_Log(scalar_op):
return mx.log
@mlx_funcify_Elemwise_scalar_op.register(Log1p)
def mlx_funcify_Elemwise_scalar_Log1p(scalar_op):
return mx.log1p
@mlx_funcify_Elemwise_scalar_op.register(Sin)
def mlx_funcify_Elemwise_scalar_Sin(scalar_op):
return mx.sin
@mlx_funcify_Elemwise_scalar_op.register(Cos)
def mlx_funcify_Elemwise_scalar_Cos(scalar_op):
return mx.cos
@mlx_funcify_Elemwise_scalar_op.register(Sqrt)
def mlx_funcify_Elemwise_scalar_Sqrt(scalar_op):
return mx.sqrt
@mlx_funcify_Elemwise_scalar_op.register(Sqr)
def mlx_funcify_Elemwise_scalar_Sqr(scalar_op):
return mx.square
@mlx_funcify_Elemwise_scalar_op.register(Abs)
def mlx_funcify_Elemwise_scalar_Abs(scalar_op):
return mx.abs
@mlx_funcify_Elemwise_scalar_op.register(Neg)
def mlx_funcify_Elemwise_scalar_Neg(scalar_op):
return mx.negative
@mlx_funcify_Elemwise_scalar_op.register(Sign)
def mlx_funcify_Elemwise_scalar_Sign(scalar_op):
return mx.sign
@mlx_funcify_Elemwise_scalar_op.register(LE)
def mlx_funcify_Elemwise_scalar_LE(scalar_op):
return mx.less_equal
@mlx_funcify_Elemwise_scalar_op.register(LT)
def mlx_funcify_Elemwise_scalar_LT(scalar_op):
return mx.less
@mlx_funcify_Elemwise_scalar_op.register(GE)
def mlx_funcify_Elemwise_scalar_GE(scalar_op):
return mx.greater_equal
@mlx_funcify_Elemwise_scalar_op.register(GT)
def mlx_funcify_Elemwise_scalar_GT(scalar_op):
return mx.greater
@mlx_funcify_Elemwise_scalar_op.register(EQ)
def mlx_funcify_Elemwise_scalar_EQ(scalar_op):
return mx.equal
@mlx_funcify_Elemwise_scalar_op.register(NEQ)
def mlx_funcify_Elemwise_scalar_NEQ(scalar_op):
return mx.not_equal
@mlx_funcify_Elemwise_scalar_op.register(Switch)
def mlx_funcify_Elemwise_scalar_Switch(scalar_op):
return mx.where
@mlx_funcify_Elemwise_scalar_op.register(AND)
def mlx_funcify_Elemwise_scalar_AND(scalar_op):
return mx.bitwise_and
@mlx_funcify_Elemwise_scalar_op.register(OR)
def mlx_funcify_Elemwise_scalar_OR(scalar_op):
return mx.bitwise_or
@mlx_funcify_Elemwise_scalar_op.register(ScalarMaximum)
def mlx_funcify_Elemwise_scalar_ScalarMaximum(scalar_op):
return mx.maximum
@mlx_funcify_Elemwise_scalar_op.register(ScalarMinimum)
def mlx_funcify_Elemwise_scalar_ScalarMinimum(scalar_op):
return mx.minimum
@mlx_funcify_Elemwise_scalar_op.register(Cast)
def mlx_funcify_Elemwise_scalar_Cast(scalar_op):
def cast(x):
dtype = convert_dtype_to_mlx(scalar_op.o_type.dtype)
try:
return x.astype(dtype)
except ValueError as e:
if "is not supported on the GPU" in str(e):
import warnings
warnings.warn(
f"MLX GPU limitation: {e}. Attempting automatic fallback casting.",
UserWarning,
stacklevel=2,
)
fallback_dtype = convert_dtype_to_mlx(
scalar_op.o_type.dtype, auto_cast_unsupported=True
)
return x.astype(fallback_dtype)
else:
raise e
return cast
@mlx_funcify_Elemwise_scalar_op.register(Sigmoid)
def mlx_funcify_Elemwise_scalar_Sigmoid(scalar_op):
return mx.sigmoid
@mlx_funcify_Elemwise_scalar_op.register(Invert)
def mlx_funcify_Elemwise_scalar_Invert(scalar_op):
return mx.bitwise_invert
@mlx_funcify_Elemwise_scalar_op.register(IsNan)
def mlx_funcify_Elemwise_scalar_IsNan(scalar_op):
return mx.isnan
@mlx_funcify_Elemwise_scalar_op.register(IsInf)
def mlx_funcify_Elemwise_scalar_IsInf(scalar_op):
return mx.isinf
@mlx_funcify_Elemwise_scalar_op.register(Erfc)
def mlx_funcify_Elemwise_scalar_Erfc(scalar_op):
def erfc(x):
return 1.0 - mx.erf(x)
return erfc
@mlx_funcify_Elemwise_scalar_op.register(Erfcx)
def mlx_funcify_Elemwise_scalar_Erfcx(scalar_op):
def erfcx(x):
return mx.exp(x * x) * (1.0 - mx.erf(x))
return erfcx
@mlx_funcify_Elemwise_scalar_op.register(Softplus)
def mlx_funcify_Elemwise_scalar_softplus(scalar_op):
def softplus(x):
# Numerically stable implementation of log(1 + exp(x))
# Following the same logic as the original PyTensor implementation
return mx.where(
x < -37.0,
mx.exp(x),
mx.where(
x < 18.0, mx.log1p(mx.exp(x)), mx.where(x < 33.3, x + mx.exp(-x), x)
),
)
return softplus
@mlx_funcify.register(Elemwise)
def mlx_funcify_Elemwise(op, node, **kwargs):
return mlx_funcify_Elemwise_scalar_op(op.scalar_op)
import mlx.core as mx
from pytensor.link.mlx.dispatch import mlx_funcify
from pytensor.tensor.math import Argmax, Dot, Max
@mlx_funcify.register(Dot)
def mlx_funcify_Dot(op, node=None, **kwargs):
def dot(x, y):
return mx.matmul(x, y)
return dot
@mlx_funcify.register(Max)
def mlx_funcify_Max(op, node=None, **kwargs):
def max_fn(x):
axes = op.axis
if axes is None:
reduce_axes = None
else:
reduce_axes = tuple(int(ax) for ax in axes)
keepdims = getattr(op, "keepdims", False)
return mx.max(x, axis=reduce_axes, keepdims=keepdims)
return max_fn
@mlx_funcify.register(Argmax)
def mlx_funcify_Argmax(op, node=None, **kwargs):
axis = op.axis
def argmax_fn(x):
if axis is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axis)
keep_axes = [i for i in range(x.ndim) if i not in axes]
transposed_x = mx.transpose(x, tuple(keep_axes + list(axes)))
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]
flat_size = 1
for dim in reduced_shape:
flat_size *= int(dim)
reshaped_x = transposed_x.reshape((*kept_shape, flat_size))
max_idx = mx.argmax(reshaped_x, axis=-1)
result = max_idx.astype(mx.int64)
if getattr(op, "keepdims", False):
reshape_shape = []
keep_iter = iter(kept_shape)
axis_iter = iter(sorted(axes))
next_axis = next(axis_iter, None)
for dim_idx in range(x.ndim):
if next_axis is not None and dim_idx == next_axis:
reshape_shape.append(1)
next_axis = next(axis_iter, None)
else:
reshape_shape.append(int(next(keep_iter)))
return result.reshape(tuple(reshape_shape))
return result
return argmax_fn
import mlx.core as mx
from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@mlx_funcify.register(Shape)
def mlx_funcify_Shape(op, **kwargs):
def shape(x):
return mx.array(x.shape, dtype=mx.int64)
return shape
@mlx_funcify.register(SpecifyShape)
def mlx_funcify_SpecifyShape(op, node, **kwargs):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
for actual, expected in zip(x.shape, shape, strict=True):
if expected is None:
continue
if actual != expected:
raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}")
return x
return specifyshape
@mlx_funcify.register(Shape_i)
def mlx_funcify_Shape_i(op, node, **kwargs):
def shape_i(x):
return x.shape[op.i]
return shape_i
@mlx_funcify.register(Reshape)
def mlx_funcify_Reshape(op, **kwargs):
def reshape(x, shp):
return mx.reshape(x, shp)
return reshape
import mlx.core as mx
from pytensor.link.mlx.dispatch import mlx_funcify, mlx_typify
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.signal.conv import Convolve1d
@mlx_funcify.register(Convolve1d)
def mlx_funcify_Convolve1d(op, node, **kwargs):
_, _, full_mode_var = node.inputs
try:
full_mode = bool(get_underlying_scalar_constant_value(full_mode_var))
runtime_mode_static = True
except NotScalarConstantError:
full_mode = True
runtime_mode_static = False
def conv1d(raw_data, raw_kernel, runtime_full_mode):
data = mlx_typify(raw_data, dtype=None)
kernel = mlx_typify(raw_kernel, dtype=None)
if runtime_mode_static:
runtime_mode = full_mode
else:
runtime_full_mode = mx.array(runtime_full_mode)
runtime_mode = bool(runtime_full_mode.reshape(-1)[0])
mode = "full" if runtime_mode else "valid"
return mx.convolve(data, kernel, mode=mode)
return conv1d
from copy import deepcopy
from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice
@mlx_funcify.register(Subtensor)
def mlx_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists):
indices = indices_from_subtensor([int(element) for element in ilists], idx_list)
if len(indices) == 1:
indices = indices[0]
return x.__getitem__(indices)
return subtensor
@mlx_funcify.register(AdvancedSubtensor)
@mlx_funcify.register(AdvancedSubtensor1)
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def advanced_subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
if len(indices) == 1:
indices = indices[0]
return x.__getitem__(indices)
return advanced_subtensor
@mlx_funcify.register(IncSubtensor)
@mlx_funcify.register(AdvancedIncSubtensor1)
def mlx_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False):
def mlx_fn(x, indices, y):
if not op.inplace:
x = deepcopy(x)
x[indices] = y
return x
else:
def mlx_fn(x, indices, y):
if not op.inplace:
x = deepcopy(x)
x[indices] += y
return x
def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
return mlx_fn(x, indices, y)
return incsubtensor
@mlx_funcify.register(AdvancedIncSubtensor)
def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):
def mlx_fn(x, indices, y):
if not op.inplace:
x = deepcopy(x)
x[indices] = y
return x
else:
def mlx_fn(x, indices, y):
if not op.inplace:
x = deepcopy(x)
x[indices] += y
return x
def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn):
return mlx_fn(x, ilist, y)
return advancedincsubtensor
@mlx_funcify.register(MakeSlice)
def mlx_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)
return makeslice
from pytensor.link.basic import JITLinker
class MLXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX."""
def __init__(self, use_compile=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gen_functors = []
self.use_compile = use_compile
def fgraph_convert(self, fgraph, **kwargs):
"""Convert a PyTensor FunctionGraph to an MLX-compatible function.
Parameters
----------
fgraph : FunctionGraph
The function graph to convert
Returns
-------
callable
An MLX-compatible function
"""
from pytensor.link.mlx.dispatch import mlx_funcify
return mlx_funcify(
fgraph,
**kwargs,
)
def jit_compile(self, fn):
import mlx.core as mx
from pytensor.link.mlx.dispatch import mlx_typify
if not self.use_compile:
# Skip compilation and just return the function with MLX typification
def fn_no_compile(*inputs):
return fn(*(mlx_typify(inp) for inp in inputs))
return fn_no_compile
inner_fn = mx.compile(fn)
def fn(*inputs, inner_fn=inner_fn):
return inner_fn(*(mlx_typify(inp) for inp in inputs))
return fn
def create_thunk_inputs(self, storage_map):
"""Create inputs for the MLX thunk.
Parameters
----------
storage_map : dict
Map from variables to their storage
Returns
-------
list
The inputs for the thunk
"""
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
thunk_inputs.append(sinput)
return thunk_inputs
...@@ -31,13 +31,15 @@ class PytorchLinker(JITLinker): ...@@ -31,13 +31,15 @@ class PytorchLinker(JITLinker):
**kwargs, **kwargs,
} }
return pytorch_funcify( return pytorch_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **built_kwargs fgraph,
input_storage=input_storage,
storage_map=storage_map,
**built_kwargs,
) )
def jit_compile(self, fn): def jit_compile(self, fn):
import torch import torch
# flag that tend to help our graphs
torch._dynamo.config.capture_dynamic_output_shape_ops = True torch._dynamo.config.capture_dynamic_output_shape_ops = True
from pytensor.link.pytorch.dispatch import pytorch_typify from pytensor.link.pytorch.dispatch import pytorch_typify
......
"""
Basic tests for the MLX backend.
"""
from collections.abc import Callable, Iterable
from functools import partial
import numpy as np
import pytest
import pytensor
from pytensor import tensor as pt
from pytensor.compile.function import function
from pytensor.compile.mode import MLX, Mode
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Variable
from pytensor.link.mlx import MLXLinker
from pytensor.raise_op import assert_op
mx = pytest.importorskip("mlx.core")
optimizer = RewriteDatabaseQuery(include=["mlx"], exclude=MLX._optimizer.exclude)
mlx_mode = Mode(linker=MLXLinker(), optimizer=optimizer)
mlx_mode_no_compile = Mode(linker=MLXLinker(use_compile=False), optimizer=optimizer)
compile_mode = Mode(linker=MLXLinker(use_compile=True), optimizer=optimizer)
py_mode = Mode(linker="py", optimizer=None)
def compare_mlx_and_py(
graph_inputs: Iterable[Variable],
graph_outputs: Variable | Iterable[Variable],
test_inputs: Iterable,
*,
assert_fn: Callable | None = None,
must_be_device_array: bool = True,
mlx_mode=mlx_mode,
py_mode=py_mode,
):
"""Function to compare python function output and mlx compiled output for testing equality
The inputs and outputs are then passed to this function which then compiles the given function in both
mlx and python, runs the calculation in both and checks if the results are the same
Parameters
----------
graph_inputs:
Symbolic inputs to the graph
outputs:
Symbolic outputs of the graph
test_inputs: iter
Numerical inputs for testing the function.
assert_fn: func, opt
Assert function used to check for equality between python and mlx. If not
provided uses np.testing.assert_allclose
must_be_device_array: Bool
Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes
if this device array is found it indicates if the result was computed by jax
Returns
-------
mlx_res
"""
if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
if any(inp.owner is not None for inp in graph_inputs):
raise ValueError("Inputs must be root variables")
pytensor_mlx_fn = function(graph_inputs, graph_outputs, mode=mlx_mode)
mlx_res = pytensor_mlx_fn(*test_inputs)
if must_be_device_array:
if isinstance(mlx_res, list):
assert all(isinstance(res, mx.array) for res in mlx_res)
else:
assert isinstance(mlx_res, mx.array)
pytensor_py_fn = function(graph_inputs, graph_outputs, mode=py_mode)
py_res = pytensor_py_fn(*test_inputs)
if isinstance(graph_outputs, list | tuple):
for j, p in zip(mlx_res, py_res, strict=True):
assert_fn(j, p)
else:
assert_fn(mlx_res, py_res)
return pytensor_mlx_fn, mlx_res
def test_scalar_from_tensor_matrix_indexing():
"""Test ScalarFromTensor with matrix element extraction."""
# Matrix element extraction is a common real-world scenario
matrix = pt.matrix("matrix", dtype="float32")
element = matrix[0, 0] # Creates 0-d tensor
f = pytensor.function([matrix], element, mode="MLX")
test_matrix = np.array([[42.0, 1.0], [2.0, 3.0]], dtype=np.float32)
result = f(test_matrix)
assert float(result) == 42.0
assert isinstance(result, mx.array)
def test_scalar_from_tensor_reduction_operations():
"""Test ScalarFromTensor with reduction operations that produce scalars."""
# Test vector sum reduction
vector = pt.vector("vector", dtype="float32")
sum_result = pt.sum(vector)
f = pytensor.function([vector], sum_result, mode="MLX")
test_vector = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
result = f(test_vector)
assert float(result) == 10.0
# Test matrix mean reduction
matrix = pt.matrix("matrix", dtype="float32")
mean_result = pt.mean(matrix)
f2 = pytensor.function([matrix], mean_result, mode="MLX")
test_matrix = np.array([[2.0, 4.0], [6.0, 8.0]], dtype=np.float32)
result = f2(test_matrix)
assert float(result) == 5.0
def test_scalar_from_tensor_conditional_operations():
"""Test ScalarFromTensor with conditional operations."""
x = pt.scalar("x", dtype="float32")
y = pt.scalar("y", dtype="float32")
# Switch operation may create 0-d tensors
max_val = pt.switch(x > y, x, y)
f = pytensor.function([x, y], max_val, mode="MLX")
# Test both branches
result1 = f(5.0, 3.0)
assert float(result1) == 5.0
result2 = f(2.0, 7.0)
assert float(result2) == 7.0
def test_scalar_from_tensor_multiple_dtypes():
"""Test ScalarFromTensor with different data types."""
# Test different dtypes that might require scalar extraction
for dtype in ["float32", "int32", "int64"]:
x = pt.vector("x", dtype=dtype)
# Use max reduction to create 0-d tensor
max_val = pt.max(x)
f = pytensor.function([x], max_val, mode="MLX", allow_input_downcast=True)
if dtype.startswith("float"):
test_data = np.array([1.5, 3.7, 2.1], dtype=dtype)
expected = 3.7
else:
test_data = np.array([10, 30, 20], dtype=dtype)
expected = 30
result = f(test_data)
assert abs(float(result) - expected) < 1e-5
def test_scalar_from_tensor_pytensor_integration():
"""Test ScalarFromTensor in a complete PyTensor graph context.
This test uses symbolic variables (not constants) to ensure the MLX backend
actually executes the ScalarFromTensor operation rather than having it
optimized away during compilation.
"""
# Create a symbolic scalar input to actually test MLX execution
x = pt.scalar("x", dtype="int64")
# Apply ScalarFromTensor - this creates a graph that forces execution
scalar_result = pt.scalar_from_tensor(x)
# Create function and test with actual MLX backend execution
f = pytensor.function([x], scalar_result, mode="MLX")
result = f(42)
assert result == 42
assert isinstance(result, mx.array)
def test_mlx_float64_auto_casting():
"""Test MLX automatic casting of float64 to float32 with warnings."""
import warnings
# Test 1: Direct Cast operation with warning
x = pt.scalar("x", dtype="float32")
y = pt.cast(x, "float64")
# Capture warnings
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
f = pytensor.function([x], y, mode=mlx_mode, allow_input_downcast=True)
result = f(3.14)
# Check that the operation succeeded
assert result.dtype == mx.float32 # Should be auto-cast to float32
assert abs(float(result) - 3.14) < 1e-6
# Check that a warning was issued
warning_messages = [str(w.message) for w in warning_list]
dtype_warnings = [
msg for msg in warning_messages if "float64" in msg and "float32" in msg
]
assert (
len(dtype_warnings) > 0
), f"Expected dtype warning, got warnings: {warning_messages}"
def test_mlx_float64_complex_operations():
"""Test float64 casting in more complex operations."""
import warnings
# Test with vector operations
x = pt.vector("x", dtype="float32")
y = pt.cast(x, "float64")
z = pt.exp(y) + pt.sin(y) # Multiple operations on float64
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
f = pytensor.function([x], z, mode=mlx_mode, allow_input_downcast=True)
result = f([1.0, 2.0, 3.0])
# Should work and return float32 results
assert result.dtype == mx.float32
assert result.shape == (3,)
# Should have issued warnings
warning_messages = [str(w.message) for w in warning_list]
dtype_warnings = [
msg
for msg in warning_messages
if "float64" in msg or "MLX GPU limitation" in msg
]
assert len(dtype_warnings) > 0
def test_mlx_float64_no_warning_when_disabled():
"""Test that auto-casting can be controlled."""
import warnings
from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx
# Test that we can disable auto-casting
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
# This should not issue warnings when auto_cast_unsupported=False
dtype = convert_dtype_to_mlx("float64", auto_cast_unsupported=False)
assert dtype == mx.float64 # Should return the original dtype
# No warnings should be issued for proactive conversion when disabled
dtype_warnings = [
str(w.message) for w in warning_list if "float64" in str(w.message)
]
assert len(dtype_warnings) == 0
def test_mlx_complex128_auto_casting():
"""Test automatic casting of complex128 to complex64."""
import warnings
from pytensor.link.mlx.dispatch.core import convert_dtype_to_mlx
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
# This should trigger a warning and return complex64
dtype = convert_dtype_to_mlx("complex128", auto_cast_unsupported=True)
assert dtype == mx.complex64
# Should have issued a warning
warning_messages = [str(w.message) for w in warning_list]
complex_warnings = [
msg
for msg in warning_messages
if "complex128" in msg and "complex64" in msg
]
assert len(complex_warnings) > 0
def test_mlx_checkandraise_constant_false():
x = pt.scalar("x", dtype="float32")
res = assert_op(x, pt.as_tensor_variable(np.array(False)))
with pytest.warns(UserWarning, match=r"Skipping `Assert` Op"):
mlx_fn = function([x], res, mode=mlx_mode)
out = mlx_fn(np.array(0.5, dtype=np.float32))
assert isinstance(out, mx.array)
assert np.allclose(out, 0.5)
def test_mlx_checkandraise_warning_and_execution():
p = pt.scalar("p", dtype="float32")
res = assert_op(p, p < 1.0)
with pytest.warns(UserWarning, match=r"Skipping `Assert` Op"):
mlx_fn = function([p], res, mode=mlx_mode)
out = mlx_fn(np.array(0.5, dtype=np.float32))
assert isinstance(out, mx.array)
assert np.allclose(out, 0.5)
import numpy as np
import pytensor.tensor as pt
from pytensor.tensor import tensor
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot
from tests.link.mlx.test_basic import compare_mlx_and_py
# Equivalent blockwise to matmul but with dumb signature
odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)")
def test_blockwise_conv1d():
rng = np.random.default_rng(14)
a = tensor("a", shape=(2, 100))
b = tensor("b", shape=(2, 8))
a_test = rng.normal(size=(2, 100))
b_test = rng.normal(size=(2, 8))
test_values = [a_test, b_test]
out = pt.signal.convolve1d(a, b, mode="valid")
# assert isinstance(out.owner.op, Blockwise)
compare_mlx_and_py([a, b], [out], test_values, must_be_device_array=True)
import numpy as np
import pytest
import pytensor
from pytensor import tensor as pt
from pytensor.tensor.basic import Alloc
from tests.link.mlx.test_basic import compile_mode, mlx_mode_no_compile, mx
def test_alloc_with_different_shape_types():
"""Test Alloc works with different types of shape parameters.
This addresses the TypeError that occurred when shape parameters
contained MLX arrays instead of Python integers.
"""
from pytensor.link.mlx.dispatch.core import (
mlx_funcify_Alloc,
)
# Create a mock node (we don't need a real node for this test)
class MockNode:
def __init__(self):
self.op = Alloc()
self.inputs = None
self.outputs = None
alloc_func = mlx_funcify_Alloc(Alloc(), MockNode())
x = mx.array(5.0)
# Test with Python ints
result = alloc_func(x, 3, 4)
assert result.shape == (3, 4)
assert float(result[0, 0]) == 5.0
# Test with MLX arrays (this used to fail)
result = alloc_func(x, mx.array(3), mx.array(4))
assert result.shape == (3, 4)
assert float(result[0, 0]) == 5.0
# Test with mixed types
result = alloc_func(x, 3, mx.array(4))
assert result.shape == (3, 4)
assert float(result[0, 0]) == 5.0
def test_alloc_pytensor_integration():
"""Test Alloc in a PyTensor graph context."""
# Test basic constant shape allocation
x = pt.scalar("x", dtype="float32")
result = pt.alloc(x, 3, 4)
f = pytensor.function([x], result, mode="MLX")
output = f(5.0)
assert output.shape == (3, 4)
assert float(output[0, 0]) == 5.0
def test_alloc_compilation_limitation():
"""Test that Alloc operations with dynamic shapes provide helpful error in compiled contexts."""
# Create variables
x = pt.scalar("x", dtype="float32")
s1 = pt.scalar("s1", dtype="int64")
s2 = pt.scalar("s2", dtype="int64")
# Create Alloc operation with dynamic shapes
result = pt.alloc(x, s1, s2)
# Create function with non-compiled MLX mode
f = pytensor.function([x, s1, s2], result, mode=mlx_mode_no_compile)
# Test that it works with concrete values (non-compiled context)
output = f(5.0, 3, 4)
assert output.shape == (3, 4)
np.testing.assert_allclose(output, 5.0)
# Test that compilation fails with helpful error
compiled_f = pytensor.function([x, s1, s2], result, mode=compile_mode)
with pytest.raises(
ValueError,
match="MLX compilation limitation: Alloc operations with dynamic shapes cannot be "
"used inside compiled functions",
):
compiled_f(5.0, 3, 4)
def test_alloc_static_shapes_compilation():
"""Test that Alloc operations with static shapes work fine in compiled contexts."""
# Create a scenario with static shapes that should work
x = pt.scalar("x", dtype="float32")
# Use constant shape - this should work even in compilation
result = pt.alloc(x, 3, 4) # Static shapes
# Test both compiled and non-compiled modes
f_normal = pytensor.function([x], result, mode=mlx_mode_no_compile)
f_compiled = pytensor.function([x], result, mode=compile_mode)
# Both should work
output_normal = f_normal(5.0)
output_compiled = f_compiled(5.0)
assert output_normal.shape == (3, 4)
assert output_compiled.shape == (3, 4)
np.testing.assert_allclose(output_normal, 5.0)
np.testing.assert_allclose(output_compiled, 5.0)
np.testing.assert_allclose(output_normal, output_compiled)
def test_empty_static_shape():
result = pt.empty((3, 4), dtype="float32")
f = pytensor.function([], result, mode="MLX")
output = f()
assert output.shape == (3, 4)
np.testing.assert_allclose(output, 0.0)
def test_empty_dynamic_shape():
s1 = pt.scalar("s1", dtype="int64")
s2 = pt.scalar("s2", dtype="int64")
result = pt.empty((s1, s2), dtype="float32")
f = pytensor.function([s1, s2], result, mode=mlx_mode_no_compile)
output = f(3, 4)
assert output.shape == (3, 4)
np.testing.assert_allclose(output, 0.0)
f_compiled = pytensor.function([s1, s2], result, mode=compile_mode)
with pytest.raises(
ValueError,
match="MLX compilation limitation: Alloc operations with dynamic shapes cannot be "
"used inside compiled functions",
):
f_compiled(3, 4)
import numpy as np
import pytest
import scipy
from pytensor import config, function
from pytensor.tensor.basic import switch
from pytensor.tensor.math import (
add,
cos,
eq,
exp,
ge,
gt,
int_div,
isinf,
le,
log,
lt,
mul,
neq,
power,
prod,
sigmoid,
sin,
sub,
true_div,
)
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import any as pt_any
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.math import min as pt_min
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.special import SoftmaxGrad, softmax
from pytensor.tensor.type import matrix, vector, vectors
from tests.link.mlx.test_basic import compare_mlx_and_py
mx = pytest.importorskip("mlx.core")
@pytest.mark.parametrize("op", [pt_any, pt_all, pt_max, pt_min])
def test_input(op) -> None:
x = vector("x")
out = op(x > 0)
x_test = mx.array([1.0, 2.0, 3.0])
compare_mlx_and_py([x], out, [x_test])
def test_mlx_CAReduce():
a_pt = vector("a")
a_pt.tag.test_value = np.r_[1, 2, 3].astype(config.floatX)
x = pt_sum(a_pt, axis=None)
compare_mlx_and_py([a_pt], [x], [np.r_[1, 2, 3].astype(config.floatX)])
a_pt = matrix("a")
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
x = pt_sum(a_pt, axis=0)
compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
x = pt_sum(a_pt, axis=1)
compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
a_pt = matrix("a")
a_pt.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
x = prod(a_pt, axis=0)
compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
x = pt_all(a_pt)
compare_mlx_and_py([a_pt], [x], [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis):
x = matrix("x")
x_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = softmax(x, axis=axis)
compare_mlx_and_py([x], [out], [x_test_value])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax_grad(axis):
dy = matrix("dy")
dy_test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
sm = matrix("sm")
sm_test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = SoftmaxGrad(axis=axis)(dy, sm)
compare_mlx_and_py([dy, sm], [out], [dy_test_value, sm_test_value])
@pytest.mark.parametrize("size", [(10, 10), (1000, 1000)])
@pytest.mark.parametrize("axis", [0, 1])
def test_logsumexp_benchmark(size, axis, benchmark):
X = matrix("X")
X_max = pt_max(X, axis=axis, keepdims=True)
X_max = switch(isinf(X_max), 0, X_max)
X_lse = log(pt_sum(exp(X - X_max), axis=axis, keepdims=True)) + X_max
rng = np.random.default_rng(23920)
X_val = rng.normal(size=size)
X_lse_fn = function([X], X_lse, mode="MLX")
# JIT compile first
_ = X_lse_fn(X_val)
res = benchmark(X_lse_fn, X_val)
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res)
def test_multiple_input_multiply():
x, y, z = vectors("xyz")
out = mul(x, y, z)
compare_mlx_and_py([x, y, z], [out], test_inputs=[[1.5], [2.5], [3.5]])
@pytest.mark.parametrize(
"op",
[
pytest.param(exp, id="exp"),
pytest.param(log, id="log"),
pytest.param(sin, id="sin"),
pytest.param(cos, id="cos"),
pytest.param(sigmoid, id="sigmoid"),
],
)
def test_elemwise_one_input(op) -> None:
x = vector("x")
out = op(x)
x_test = mx.array([1.0, 2.0, 3.0])
compare_mlx_and_py([x], out, [x_test])
@pytest.mark.parametrize(
"op",
[
add,
sub,
mul,
power,
le,
lt,
ge,
gt,
eq,
neq,
true_div,
int_div,
],
ids=[
"add",
"sub",
"mul",
"power",
"le",
"lt",
"ge",
"gt",
"eq",
"neq",
"true_div",
"int_div",
],
)
def test_elemwise_two_inputs(op) -> None:
x = vector("x")
y = vector("y")
out = op(x, y)
x_test = mx.array([1.0, 2.0, 3.0])
y_test = mx.array([4.0, 5.0, 6.0])
compare_mlx_and_py([x, y], out, [x_test, y_test])
import numpy as np
import pytest
import pytensor
import pytensor.tensor as pt
from pytensor.tensor.math import Argmax, Max
from tests.link.mlx.test_basic import compare_mlx_and_py
mx = pytest.importorskip("mlx.core")
def test_dot():
x = pt.matrix("x")
y = pt.matrix("y")
out = x.dot(y)
fn = pytensor.function([x, y], out, mode="MLX")
seed = sum(map(ord, "test_mlx_dot"))
rng = np.random.default_rng(seed)
test_x = rng.normal(size=(3, 2))
test_y = rng.normal(size=(2, 4))
actual = fn(test_x, test_y)
assert isinstance(actual, mx.array)
expected = np.dot(test_x, test_y)
np.testing.assert_allclose(actual, expected, rtol=1e-6)
def test_switch() -> None:
x = pt.vector("x")
y = pt.vector("y")
out = pt.switch(x > 0, y, x)
x_test = mx.array([-1.0, 2.0, 3.0])
y_test = mx.array([4.0, 5.0, 6.0])
compare_mlx_and_py([x, y], out, [x_test, y_test])
def test_int_div_specific() -> None:
x = pt.vector("x")
y = pt.vector("y")
out = pt.int_div(x, y)
# Test with integers that demonstrate floor division behavior
x_test = mx.array([7.0, 8.0, 9.0, -7.0, -8.0])
y_test = mx.array([3.0, 3.0, 3.0, 3.0, 3.0])
compare_mlx_and_py([x, y], out, [x_test, y_test])
def test_isnan() -> None:
x = pt.vector("x")
out = pt.isnan(x)
x_test = mx.array([1.0, np.nan, 3.0, np.inf, -np.nan, 0.0, -np.inf])
compare_mlx_and_py([x], out, [x_test])
def test_isnan_edge_cases() -> None:
x = pt.scalar("x")
out = pt.isnan(x)
# Test individual cases
test_cases = [0.0, np.nan, np.inf, -np.inf, 1e-10, 1e10]
for test_val in test_cases:
x_test = test_val
compare_mlx_and_py([x], out, [x_test])
def test_erfc() -> None:
"""Test complementary error function"""
x = pt.vector("x")
out = pt.erfc(x)
# Test with various values including negative, positive, and zero
x_test = mx.array([0.0, 0.5, 1.0, -0.5, -1.0, 2.0, -2.0, 0.1])
compare_mlx_and_py([x], out, [x_test])
def test_erfc_extreme_values() -> None:
"""Test erfc with extreme values"""
x = pt.vector("x")
out = pt.erfc(x)
# Test with larger values where erfc approaches 0 or 2
x_test = mx.array([-3.0, -2.5, 2.5, 3.0])
# Use relaxed tolerance for extreme values due to numerical precision differences
from functools import partial
relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-3, atol=1e-6)
compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)
def test_erfcx() -> None:
"""Test scaled complementary error function"""
x = pt.vector("x")
out = pt.erfcx(x)
# Test with positive values where erfcx is most numerically stable
x_test = mx.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5])
compare_mlx_and_py([x], out, [x_test])
def test_erfcx_small_values() -> None:
"""Test erfcx with small values"""
x = pt.vector("x")
out = pt.erfcx(x)
# Test with small values
x_test = mx.array([0.001, 0.01, 0.1, 0.2])
compare_mlx_and_py([x], out, [x_test])
def test_softplus() -> None:
"""Test softplus (log(1 + exp(x))) function"""
x = pt.vector("x")
out = pt.softplus(x)
# Test with normal range values
x_test = mx.array([0.0, 1.0, 2.0, -1.0, -2.0, 10.0])
compare_mlx_and_py([x], out, [x_test])
def test_softplus_extreme_values() -> None:
"""Test softplus with extreme values to verify numerical stability"""
x = pt.vector("x")
out = pt.softplus(x)
# Test with extreme values where different branches of the implementation are used
x_test = mx.array([-40.0, -50.0, 20.0, 30.0, 35.0, 50.0])
# Use relaxed tolerance for extreme values due to numerical precision differences
from functools import partial
relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-4, atol=1e-8)
compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)
def test_mlx_max_and_argmax():
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
x = pt.dvector()
mx = Max([0])(x)
amx = Argmax([0])(x)
out = mx * amx
compare_mlx_and_py([x], [out], [np.r_[1, 2]])
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config
from pytensor.tensor.shape import Shape, Shape_i, reshape
from pytensor.tensor.type import iscalar, vector
from tests.link.mlx.test_basic import compare_mlx_and_py
def test_mlx_shape_ops():
x_np = np.zeros((20, 3))
x = Shape()(pt.as_tensor_variable(x_np))
compare_mlx_and_py([], [x], [], must_be_device_array=False)
x = Shape_i(1)(pt.as_tensor_variable(x_np))
compare_mlx_and_py([], [x], [], must_be_device_array=False)
def test_mlx_specify_shape():
in_pt = pt.matrix("in")
x = pt.specify_shape(in_pt, (4, None))
compare_mlx_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)])
# When used to assert two arrays have similar shapes
in_pt = pt.matrix("in")
shape_pt = pt.matrix("shape")
x = pt.specify_shape(in_pt, shape_pt.shape)
compare_mlx_and_py(
[in_pt, shape_pt],
[x],
[np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
)
def test_mlx_Reshape_constant():
a = vector("a")
x = reshape(a, (2, 2))
compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
def test_mlx_Reshape_various_shapes():
"""Test reshape with various different shapes to ensure robustness."""
# 1D to 2D
a = vector("a")
x = reshape(a, (2, 3))
compare_mlx_and_py([a], [x], [np.arange(6, dtype=config.floatX)])
# 2D to 1D
b = pt.matrix("b")
y = reshape(b, (6,))
compare_mlx_and_py([b], [y], [np.arange(6, dtype=config.floatX).reshape(2, 3)])
# 2D to 3D
c = pt.matrix("c")
z = reshape(c, (2, 2, 3))
compare_mlx_and_py([c], [z], [np.arange(12, dtype=config.floatX).reshape(4, 3)])
# 3D to 2D
d = pt.tensor3("d")
w = reshape(d, (3, 4))
compare_mlx_and_py([d], [w], [np.arange(12, dtype=config.floatX).reshape(2, 2, 3)])
def test_mlx_Reshape_negative_one():
"""Test reshape with -1 dimension (infer dimension)."""
a = vector("a")
# Use -1 to infer the second dimension
x = reshape(a, (2, -1))
compare_mlx_and_py([a], [x], [np.arange(8, dtype=config.floatX)])
# Use -1 to infer the first dimension
y = reshape(a, (-1, 4))
compare_mlx_and_py([a], [y], [np.arange(8, dtype=config.floatX)])
def test_mlx_Reshape_concrete_shape():
"""MLX should compile when a concrete value is passed for the `shape` parameter."""
a = vector("a")
x = reshape(a, a.shape)
compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2))
compare_mlx_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
@pytest.mark.xfail(reason="`shape_pt` should be specified as a static argument")
def test_mlx_Reshape_shape_graph_input():
a = vector("a")
shape_pt = iscalar("b")
x = reshape(a, (shape_pt, shape_pt))
compare_mlx_and_py(
[a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]
)
@pytest.mark.xfail(reason="ViewOp Op is not supported yet")
def test_mlx_compile_ops():
x = DeepCopyOp()(pt.as_tensor_variable(1.1))
compare_mlx_and_py([], [x], [])
x_np = np.zeros((20, 1, 1))
x = ViewOp()(pt.as_tensor_variable(x_np))
compare_mlx_and_py([], [x], [])
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor.tensor import subtensor as pt_subtensor
from pytensor.tensor import tensor
from tests.link.mlx.test_basic import compare_mlx_and_py
mx = pytest.importorskip("mlx.core")
def test_mlx_Subtensor_basic():
"""Test basic subtensor operations with constant indices."""
shape = (3, 4, 5)
x_pt = tensor("x", shape=shape, dtype="float32")
x_np = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
# Basic indexing with single elements
out_pt = x_pt[1, 2, 0]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
# Basic indexing with slices
out_pt = x_pt[1:, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[1:2, 1, :]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
# Negative indexing
out_pt = x_pt[-1, -1, -1]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
# Step slicing
out_pt = x_pt[::2, ::2, ::2]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
# Reverse indexing
out_pt = x_pt[::-1, ::-1, ::-1]
assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
def test_mlx_AdvancedSubtensor():
"""Test advanced subtensor operations."""
shape = (3, 4, 5)
x_pt = tensor("x", shape=shape, dtype="float32")
x_np = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
# Advanced indexing with array indices
out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2])
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
# Multi-dimensional advanced indexing
out_pt = x_pt[[1, 2], [2, 3]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
# Mixed advanced and basic indexing
out_pt = x_pt[[1, 2], :]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
out_pt = x_pt[[1, 2], :, [3, 4]]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
@pytest.mark.xfail(
raises=ValueError, reason="MLX does not support boolean indexing yet"
)
def test_mlx_AdvancedSubtensor_boolean():
"""Test advanced subtensor operations with boolean indexing."""
shape = (3, 4, 5)
x_pt = tensor("x", shape=shape, dtype="float32")
x_np = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
# Boolean indexing with constant mask
bool_mask = np.array([True, False, True])
out_pt = x_pt[bool_mask]
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
compare_mlx_and_py([x_pt], [out_pt], [x_np])
def test_mlx_IncSubtensor_set():
"""Test set operations using IncSubtensor (set_instead_of_inc=True)."""
# Test data
x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
x_pt = pt.constant(x_np)
# Set single element
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=np.float32))
out_pt = pt_subtensor.set_subtensor(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
assert out_pt.owner.op.set_instead_of_inc
compare_mlx_and_py([], [out_pt], [])
def test_mlx_IncSubtensor_increment():
"""Test increment operations using IncSubtensor (set_instead_of_inc=False)."""
# Test data
x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
x_pt = pt.constant(x_np)
# Increment single element
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=np.float32))
out_pt = pt_subtensor.inc_subtensor(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
assert not out_pt.owner.op.set_instead_of_inc
compare_mlx_and_py([], [out_pt], [])
def test_mlx_AdvancedIncSubtensor_set():
"""Test advanced set operations using AdvancedIncSubtensor."""
rng = np.random.default_rng(213234)
# Test data
x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
x_pt = pt.constant(x_np)
# Set with advanced indexing - this actually works in MLX!
st_pt = pt.as_tensor_variable(rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32))
out_pt = pt_subtensor.set_subtensor(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
assert out_pt.owner.op.set_instead_of_inc
compare_mlx_and_py([], [out_pt], [])
def test_mlx_AdvancedIncSubtensor_increment():
"""Test advanced increment operations using AdvancedIncSubtensor."""
rng = np.random.default_rng(213234)
# Test data
x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
x_pt = pt.constant(x_np)
# Increment with advanced indexing - this actually works in MLX!
st_pt = pt.as_tensor_variable(rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32))
out_pt = pt_subtensor.inc_subtensor(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
assert not out_pt.owner.op.set_instead_of_inc
compare_mlx_and_py([], [out_pt], [])
def test_mlx_AdvancedIncSubtensor1_operations():
"""Test AdvancedIncSubtensor1 operations (handled by IncSubtensor dispatcher)."""
rng = np.random.default_rng(213234)
# Test data
x_np = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
x_pt = pt.constant(x_np)
# Test set operation - this actually works in MLX!
st_pt = pt.as_tensor_variable(rng.uniform(-1, 1, size=(2, 4, 5)).astype(np.float32))
indices = [1, 2]
# Create AdvancedIncSubtensor1 manually for set operation
out_pt = pt_subtensor.advanced_set_subtensor1(x_pt, st_pt, indices)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor1)
assert out_pt.owner.op.set_instead_of_inc
compare_mlx_and_py([], [out_pt], [])
@pytest.mark.xfail(reason="Inplace operations not yet supported in MLX mode")
def test_mlx_inplace_variants():
"""Test inplace variants of all subtensor operations."""
# Test data
x_np = np.arange(12, dtype=np.float32).reshape((3, 4))
x_pt = pt.constant(x_np)
# Test inplace IncSubtensor (set)
st_pt = pt.as_tensor_variable(np.array([-1.0, -2.0], dtype=np.float32))
out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], st_pt, inplace=True)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
assert out_pt.owner.op.inplace
assert out_pt.owner.op.set_instead_of_inc
compare_mlx_and_py([], [out_pt], [])
@pytest.mark.xfail(
reason="MLX slice indices must be integers or None, dynamic slices not supported"
)
def test_mlx_MakeSlice():
"""Test MakeSlice operation."""
# Test slice creation
start = pt.iscalar("start")
stop = pt.iscalar("stop")
step = pt.iscalar("step")
# Create a slice using MakeSlice
slice_op = pt_subtensor.MakeSlice()
slice_pt = slice_op(start, stop, step)
# Use simple constant array instead of arange
x_pt = pt.constant(np.arange(10, dtype=np.float32))
out_pt = x_pt[slice_pt]
compare_mlx_and_py([start, stop, step], [out_pt], [1, 8, 2])
def test_mlx_subtensor_edge_cases():
"""Test edge cases and boundary conditions."""
# Empty slices - use constant array
x_pt = pt.constant(np.arange(10, dtype=np.float32))
out_pt = x_pt[5:5] # Empty slice
compare_mlx_and_py([], [out_pt], [])
# Single element arrays
x_pt = pt.tensor(shape=(1,), dtype="float32", name="x")
x_np = np.array([42.0], dtype=np.float32)
out_pt = x_pt[0]
compare_mlx_and_py([x_pt], [out_pt], [x_np])
# Large step sizes - use constant array
x_pt = pt.constant(np.arange(20, dtype=np.float32))
out_pt = x_pt[::5]
compare_mlx_and_py([], [out_pt], [])
# Negative steps - use constant array
x_pt = pt.constant(np.arange(10, dtype=np.float32))
out_pt = x_pt[::-2]
compare_mlx_and_py([], [out_pt], [])
@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported")
def test_mlx_subtensor_with_variables():
"""Test subtensor operations with PyTensor variables as inputs."""
# Test with variable arrays (not constants)
x_pt = pt.matrix("x", dtype="float32")
y_pt = pt.vector("y", dtype="float32")
x_np = np.arange(12, dtype=np.float32).reshape((3, 4))
y_np = np.array([-1.0, -2.0], dtype=np.float32)
# Set operation with variables
out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt)
compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论