提交 d7541cd7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix numba cache mangling in forked processes

上级 bde75a2e
import re
import warnings
from collections.abc import Callable
from functools import singledispatch, wraps
......@@ -13,7 +14,10 @@ from pytensor import config
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
from pytensor.link.numba.cache import (
compile_numba_function_src,
hash_from_pickle_dump,
)
from pytensor.link.utils import (
fgraph_to_python,
)
......@@ -33,7 +37,7 @@ def _filter_numba_warnings():
"ignore",
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
'Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals'
'Cannot cache compiled function "numba_funcified_fgraph.*" as it uses dynamic globals'
),
category=NumbaWarning,
)
......@@ -446,12 +450,20 @@ def numba_funcify_ensure_cache(op, *args, **kwargs) -> tuple[Callable, str | Non
print(f"{op} of type {type(op)} will not be cached by PyTensor.\n") # noqa: T201
return jitable_func, None
else:
op_name = jitable_func.__name__
full_cache_key = f"{cache_key}_fastmath{int(config.numba__fastmath)}"
# Include cache_key in the wrapper name to ensure unique LLVM symbol names.
# Without this, functions with the same __name__ but different behavior
# (e.g. all DimShuffle ops produce "dimshuffle")
# get identical mangled names when numba's UID counter overlaps after os.fork().
# This could cause compilation errors or silent bugs.
# See https://github.com/numba/numba/issues/10486
safe_key = re.sub(r"[^a-zA-Z0-9_]", "_", full_cache_key)
op_name = f"{jitable_func.__name__}_{safe_key}"
cached_func = compile_numba_function_src(
src=f"def {op_name}(*args): return jitable_func(*args)",
function_name=op_name,
global_env=globals() | {"jitable_func": jitable_func},
cache_key=f"{cache_key}_fastmath{int(config.numba__fastmath)}",
cache_key=full_cache_key,
)
return numba_njit(cached_func, cache=True), cache_key
......
......@@ -37,7 +37,7 @@ from pytensor.tensor.type import (
pytestmark = pytest.mark.filterwarnings(
"error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
)
......
......@@ -9,7 +9,7 @@ from tests.link.numba.test_basic import compare_numba_and_py
pytestmark = pytest.mark.filterwarnings(
"error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
)
......
import contextlib
import copy
import os
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any
from unittest import mock
......@@ -770,3 +771,59 @@ class TestFgraphCacheKey:
assert self.generate_and_validate_key(fg_pi) != self.generate_and_validate_key(
fg_e
)
@pytest.mark.skipif(not hasattr(os, "fork"), reason="Test requires os.fork (Unix only)")
def test_fork_cache_no_type_mismatch(tmp_path, monkeypatch):
"""Regression test for fork-safety of the numba disk cache.
After os.fork(), numba's internal UID counter (FunctionIdentity._unique_ids)
is shared between parent and child. If two exec()-created wrapper functions
with the same qualname get the same UID in different processes, their LLVM
mangled names collide. When they have different return types (e.g. 3D vs 4D
array), this causes a ValueError during LLVM lowering.
PyTensor prevents this by including the cache key in the wrapper function
name, ensuring unique LLVM symbols even when UIDs overlap after fork.
See: https://github.com/numba/numba/issues/10486
"""
import pytensor.link.numba.cache as cache_mod
# Use a temporary cache for this test
monkeypatch.setattr(cache_mod, "NUMBA_CACHE_PATH", tmp_path)
def run_in_fork(func):
pid = os.fork()
if pid == 0:
try:
func()
os._exit(0)
except BaseException:
os._exit(1)
else:
_, status = os.waitpid(pid, 0)
return os.WEXITSTATUS(status)
def graph_a():
x = pt.tensor3("x")
fn = function([x], x.transpose(2, 0, 1), mode="NUMBA")
assert fn(np.zeros((2, 3, 4))).shape == (4, 2, 3)
def graph_b():
x = pt.tensor3("x")
fn = function([x], [x.transpose(2, 0, 1), x[None]], mode="NUMBA")
r1, r2 = fn(np.zeros((2, 3, 4)))
assert r1.shape == (4, 2, 3)
assert r2.shape == (1, 2, 3, 4)
# Fork child compiles graph_a (transpose only)
assert run_in_fork(graph_a) == 0, "Fork child failed"
# Parent compiles graph_b (transpose + expand dims)
# This loads fork's cache and also compiles fresh ops
graph_b()
# Running in another fork is also fine
assert run_in_fork(graph_a) == 0, "Fork child 1 failed"
assert run_in_fork(graph_b) == 0, "Fork child 2 failed"
......@@ -154,7 +154,7 @@ from tests.tensor.utils import (
pytestmark = pytest.mark.filterwarnings(
"error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
r"ignore::numba.NumbaPerformanceWarning",
)
......
......@@ -20,7 +20,7 @@ from pytensor.tensor.type import tensor
pytestmark = pytest.mark.filterwarnings(
"error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
r"ignore::numba.NumbaPerformanceWarning",
)
......
......@@ -48,7 +48,7 @@ from tests.tensor.utils import random
pytestmark = pytest.mark.filterwarnings(
"error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论