提交 69418200 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rename Elemwise.check_runtime_broadcast

上级 d1315406
...@@ -18,8 +18,8 @@ from tests.link.jax.test_basic import compare_jax_and_py ...@@ -18,8 +18,8 @@ from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_elemwise import TestElemwise from tests.tensor.test_elemwise import TestElemwise
def test_elemwise_runtime_shape_error(): def test_elemwise_runtime_broadcast():
TestElemwise.check_runtime_shapes_error(get_mode("JAX")) TestElemwise.check_runtime_broadcast(get_mode("JAX"))
def test_jax_Dimshuffle(): def test_jax_Dimshuffle():
......
...@@ -122,8 +122,8 @@ def test_Elemwise(inputs, input_vals, output_fn, exc): ...@@ -122,8 +122,8 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
@pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults") @pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults")
def test_elemwise_runtime_shape_error(): def test_elemwise_runtime_broadcast():
TestElemwise.check_runtime_shapes_error(get_mode("NUMBA")) TestElemwise.check_runtime_broadcast(get_mode("NUMBA"))
def test_elemwise_speed(benchmark): def test_elemwise_speed(benchmark):
......
...@@ -751,7 +751,7 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -751,7 +751,7 @@ class TestElemwise(unittest_tools.InferShapeTester):
g(*[np.zeros(2**11, config.floatX) for i in range(6)]) g(*[np.zeros(2**11, config.floatX) for i in range(6)])
@staticmethod @staticmethod
def check_runtime_shapes_error(mode): def check_runtime_broadcast(mode):
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules.""" """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
x_v = matrix("x") x_v = matrix("x")
m_v = vector("m") m_v = vector("m")
...@@ -777,15 +777,15 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -777,15 +777,15 @@ class TestElemwise(unittest_tools.InferShapeTester):
with pytest.raises((ValueError, TypeError)): with pytest.raises((ValueError, TypeError)):
f(x, m) f(x, m)
def test_runtime_shapes_error_python(self): def test_runtime_broadcast_python(self):
self.check_runtime_shapes_error(Mode(linker="py")) self.check_runtime_broadcast(Mode(linker="py"))
@pytest.mark.skipif( @pytest.mark.skipif(
not pytensor.config.cxx, not pytensor.config.cxx,
reason="G++ not available, so we need to skip this test.", reason="G++ not available, so we need to skip this test.",
) )
def test_runtime_shapes_error_c(self): def test_runtime_broadcast_c(self):
self.check_runtime_shapes_error(Mode(linker="c")) self.check_runtime_broadcast(Mode(linker="c"))
def test_str(self): def test_str(self):
op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None) op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论