Unverified 提交 95c2a744 authored 作者: Faraz Shamim's avatar Faraz Shamim 提交者: GitHub

Move MLX tests from math to elemwise (#1748)

上级 0758de4b
......@@ -8,11 +8,14 @@ from pytensor.tensor.math import (
add,
cos,
eq,
erfc,
erfcx,
exp,
ge,
gt,
int_div,
isinf,
isnan,
le,
log,
lt,
......@@ -22,6 +25,7 @@ from pytensor.tensor.math import (
prod,
sigmoid,
sin,
softplus,
sub,
true_div,
)
......@@ -189,3 +193,126 @@ def test_elemwise_two_inputs(op) -> None:
x_test = mx.array([1.0, 2.0, 3.0])
y_test = mx.array([4.0, 5.0, 6.0])
compare_mlx_and_py([x, y], out, [x_test, y_test])
def test_switch() -> None:
x = vector("x")
y = vector("y")
out = switch(x > 0, y, x)
x_test = mx.array([-1.0, 2.0, 3.0])
y_test = mx.array([4.0, 5.0, 6.0])
compare_mlx_and_py([x, y], out, [x_test, y_test])
def test_int_div_specific() -> None:
x = vector("x")
y = vector("y")
out = int_div(x, y)
# Test with integers that demonstrate floor division behavior
x_test = mx.array([7.0, 8.0, 9.0, -7.0, -8.0])
y_test = mx.array([3.0, 3.0, 3.0, 3.0, 3.0])
compare_mlx_and_py([x, y], out, [x_test, y_test])
def test_isnan() -> None:
x = vector("x")
out = isnan(x)
x_test = mx.array([1.0, np.nan, 3.0, np.inf, -np.nan, 0.0, -np.inf])
compare_mlx_and_py([x], out, [x_test])
def test_isnan_edge_cases() -> None:
from pytensor.tensor.type import scalar
x = scalar("x")
out = isnan(x)
# Test individual cases
test_cases = [0.0, np.nan, np.inf, -np.inf, 1e-10, 1e10]
for test_val in test_cases:
x_test = test_val
compare_mlx_and_py([x], out, [x_test])
def test_erfc() -> None:
"""Test complementary error function"""
x = vector("x")
out = erfc(x)
# Test with various values including negative, positive, and zero
x_test = mx.array([0.0, 0.5, 1.0, -0.5, -1.0, 2.0, -2.0, 0.1])
compare_mlx_and_py([x], out, [x_test])
def test_erfc_extreme_values() -> None:
"""Test erfc with extreme values"""
from functools import partial
x = vector("x")
out = erfc(x)
# Test with larger values where erfc approaches 0 or 2
x_test = mx.array([-3.0, -2.5, 2.5, 3.0])
# Use relaxed tolerance for extreme values due to numerical precision differences
relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-3, atol=1e-6)
compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)
def test_erfcx() -> None:
"""Test scaled complementary error function"""
x = vector("x")
out = erfcx(x)
# Test with positive values where erfcx is most numerically stable
x_test = mx.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5])
compare_mlx_and_py([x], out, [x_test])
def test_erfcx_small_values() -> None:
"""Test erfcx with small values"""
x = vector("x")
out = erfcx(x)
# Test with small values
x_test = mx.array([0.001, 0.01, 0.1, 0.2])
compare_mlx_and_py([x], out, [x_test])
def test_softplus() -> None:
"""Test softplus (log(1 + exp(x))) function"""
x = vector("x")
out = softplus(x)
# Test with normal range values
x_test = mx.array([0.0, 1.0, 2.0, -1.0, -2.0, 10.0])
compare_mlx_and_py([x], out, [x_test])
def test_softplus_extreme_values() -> None:
"""Test softplus with extreme values to verify numerical stability"""
from functools import partial
x = vector("x")
out = softplus(x)
# Test with extreme values where different branches of the implementation are used
x_test = mx.array([-40.0, -50.0, 20.0, 30.0, 35.0, 50.0])
# Use relaxed tolerance for extreme values due to numerical precision differences
relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-4, atol=1e-8)
compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)
......@@ -29,127 +29,6 @@ def test_dot():
np.testing.assert_allclose(actual, expected, rtol=1e-6)
def test_switch() -> None:
x = pt.vector("x")
y = pt.vector("y")
out = pt.switch(x > 0, y, x)
x_test = mx.array([-1.0, 2.0, 3.0])
y_test = mx.array([4.0, 5.0, 6.0])
compare_mlx_and_py([x, y], out, [x_test, y_test])
def test_int_div_specific() -> None:
x = pt.vector("x")
y = pt.vector("y")
out = pt.int_div(x, y)
# Test with integers that demonstrate floor division behavior
x_test = mx.array([7.0, 8.0, 9.0, -7.0, -8.0])
y_test = mx.array([3.0, 3.0, 3.0, 3.0, 3.0])
compare_mlx_and_py([x, y], out, [x_test, y_test])
def test_isnan() -> None:
x = pt.vector("x")
out = pt.isnan(x)
x_test = mx.array([1.0, np.nan, 3.0, np.inf, -np.nan, 0.0, -np.inf])
compare_mlx_and_py([x], out, [x_test])
def test_isnan_edge_cases() -> None:
x = pt.scalar("x")
out = pt.isnan(x)
# Test individual cases
test_cases = [0.0, np.nan, np.inf, -np.inf, 1e-10, 1e10]
for test_val in test_cases:
x_test = test_val
compare_mlx_and_py([x], out, [x_test])
def test_erfc() -> None:
"""Test complementary error function"""
x = pt.vector("x")
out = pt.erfc(x)
# Test with various values including negative, positive, and zero
x_test = mx.array([0.0, 0.5, 1.0, -0.5, -1.0, 2.0, -2.0, 0.1])
compare_mlx_and_py([x], out, [x_test])
def test_erfc_extreme_values() -> None:
"""Test erfc with extreme values"""
x = pt.vector("x")
out = pt.erfc(x)
# Test with larger values where erfc approaches 0 or 2
x_test = mx.array([-3.0, -2.5, 2.5, 3.0])
# Use relaxed tolerance for extreme values due to numerical precision differences
from functools import partial
relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-3, atol=1e-6)
compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)
def test_erfcx() -> None:
"""Test scaled complementary error function"""
x = pt.vector("x")
out = pt.erfcx(x)
# Test with positive values where erfcx is most numerically stable
x_test = mx.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5])
compare_mlx_and_py([x], out, [x_test])
def test_erfcx_small_values() -> None:
"""Test erfcx with small values"""
x = pt.vector("x")
out = pt.erfcx(x)
# Test with small values
x_test = mx.array([0.001, 0.01, 0.1, 0.2])
compare_mlx_and_py([x], out, [x_test])
def test_softplus() -> None:
"""Test softplus (log(1 + exp(x))) function"""
x = pt.vector("x")
out = pt.softplus(x)
# Test with normal range values
x_test = mx.array([0.0, 1.0, 2.0, -1.0, -2.0, 10.0])
compare_mlx_and_py([x], out, [x_test])
def test_softplus_extreme_values() -> None:
"""Test softplus with extreme values to verify numerical stability"""
x = pt.vector("x")
out = pt.softplus(x)
# Test with extreme values where different branches of the implementation are used
x_test = mx.array([-40.0, -50.0, 20.0, 30.0, 35.0, 50.0])
# Use relaxed tolerance for extreme values due to numerical precision differences
from functools import partial
relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-4, atol=1e-8)
compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)
def test_mlx_max_and_argmax():
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论