提交 ec117322 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix exclude tag cxx -> cxx_only

上级 6d4e3491
...@@ -284,7 +284,7 @@ class PerformLinker(LocalLinker): ...@@ -284,7 +284,7 @@ class PerformLinker(LocalLinker):
""" """
required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only") required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only")
incompatible_rewrites: tuple[str, ...] = ("cxx",) incompatible_rewrites: tuple[str, ...] = ("cxx_only",)
def __init__( def __init__(
self, allow_gc: bool | None = None, schedule: Callable | None = None self, allow_gc: bool | None = None, schedule: Callable | None = None
......
...@@ -14,7 +14,7 @@ class JAXLinker(JITLinker): ...@@ -14,7 +14,7 @@ class JAXLinker(JITLinker):
"jax", "jax",
) # TODO: Distinguish between optional "jax" and "minimum_compile_jax" ) # TODO: Distinguish between optional "jax" and "minimum_compile_jax"
incompatible_rewrites = ( incompatible_rewrites = (
"cxx", "cxx_only",
"BlasOpt", "BlasOpt",
"local_careduce_fusion", "local_careduce_fusion",
"scan_save_mem_prealloc", "scan_save_mem_prealloc",
......
...@@ -7,7 +7,7 @@ class NumbaLinker(JITLinker): ...@@ -7,7 +7,7 @@ class NumbaLinker(JITLinker):
"numba", "numba",
) # TODO: Distinguish between optional "numba" and "minimum_compile_numba" ) # TODO: Distinguish between optional "numba" and "minimum_compile_numba"
incompatible_rewrites = ( incompatible_rewrites = (
"cxx", "cxx_only",
"BlasOpt", "BlasOpt",
"local_careduce_fusion", "local_careduce_fusion",
"scan_save_mem_prealloc", "scan_save_mem_prealloc",
......
...@@ -840,7 +840,7 @@ class VMLinker(LocalLinker): ...@@ -840,7 +840,7 @@ class VMLinker(LocalLinker):
c_thunks = bool(config.cxx) c_thunks = bool(config.cxx)
if not c_thunks: if not c_thunks:
self.required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only") self.required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only")
self.incompatible_rewrites: tuple[str, ...] = ("cxx",) self.incompatible_rewrites: tuple[str, ...] = ("cxx_only",)
self.c_thunks = c_thunks self.c_thunks = c_thunks
self.allow_partial_eval = allow_partial_eval self.allow_partial_eval = allow_partial_eval
self.updated_vars = {} self.updated_vars = {}
......
...@@ -68,5 +68,6 @@ optdb.register( ...@@ -68,5 +68,6 @@ optdb.register(
"fast_run", "fast_run",
"inplace", "inplace",
"c_blas", "c_blas",
"cxx_only",
position=70.0, position=70.0,
) )
...@@ -69,13 +69,10 @@ def test_local_csm_grad_c(): ...@@ -69,13 +69,10 @@ def test_local_csm_grad_c():
not pytensor.config.cxx, reason="G++ not available, so we need to skip this test." not pytensor.config.cxx, reason="G++ not available, so we need to skip this test."
) )
def test_local_mul_s_d(): def test_local_mul_s_d():
mode = get_default_mode()
mode = mode.including("specialize", "local_mul_s_d")
for sp_format in sparse.sparse_formats: for sp_format in sparse.sparse_formats:
inputs = [getattr(pytensor.sparse, sp_format + "_matrix")(), matrix()] inputs = [getattr(pytensor.sparse, sp_format + "_matrix")(), matrix()]
f = pytensor.function(inputs, smath.mul_s_d(*inputs), mode=mode) f = pytensor.function(inputs, smath.mul_s_d(*inputs), mode="CVM")
assert not any( assert not any(
isinstance(node.op, smath.MulSD) for node in f.maker.fgraph.toposort() isinstance(node.op, smath.MulSD) for node in f.maker.fgraph.toposort()
...@@ -92,7 +89,7 @@ def test_local_mul_s_v(): ...@@ -92,7 +89,7 @@ def test_local_mul_s_v():
for sp_format in ["csr"]: # Not implemented for other format for sp_format in ["csr"]: # Not implemented for other format
inputs = [getattr(pytensor.sparse, sp_format + "_matrix")(), vector()] inputs = [getattr(pytensor.sparse, sp_format + "_matrix")(), vector()]
f = pytensor.function(inputs, smath.mul_s_v(*inputs), mode=mode) f = pytensor.function(inputs, smath.mul_s_v(*inputs), mode="CVM")
assert not any( assert not any(
isinstance(node.op, smath.MulSV) for node in f.maker.fgraph.toposort() isinstance(node.op, smath.MulSV) for node in f.maker.fgraph.toposort()
...@@ -103,13 +100,10 @@ def test_local_mul_s_v(): ...@@ -103,13 +100,10 @@ def test_local_mul_s_v():
not pytensor.config.cxx, reason="G++ not available, so we need to skip this test." not pytensor.config.cxx, reason="G++ not available, so we need to skip this test."
) )
def test_local_structured_add_s_v(): def test_local_structured_add_s_v():
mode = get_default_mode()
mode = mode.including("specialize", "local_structured_add_s_v")
for sp_format in ["csr"]: # Not implemented for other format for sp_format in ["csr"]: # Not implemented for other format
inputs = [getattr(pytensor.sparse, sp_format + "_matrix")(), vector()] inputs = [getattr(pytensor.sparse, sp_format + "_matrix")(), vector()]
f = pytensor.function(inputs, smath.structured_add_s_v(*inputs), mode=mode) f = pytensor.function(inputs, smath.structured_add_s_v(*inputs), mode="CVM")
assert not any( assert not any(
isinstance(node.op, smath.StructuredAddSV) isinstance(node.op, smath.StructuredAddSV)
...@@ -121,9 +115,6 @@ def test_local_structured_add_s_v(): ...@@ -121,9 +115,6 @@ def test_local_structured_add_s_v():
not pytensor.config.cxx, reason="G++ not available, so we need to skip this test." not pytensor.config.cxx, reason="G++ not available, so we need to skip this test."
) )
def test_local_sampling_dot_csr(): def test_local_sampling_dot_csr():
mode = get_default_mode()
mode = mode.including("specialize", "local_sampling_dot_csr")
for sp_format in ["csr"]: # Not implemented for other format for sp_format in ["csr"]: # Not implemented for other format
inputs = [ inputs = [
matrix(), matrix(),
...@@ -131,7 +122,7 @@ def test_local_sampling_dot_csr(): ...@@ -131,7 +122,7 @@ def test_local_sampling_dot_csr():
getattr(pytensor.sparse, sp_format + "_matrix")(), getattr(pytensor.sparse, sp_format + "_matrix")(),
] ]
f = pytensor.function(inputs, smath.sampling_dot(*inputs), mode=mode) f = pytensor.function(inputs, smath.sampling_dot(*inputs), mode="CVM")
if pytensor.config.blas__ldflags: if pytensor.config.blas__ldflags:
assert not any( assert not any(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论