提交 ec117322 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix exclude tag cxx -> cxx_only

上级 6d4e3491
......@@ -284,7 +284,7 @@ class PerformLinker(LocalLinker):
"""
required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only")
incompatible_rewrites: tuple[str, ...] = ("cxx",)
incompatible_rewrites: tuple[str, ...] = ("cxx_only",)
def __init__(
self, allow_gc: bool | None = None, schedule: Callable | None = None
......
......@@ -14,7 +14,7 @@ class JAXLinker(JITLinker):
"jax",
) # TODO: Distinguish between optional "jax" and "minimum_compile_jax"
incompatible_rewrites = (
"cxx",
"cxx_only",
"BlasOpt",
"local_careduce_fusion",
"scan_save_mem_prealloc",
......
......@@ -7,7 +7,7 @@ class NumbaLinker(JITLinker):
"numba",
) # TODO: Distinguish between optional "numba" and "minimum_compile_numba"
incompatible_rewrites = (
"cxx",
"cxx_only",
"BlasOpt",
"local_careduce_fusion",
"scan_save_mem_prealloc",
......
......@@ -840,7 +840,7 @@ class VMLinker(LocalLinker):
c_thunks = bool(config.cxx)
if not c_thunks:
self.required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only")
self.incompatible_rewrites: tuple[str, ...] = ("cxx",)
self.incompatible_rewrites: tuple[str, ...] = ("cxx_only",)
self.c_thunks = c_thunks
self.allow_partial_eval = allow_partial_eval
self.updated_vars = {}
......
......@@ -68,5 +68,6 @@ optdb.register(
"fast_run",
"inplace",
"c_blas",
"cxx_only",
position=70.0,
)
......@@ -69,13 +69,10 @@ def test_local_csm_grad_c():
not pytensor.config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_local_mul_s_d():
mode = get_default_mode()
mode = mode.including("specialize", "local_mul_s_d")
for sp_format in sparse.sparse_formats:
inputs = [getattr(pytensor.sparse, sp_format + "_matrix")(), matrix()]
f = pytensor.function(inputs, smath.mul_s_d(*inputs), mode=mode)
f = pytensor.function(inputs, smath.mul_s_d(*inputs), mode="CVM")
assert not any(
isinstance(node.op, smath.MulSD) for node in f.maker.fgraph.toposort()
......@@ -92,7 +89,7 @@ def test_local_mul_s_v():
for sp_format in ["csr"]: # Not implemented for other format
inputs = [getattr(pytensor.sparse, sp_format + "_matrix")(), vector()]
f = pytensor.function(inputs, smath.mul_s_v(*inputs), mode=mode)
f = pytensor.function(inputs, smath.mul_s_v(*inputs), mode="CVM")
assert not any(
isinstance(node.op, smath.MulSV) for node in f.maker.fgraph.toposort()
......@@ -103,13 +100,10 @@ def test_local_mul_s_v():
not pytensor.config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_local_structured_add_s_v():
mode = get_default_mode()
mode = mode.including("specialize", "local_structured_add_s_v")
for sp_format in ["csr"]: # Not implemented for other format
inputs = [getattr(pytensor.sparse, sp_format + "_matrix")(), vector()]
f = pytensor.function(inputs, smath.structured_add_s_v(*inputs), mode=mode)
f = pytensor.function(inputs, smath.structured_add_s_v(*inputs), mode="CVM")
assert not any(
isinstance(node.op, smath.StructuredAddSV)
......@@ -121,9 +115,6 @@ def test_local_structured_add_s_v():
not pytensor.config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_local_sampling_dot_csr():
mode = get_default_mode()
mode = mode.including("specialize", "local_sampling_dot_csr")
for sp_format in ["csr"]: # Not implemented for other format
inputs = [
matrix(),
......@@ -131,7 +122,7 @@ def test_local_sampling_dot_csr():
getattr(pytensor.sparse, sp_format + "_matrix")(),
]
f = pytensor.function(inputs, smath.sampling_dot(*inputs), mode=mode)
f = pytensor.function(inputs, smath.sampling_dot(*inputs), mode="CVM")
if pytensor.config.blas__ldflags:
assert not any(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论