提交 f86a0dc1 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Don't run unrelated tests in altenarnative backends

上级 b5a17dd0
......@@ -15,11 +15,11 @@ from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, vector, vectors
from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_elemwise import TestElemwise
from tests.tensor.test_elemwise import check_elemwise_runtime_broadcast
def test_elemwise_runtime_broadcast():
TestElemwise.check_runtime_broadcast(get_mode("JAX"))
check_elemwise_runtime_broadcast(get_mode("JAX"))
def test_jax_Dimshuffle():
......
......@@ -14,7 +14,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor.type import iscalar, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_basic import TestAlloc
from tests.tensor.test_basic import check_alloc_runtime_broadcast
def test_jax_Alloc():
......@@ -54,7 +54,7 @@ def test_jax_Alloc():
def test_alloc_runtime_broadcast():
TestAlloc.check_runtime_broadcast(get_mode("JAX"))
check_alloc_runtime_broadcast(get_mode("JAX"))
def test_jax_MakeVector():
......
......@@ -24,7 +24,10 @@ from tests.link.numba.test_basic import (
scalar_my_multi_out,
set_test_value,
)
from tests.tensor.test_elemwise import TestElemwise, careduce_benchmark_tester
from tests.tensor.test_elemwise import (
careduce_benchmark_tester,
check_elemwise_runtime_broadcast,
)
rng = np.random.default_rng(42849)
......@@ -124,7 +127,7 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
@pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults")
def test_elemwise_runtime_broadcast():
TestElemwise.check_runtime_broadcast(get_mode("NUMBA"))
check_elemwise_runtime_broadcast(get_mode("NUMBA"))
def test_elemwise_speed(benchmark):
......
......@@ -16,7 +16,7 @@ from tests.link.numba.test_basic import (
compare_shape_dtype,
set_test_value,
)
from tests.tensor.test_basic import TestAlloc
from tests.tensor.test_basic import check_alloc_runtime_broadcast
pytest.importorskip("numba")
......@@ -52,7 +52,7 @@ def test_Alloc(v, shape):
def test_alloc_runtime_broadcast():
TestAlloc.check_runtime_broadcast(get_mode("NUMBA"))
check_alloc_runtime_broadcast(get_mode("NUMBA"))
def test_AllocEmpty():
......
......@@ -716,6 +716,32 @@ class TestAsTensorVariable:
ptb.as_tensor(x)
def check_alloc_runtime_broadcast(mode):
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
floatX = config.floatX
x_v = vector("x", shape=(None,))
out = alloc(x_v, 5, 3)
f = pytensor.function([x_v], out, mode=mode)
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
np.testing.assert_array_equal(
f(x=np.zeros((3,), dtype=floatX)),
np.zeros((5, 3), dtype=floatX),
)
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(x=np.zeros((1,), dtype=floatX))
out = alloc(specify_shape(x_v, (1,)), 5, 3)
f = pytensor.function([x_v], out, mode=mode)
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
np.testing.assert_array_equal(
f(x=np.zeros((1,), dtype=floatX)),
np.zeros((5, 3), dtype=floatX),
)
class TestAlloc:
dtype = config.floatX
mode = mode_opt
......@@ -729,32 +755,6 @@ class TestAlloc:
== n
)
@staticmethod
def check_runtime_broadcast(mode):
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
floatX = config.floatX
x_v = vector("x", shape=(None,))
out = alloc(x_v, 5, 3)
f = pytensor.function([x_v], out, mode=mode)
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
np.testing.assert_array_equal(
f(x=np.zeros((3,), dtype=floatX)),
np.zeros((5, 3), dtype=floatX),
)
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(x=np.zeros((1,), dtype=floatX))
out = alloc(specify_shape(x_v, (1,)), 5, 3)
f = pytensor.function([x_v], out, mode=mode)
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
np.testing.assert_array_equal(
f(x=np.zeros((1,), dtype=floatX)),
np.zeros((5, 3), dtype=floatX),
)
def setup_method(self):
self.rng = np.random.default_rng(seed=utt.fetch_seed())
......@@ -912,7 +912,7 @@ class TestAlloc:
@pytest.mark.parametrize("mode", (Mode("py"), Mode("c")))
def test_runtime_broadcast(self, mode):
self.check_runtime_broadcast(mode)
check_alloc_runtime_broadcast(mode)
def test_infer_static_shape():
......
......@@ -705,6 +705,33 @@ class TestBitOpReduceGrad:
assert np.all(gx_val == 0)
def check_elemwise_runtime_broadcast(mode):
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
x_v = matrix("x")
m_v = vector("m")
z_v = x_v - m_v
f = pytensor.function([x_v, m_v], z_v, mode=mode)
# Test invalid broadcasting by either x or m
for x_sh, m_sh in [((2, 1), (3,)), ((2, 3), (1,))]:
x = np.ones(x_sh).astype(config.floatX)
m = np.zeros(m_sh).astype(config.floatX)
# This error is introduced by PyTensor, so it's the same across different backends
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(x, m)
x = np.ones((2, 3)).astype(config.floatX)
m = np.zeros((1,)).astype(config.floatX)
x = np.ones((2, 4)).astype(config.floatX)
m = np.zeros((3,)).astype(config.floatX)
# This error is backend specific, and may have different types
with pytest.raises((ValueError, TypeError)):
f(x, m)
class TestElemwise(unittest_tools.InferShapeTester):
def test_elemwise_grad_bool(self):
x = scalar(dtype="bool")
......@@ -750,42 +777,15 @@ class TestElemwise(unittest_tools.InferShapeTester):
g = pytensor.function([a, b, c, d, e, f], s, mode=Mode(linker="py"))
g(*[np.zeros(2**11, config.floatX) for i in range(6)])
@staticmethod
def check_runtime_broadcast(mode):
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
x_v = matrix("x")
m_v = vector("m")
z_v = x_v - m_v
f = pytensor.function([x_v, m_v], z_v, mode=mode)
# Test invalid broadcasting by either x or m
for x_sh, m_sh in [((2, 1), (3,)), ((2, 3), (1,))]:
x = np.ones(x_sh).astype(config.floatX)
m = np.zeros(m_sh).astype(config.floatX)
# This error is introduced by PyTensor, so it's the same across different backends
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(x, m)
x = np.ones((2, 3)).astype(config.floatX)
m = np.zeros((1,)).astype(config.floatX)
x = np.ones((2, 4)).astype(config.floatX)
m = np.zeros((3,)).astype(config.floatX)
# This error is backend specific, and may have different types
with pytest.raises((ValueError, TypeError)):
f(x, m)
def test_runtime_broadcast_python(self):
self.check_runtime_broadcast(Mode(linker="py"))
check_elemwise_runtime_broadcast(Mode(linker="py"))
@pytest.mark.skipif(
not pytensor.config.cxx,
reason="G++ not available, so we need to skip this test.",
)
def test_runtime_broadcast_c(self):
self.check_runtime_broadcast(Mode(linker="c"))
check_elemwise_runtime_broadcast(Mode(linker="c"))
def test_str(self):
op = Elemwise(ps.add, inplace_pattern={0: 0}, name=None)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论