提交 0bd33bfc authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: ricardoV94

Fix numba cache mangling in forked processes

上级 3f9cc26a
import re
import warnings
from collections.abc import Callable
from functools import singledispatch, wraps
......@@ -13,7 +14,10 @@ from pytensor import config
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
from pytensor.link.numba.cache import (
compile_numba_function_src,
hash_from_pickle_dump,
)
from pytensor.link.utils import (
fgraph_to_python,
)
......@@ -33,7 +37,7 @@ def _filter_numba_warnings():
"ignore",
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
'Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals'
'Cannot cache compiled function "numba_funcified_fgraph.*" as it uses dynamic globals'
),
category=NumbaWarning,
)
......@@ -446,12 +450,19 @@ def numba_funcify_ensure_cache(op, *args, **kwargs) -> tuple[Callable, str | Non
print(f"{op} of type {type(op)} will not be cached by PyTensor.\n") # noqa: T201
return jitable_func, None
else:
op_name = jitable_func.__name__
full_cache_key = f"{cache_key}_fastmath{int(config.numba__fastmath)}"
# Include cache_key in the wrapper name to ensure unique LLVM symbol
# names. Without this, functions with the same __name__ (e.g. all
# DimShuffle ops produce "dimshuffle") but different return types get
# identical mangled names when numba's UID counter overlaps after
# os.fork(), causing LLVM type mismatch errors.
safe_key = re.sub(r"[^a-zA-Z0-9_]", "_", full_cache_key)
op_name = f"{jitable_func.__name__}_{safe_key}"
cached_func = compile_numba_function_src(
src=f"def {op_name}(*args): return jitable_func(*args)",
function_name=op_name,
global_env=globals() | {"jitable_func": jitable_func},
cache_key=f"{cache_key}_fastmath{int(config.numba__fastmath)}",
cache_key=full_cache_key,
)
return numba_njit(cached_func, cache=True), cache_key
......
......@@ -40,7 +40,7 @@ from tests.fixtures import * # noqa: F403
pytestmark = pytest.mark.filterwarnings(
"error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
)
......
......@@ -13,7 +13,7 @@ from tests.tensor.signal.test_conv import convolve1d_grad_benchmarker
pytestmark = pytest.mark.filterwarnings(
"error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
)
......
import contextlib
import copy
import os
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any
from unittest import mock
......@@ -787,3 +788,59 @@ class TestFgraphCacheKey:
assert self.generate_and_validate_key(fg_pi) != self.generate_and_validate_key(
fg_e
)
@pytest.mark.skipif(not hasattr(os, "fork"), reason="Test requires os.fork (Unix only)")
def test_fork_cache_no_type_mismatch(tmp_path, monkeypatch):
"""Regression test for fork-safety of the numba disk cache.
After os.fork(), numba's internal UID counter (FunctionIdentity._unique_ids)
is shared between parent and child. If two exec()-created wrapper functions
with the same qualname get the same UID in different processes, their LLVM
mangled names collide. When they have different return types (e.g. 3D vs 4D
array), this causes a ValueError during LLVM lowering.
PyTensor prevents this by including the cache key in the wrapper function
name, ensuring unique LLVM symbols even when UIDs overlap after fork.
See: https://github.com/numba/numba/issues/10486
"""
import pytensor.link.numba.cache as cache_mod
# Use a temporary cache for this test
monkeypatch.setattr(cache_mod, "NUMBA_CACHE_PATH", tmp_path)
def run_in_fork(func):
pid = os.fork()
if pid == 0:
try:
func()
os._exit(0)
except BaseException:
os._exit(1)
else:
_, status = os.waitpid(pid, 0)
return os.WEXITSTATUS(status)
def graph_a():
x = pt.tensor3("x")
fn = function([x], x.transpose(2, 0, 1), mode="NUMBA")
assert fn(np.zeros((2, 3, 4))).shape == (4, 2, 3)
def graph_b():
x = pt.tensor3("x")
fn = function([x], [x.transpose(2, 0, 1), x[None]], mode="NUMBA")
r1, r2 = fn(np.zeros((2, 3, 4)))
assert r1.shape == (4, 2, 3)
assert r2.shape == (1, 2, 3, 4)
# Fork child compiles graph_a (transpose only)
assert run_in_fork(graph_a) == 0, "Fork child failed"
# Parent compiles graph_b (transpose + expand dims)
# This loads fork's cache and also compiles fresh ops
graph_b()
# Running in another fork is also fine
assert run_in_fork(graph_a) == 0, "Fork child 1 failed"
assert run_in_fork(graph_b) == 0, "Fork child 2 failed"
......@@ -155,7 +155,7 @@ from tests.tensor.utils import (
pytestmark = pytest.mark.filterwarnings(
"error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
r"ignore::numba.NumbaPerformanceWarning",
)
......
......@@ -20,7 +20,7 @@ from pytensor.tensor.type import tensor
pytestmark = pytest.mark.filterwarnings(
"error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
r"ignore::numba.NumbaPerformanceWarning",
)
......
......@@ -48,7 +48,7 @@ from tests.tensor.utils import random
pytestmark = pytest.mark.filterwarnings(
"error",
r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph\".*:numba.NumbaWarning",
r"ignore:Cannot cache compiled function \"numba_funcified_fgraph.*:numba.NumbaWarning",
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论