提交 6d7fa2fa authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix numba cache mangling in forked processes

上级 ed230ae1
import re
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable
from functools import singledispatch, wraps from functools import singledispatch, wraps
...@@ -13,7 +14,10 @@ from pytensor import config ...@@ -13,7 +14,10 @@ from pytensor import config
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump from pytensor.link.numba.cache import (
compile_numba_function_src,
hash_from_pickle_dump,
)
from pytensor.link.utils import ( from pytensor.link.utils import (
fgraph_to_python, fgraph_to_python,
) )
...@@ -33,7 +37,7 @@ def _filter_numba_warnings(): ...@@ -33,7 +37,7 @@ def _filter_numba_warnings():
"ignore", "ignore",
message=( message=(
"(\x1b\\[1m)*" # ansi escape code for bold text "(\x1b\\[1m)*" # ansi escape code for bold text
'Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals' 'Cannot cache compiled function "numba_funcified_fgraph.*" as it uses dynamic globals'
), ),
category=NumbaWarning, category=NumbaWarning,
) )
...@@ -446,12 +450,20 @@ def numba_funcify_ensure_cache(op, *args, **kwargs) -> tuple[Callable, str | Non ...@@ -446,12 +450,20 @@ def numba_funcify_ensure_cache(op, *args, **kwargs) -> tuple[Callable, str | Non
print(f"{op} of type {type(op)} will not be cached by PyTensor.\n") # noqa: T201 print(f"{op} of type {type(op)} will not be cached by PyTensor.\n") # noqa: T201
return jitable_func, None return jitable_func, None
else: else:
op_name = jitable_func.__name__ full_cache_key = f"{cache_key}_fastmath{int(config.numba__fastmath)}"
# Include cache_key in the wrapper name to ensure unique LLVM symbol names.
# Without this, functions with the same __name__ but different behavior
# (e.g. all DimShuffle ops produce "dimshuffle")
# get identical mangled names when numba's UID counter overlaps after os.fork().
# This could cause compilation errors or silent bugs.
# See https://github.com/numba/numba/issues/10486
safe_key = re.sub(r"[^a-zA-Z0-9_]", "_", full_cache_key)
op_name = f"{jitable_func.__name__}_{safe_key}"
cached_func = compile_numba_function_src( cached_func = compile_numba_function_src(
src=f"def {op_name}(*args): return jitable_func(*args)", src=f"def {op_name}(*args): return jitable_func(*args)",
function_name=op_name, function_name=op_name,
global_env=globals() | {"jitable_func": jitable_func}, global_env=globals() | {"jitable_func": jitable_func},
cache_key=f"{cache_key}_fastmath{int(config.numba__fastmath)}", cache_key=full_cache_key,
) )
return numba_njit(cached_func, cache=True), cache_key return numba_njit(cached_func, cache=True), cache_key
......
...@@ -40,7 +40,7 @@ from tests.fixtures import * # noqa: F403 ...@@ -40,7 +40,7 @@ from tests.fixtures import * # noqa: F403
pytestmark = pytest.mark.filterwarnings( pytestmark = pytest.mark.filterwarnings(
"error", "error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning", r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning", r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
) )
......
...@@ -13,7 +13,7 @@ from tests.tensor.signal.test_conv import convolve1d_grad_benchmarker ...@@ -13,7 +13,7 @@ from tests.tensor.signal.test_conv import convolve1d_grad_benchmarker
pytestmark = pytest.mark.filterwarnings( pytestmark = pytest.mark.filterwarnings(
"error", "error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning", r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning", r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
) )
......
import contextlib import contextlib
import copy import copy
import os
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from unittest import mock from unittest import mock
...@@ -787,3 +788,59 @@ class TestFgraphCacheKey: ...@@ -787,3 +788,59 @@ class TestFgraphCacheKey:
assert self.generate_and_validate_key(fg_pi) != self.generate_and_validate_key( assert self.generate_and_validate_key(fg_pi) != self.generate_and_validate_key(
fg_e fg_e
) )
@pytest.mark.skipif(not hasattr(os, "fork"), reason="Test requires os.fork (Unix only)")
def test_fork_cache_no_type_mismatch(tmp_path, monkeypatch):
"""Regression test for fork-safety of the numba disk cache.
After os.fork(), numba's internal UID counter (FunctionIdentity._unique_ids)
is shared between parent and child. If two exec()-created wrapper functions
with the same qualname get the same UID in different processes, their LLVM
mangled names collide. When they have different return types (e.g. 3D vs 4D
array), this causes a ValueError during LLVM lowering.
PyTensor prevents this by including the cache key in the wrapper function
name, ensuring unique LLVM symbols even when UIDs overlap after fork.
See: https://github.com/numba/numba/issues/10486
"""
import pytensor.link.numba.cache as cache_mod
# Use a temporary cache for this test
monkeypatch.setattr(cache_mod, "NUMBA_CACHE_PATH", tmp_path)
def run_in_fork(func):
pid = os.fork()
if pid == 0:
try:
func()
os._exit(0)
except BaseException:
os._exit(1)
else:
_, status = os.waitpid(pid, 0)
return os.WEXITSTATUS(status)
def graph_a():
x = pt.tensor3("x")
fn = function([x], x.transpose(2, 0, 1), mode="NUMBA")
assert fn(np.zeros((2, 3, 4))).shape == (4, 2, 3)
def graph_b():
x = pt.tensor3("x")
fn = function([x], [x.transpose(2, 0, 1), x[None]], mode="NUMBA")
r1, r2 = fn(np.zeros((2, 3, 4)))
assert r1.shape == (4, 2, 3)
assert r2.shape == (1, 2, 3, 4)
# Fork child compiles graph_a (transpose only)
assert run_in_fork(graph_a) == 0, "Fork child failed"
# Parent compiles graph_b (transpose + expand dims)
# This loads fork's cache and also compiles fresh ops
graph_b()
# Running in another fork is also fine
assert run_in_fork(graph_a) == 0, "Fork child 1 failed"
assert run_in_fork(graph_b) == 0, "Fork child 2 failed"
...@@ -155,7 +155,7 @@ from tests.tensor.utils import ( ...@@ -155,7 +155,7 @@ from tests.tensor.utils import (
pytestmark = pytest.mark.filterwarnings( pytestmark = pytest.mark.filterwarnings(
"error", "error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning", r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning", r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
r"ignore::numba.NumbaPerformanceWarning", r"ignore::numba.NumbaPerformanceWarning",
) )
......
...@@ -20,7 +20,7 @@ from pytensor.tensor.type import tensor ...@@ -20,7 +20,7 @@ from pytensor.tensor.type import tensor
pytestmark = pytest.mark.filterwarnings( pytestmark = pytest.mark.filterwarnings(
"error", "error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning", r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning", r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
r"ignore::numba.NumbaPerformanceWarning", r"ignore::numba.NumbaPerformanceWarning",
) )
......
...@@ -48,7 +48,7 @@ from tests.tensor.utils import random ...@@ -48,7 +48,7 @@ from tests.tensor.utils import random
pytestmark = pytest.mark.filterwarnings( pytestmark = pytest.mark.filterwarnings(
"error", "error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning", r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning", r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论