提交 ba4cca3f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix numba FunctionGraph cache key

It's necessary to encode the edge information, not only the nodes and their ordering
上级 bf60f22f
......@@ -9,7 +9,7 @@ from numba import njit as _njit
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from pytensor import config
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
......@@ -498,36 +498,46 @@ def numba_funcify_FunctionGraph(
):
# Collect cache keys of every Op/Constant in the FunctionGraph
# so we can create a global cache key for the whole FunctionGraph
fgraph_can_be_cached = [True]
cache_keys = []
toposort = fgraph.toposort()
clients = fgraph.clients
toposort_indices = {node: i for i, node in enumerate(toposort)}
# Add dummy output clients which are not included of the toposort
toposort_indices |= {
clients[out][0][0]: i
for i, out in enumerate(fgraph.outputs, start=len(toposort))
toposort_coords: dict[Variable, tuple[int, int | str]] = {
inp: (0, i) for i, inp in enumerate(fgraph.inputs)
}
toposort_coords |= {
out: (i, j)
for i, node in enumerate(toposort, start=1)
for j, out in enumerate(node.outputs)
}
def op_conversion_and_key_collection(*args, **kwargs):
def op_conversion_and_key_collection(op, *args, node, **kwargs):
# Convert an Op to a funcified function and store the cache_key
# We also Cache each Op so Numba can do less work next time it sees it
func, key = numba_funcify_ensure_cache(*args, **kwargs)
cache_keys.append(key)
func, key = numba_funcify_ensure_cache(op, node=node, *args, **kwargs)
if key is None:
fgraph_can_be_cached[0] = False
else:
# Add graph coordinate information (input edges and node location)
cache_keys.append(
(
tuple(toposort_coords[inp] for inp in node.inputs),
key,
)
)
return func
def type_conversion_and_key_collection(value, variable, **kwargs):
# Convert a constant type to a numba compatible one and compute a cache key for it
# We need to know where in the graph the constants are used
# Otherwise we would hash stack(x, 5.0, 7.0), and stack(5.0, x, 7.0) the same
# FIXME: It doesn't make sense to call type_conversion on non-constants,
# but that's what fgraph_to_python currently does. We appease it, but don't consider for caching
# but that's what fgraph_to_python currently does.
# We appease it, but don't consider for caching
if isinstance(variable, Constant):
client_indices = tuple(
(toposort_indices[node], inp_idx) for node, inp_idx in clients[variable]
)
cache_keys.append((client_indices, cache_key_for_constant(value)))
# Store unique key in toposort_coords. It will be included by whichever nodes make use of the constant
constant_cache_key = cache_key_for_constant(value)
assert constant_cache_key is not None
toposort_coords[variable] = (-1, constant_cache_key)
return numba_typify(value, variable=variable, **kwargs)
py_func = fgraph_to_python(
......@@ -537,12 +547,15 @@ def numba_funcify_FunctionGraph(
fgraph_name=fgraph_name,
**kwargs,
)
if any(key is None for key in cache_keys):
if not fgraph_can_be_cached[0]:
# If a single element couldn't be cached, we can't cache the whole FunctionGraph either
fgraph_key = None
else:
# Add graph coordinate information for fgraph outputs
fgraph_output_ancestors = tuple(toposort_coords[out] for out in fgraph.outputs)
# Compose individual cache_keys into a global key for the FunctionGraph
fgraph_key = sha256(
f"({type(fgraph)}, {tuple(cache_keys)}, {len(fgraph.inputs)}, {len(fgraph.outputs)})".encode()
f"({type(fgraph)}, {tuple(cache_keys)}, {len(fgraph.inputs)}, {fgraph_output_ancestors})".encode()
).hexdigest()
return numba_njit(py_func), fgraph_key
......@@ -735,14 +735,6 @@ def fgraph_to_python(
body_assigns = []
for node in order:
compiled_func = op_conversion_fn(
node.op, node=node, storage_map=storage_map, **kwargs
)
# Create a local alias with a unique name
local_compiled_func_name = unique_name(compiled_func)
global_env[local_compiled_func_name] = compiled_func
node_input_names = []
for inp in node.inputs:
local_input_name = unique_name(inp)
......@@ -772,6 +764,13 @@ def fgraph_to_python(
node_output_names = [unique_name(v) for v in node.outputs]
compiled_func = op_conversion_fn(
node.op, node=node, storage_map=storage_map, **kwargs
)
# Create a local alias with a unique name
local_compiled_func_name = unique_name(compiled_func)
global_env[local_compiled_func_name] = compiled_func
assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})"
assign_comment_str = f"{indent(str(node), '# ')}"
assign_block_str = f"{assign_comment_str}\n{assign_str}"
......
......@@ -7,8 +7,7 @@ import numpy as np
import pytest
import scipy
from pytensor.compile import SymbolicInput
from pytensor.tensor.utils import hash_from_ndarray
from pytensor.tensor import scalar_from_tensor
numba = pytest.importorskip("numba")
......@@ -16,17 +15,23 @@ numba = pytest.importorskip("numba")
import pytensor.scalar as ps
import pytensor.tensor as pt
from pytensor import config, shared
from pytensor.compile import SymbolicInput
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.type import Type
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import cache_key_for_constant
from pytensor.link.numba.dispatch.basic import (
cache_key_for_constant,
numba_funcify_and_cache_key,
)
from pytensor.link.numba.linker import NumbaLinker
from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.scalar.basic import Composite, ScalarOp, as_scalar
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.utils import hash_from_ndarray
if TYPE_CHECKING:
......@@ -652,3 +657,77 @@ def test_funcify_dispatch_interop():
outs[2].owner.op, outs[2].owner
)
assert numba.njit(lambda x: fn2_def_cached(x))(test_x) == 2
class TestFgraphCacheKey:
@staticmethod
def generate_and_validate_key(fg):
_, key = numba_funcify_and_cache_key(fg)
assert key is not None
_, key_again = numba_funcify_and_cache_key(fg)
assert key == key_again # Check its stable
return key
def test_node_order(self):
x = pt.scalar("x")
log_x = pt.log(x)
graphs = [
pt.exp(x) / log_x,
log_x / pt.exp(x),
pt.exp(log_x) / x,
x / pt.exp(log_x),
pt.exp(log_x) / log_x,
log_x / pt.exp(log_x),
]
keys = []
for graph in graphs:
fg = FunctionGraph([x], [graph], clone=False)
keys.append(self.generate_and_validate_key(fg))
# Check keys are unique
assert len(set(keys)) == len(graphs)
# Extra unused input should alter the key, because it changes the function signature
y = pt.scalar("y")
for inputs in [[x, y], [y, x]]:
fg = FunctionGraph(inputs, [graphs[0]], clone=False)
keys.append(self.generate_and_validate_key(fg))
assert len(set(keys)) == len(graphs) + 2
# Adding an input as an output should also change the key
for outputs in [
[graphs[0], x],
[x, graphs[0]],
[x, x, graphs[0]],
[x, graphs[0], x],
[graphs[0], x, x],
]:
fg = FunctionGraph([x], outputs, clone=False)
keys.append(self.generate_and_validate_key(fg))
assert len(set(keys)) == len(graphs) + 2 + 5
def test_multi_output(self):
x = pt.scalar("x")
xs = scalar_from_tensor(x)
out0, out1 = Elemwise(Composite([xs], [xs * 2, xs - 2]))(x)
test_outs = [
[out0],
[out1],
[out0, out1],
[out1, out0],
]
keys = []
for test_out in test_outs:
fg = FunctionGraph([x], test_out, clone=False)
keys.append(self.generate_and_validate_key(fg))
assert len(set(keys)) == len(test_outs)
def test_constant_output(self):
fg_pi = FunctionGraph([], [pt.constant(np.pi)])
fg_e = FunctionGraph([], [pt.constant(np.e)])
assert self.generate_and_validate_key(fg_pi) != self.generate_and_validate_key(
fg_e
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论