提交 ba4cca3f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix numba FunctionGraph cache key

It's necessary to encode the edge information, not only the nodes and their ordering
上级 bf60f22f
...@@ -9,7 +9,7 @@ from numba import njit as _njit ...@@ -9,7 +9,7 @@ from numba import njit as _njit
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from pytensor import config from pytensor import config
from pytensor.graph.basic import Apply, Constant from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
...@@ -498,36 +498,46 @@ def numba_funcify_FunctionGraph( ...@@ -498,36 +498,46 @@ def numba_funcify_FunctionGraph(
): ):
# Collect cache keys of every Op/Constant in the FunctionGraph # Collect cache keys of every Op/Constant in the FunctionGraph
# so we can create a global cache key for the whole FunctionGraph # so we can create a global cache key for the whole FunctionGraph
fgraph_can_be_cached = [True]
cache_keys = [] cache_keys = []
toposort = fgraph.toposort() toposort = fgraph.toposort()
clients = fgraph.clients toposort_coords: dict[Variable, tuple[int, int | str]] = {
toposort_indices = {node: i for i, node in enumerate(toposort)} inp: (0, i) for i, inp in enumerate(fgraph.inputs)
# Add dummy output clients which are not included of the toposort }
toposort_indices |= { toposort_coords |= {
clients[out][0][0]: i out: (i, j)
for i, out in enumerate(fgraph.outputs, start=len(toposort)) for i, node in enumerate(toposort, start=1)
for j, out in enumerate(node.outputs)
} }
def op_conversion_and_key_collection(*args, **kwargs): def op_conversion_and_key_collection(op, *args, node, **kwargs):
# Convert an Op to a funcified function and store the cache_key # Convert an Op to a funcified function and store the cache_key
# We also Cache each Op so Numba can do less work next time it sees it # We also Cache each Op so Numba can do less work next time it sees it
func, key = numba_funcify_ensure_cache(*args, **kwargs) func, key = numba_funcify_ensure_cache(op, node=node, *args, **kwargs)
cache_keys.append(key) if key is None:
fgraph_can_be_cached[0] = False
else:
# Add graph coordinate information (input edges and node location)
cache_keys.append(
(
tuple(toposort_coords[inp] for inp in node.inputs),
key,
)
)
return func return func
def type_conversion_and_key_collection(value, variable, **kwargs): def type_conversion_and_key_collection(value, variable, **kwargs):
# Convert a constant type to a numba compatible one and compute a cache key for it # Convert a constant type to a numba compatible one and compute a cache key for it
# We need to know where in the graph the constants are used
# Otherwise we would hash stack(x, 5.0, 7.0), and stack(5.0, x, 7.0) the same
# FIXME: It doesn't make sense to call type_conversion on non-constants, # FIXME: It doesn't make sense to call type_conversion on non-constants,
# but that's what fgraph_to_python currently does. We appease it, but don't consider for caching # but that's what fgraph_to_python currently does.
# We appease it, but don't consider for caching
if isinstance(variable, Constant): if isinstance(variable, Constant):
client_indices = tuple( # Store unique key in toposort_coords. It will be included by whichever nodes make use of the constant
(toposort_indices[node], inp_idx) for node, inp_idx in clients[variable] constant_cache_key = cache_key_for_constant(value)
) assert constant_cache_key is not None
cache_keys.append((client_indices, cache_key_for_constant(value))) toposort_coords[variable] = (-1, constant_cache_key)
return numba_typify(value, variable=variable, **kwargs) return numba_typify(value, variable=variable, **kwargs)
py_func = fgraph_to_python( py_func = fgraph_to_python(
...@@ -537,12 +547,15 @@ def numba_funcify_FunctionGraph( ...@@ -537,12 +547,15 @@ def numba_funcify_FunctionGraph(
fgraph_name=fgraph_name, fgraph_name=fgraph_name,
**kwargs, **kwargs,
) )
if any(key is None for key in cache_keys): if not fgraph_can_be_cached[0]:
# If a single element couldn't be cached, we can't cache the whole FunctionGraph either # If a single element couldn't be cached, we can't cache the whole FunctionGraph either
fgraph_key = None fgraph_key = None
else: else:
# Add graph coordinate information for fgraph outputs
fgraph_output_ancestors = tuple(toposort_coords[out] for out in fgraph.outputs)
# Compose individual cache_keys into a global key for the FunctionGraph # Compose individual cache_keys into a global key for the FunctionGraph
fgraph_key = sha256( fgraph_key = sha256(
f"({type(fgraph)}, {tuple(cache_keys)}, {len(fgraph.inputs)}, {len(fgraph.outputs)})".encode() f"({type(fgraph)}, {tuple(cache_keys)}, {len(fgraph.inputs)}, {fgraph_output_ancestors})".encode()
).hexdigest() ).hexdigest()
return numba_njit(py_func), fgraph_key return numba_njit(py_func), fgraph_key
...@@ -735,14 +735,6 @@ def fgraph_to_python( ...@@ -735,14 +735,6 @@ def fgraph_to_python(
body_assigns = [] body_assigns = []
for node in order: for node in order:
compiled_func = op_conversion_fn(
node.op, node=node, storage_map=storage_map, **kwargs
)
# Create a local alias with a unique name
local_compiled_func_name = unique_name(compiled_func)
global_env[local_compiled_func_name] = compiled_func
node_input_names = [] node_input_names = []
for inp in node.inputs: for inp in node.inputs:
local_input_name = unique_name(inp) local_input_name = unique_name(inp)
...@@ -772,6 +764,13 @@ def fgraph_to_python( ...@@ -772,6 +764,13 @@ def fgraph_to_python(
node_output_names = [unique_name(v) for v in node.outputs] node_output_names = [unique_name(v) for v in node.outputs]
compiled_func = op_conversion_fn(
node.op, node=node, storage_map=storage_map, **kwargs
)
# Create a local alias with a unique name
local_compiled_func_name = unique_name(compiled_func)
global_env[local_compiled_func_name] = compiled_func
assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})" assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})"
assign_comment_str = f"{indent(str(node), '# ')}" assign_comment_str = f"{indent(str(node), '# ')}"
assign_block_str = f"{assign_comment_str}\n{assign_str}" assign_block_str = f"{assign_comment_str}\n{assign_str}"
......
...@@ -7,8 +7,7 @@ import numpy as np ...@@ -7,8 +7,7 @@ import numpy as np
import pytest import pytest
import scipy import scipy
from pytensor.compile import SymbolicInput from pytensor.tensor import scalar_from_tensor
from pytensor.tensor.utils import hash_from_ndarray
numba = pytest.importorskip("numba") numba = pytest.importorskip("numba")
...@@ -16,17 +15,23 @@ numba = pytest.importorskip("numba") ...@@ -16,17 +15,23 @@ numba = pytest.importorskip("numba")
import pytensor.scalar as ps import pytensor.scalar as ps
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, shared from pytensor import config, shared
from pytensor.compile import SymbolicInput
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import cache_key_for_constant from pytensor.link.numba.dispatch.basic import (
cache_key_for_constant,
numba_funcify_and_cache_key,
)
from pytensor.link.numba.linker import NumbaLinker from pytensor.link.numba.linker import NumbaLinker
from pytensor.scalar.basic import ScalarOp, as_scalar from pytensor.scalar.basic import Composite, ScalarOp, as_scalar
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.utils import hash_from_ndarray
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -652,3 +657,77 @@ def test_funcify_dispatch_interop(): ...@@ -652,3 +657,77 @@ def test_funcify_dispatch_interop():
outs[2].owner.op, outs[2].owner outs[2].owner.op, outs[2].owner
) )
assert numba.njit(lambda x: fn2_def_cached(x))(test_x) == 2 assert numba.njit(lambda x: fn2_def_cached(x))(test_x) == 2
class TestFgraphCacheKey:
@staticmethod
def generate_and_validate_key(fg):
_, key = numba_funcify_and_cache_key(fg)
assert key is not None
_, key_again = numba_funcify_and_cache_key(fg)
assert key == key_again # Check its stable
return key
def test_node_order(self):
x = pt.scalar("x")
log_x = pt.log(x)
graphs = [
pt.exp(x) / log_x,
log_x / pt.exp(x),
pt.exp(log_x) / x,
x / pt.exp(log_x),
pt.exp(log_x) / log_x,
log_x / pt.exp(log_x),
]
keys = []
for graph in graphs:
fg = FunctionGraph([x], [graph], clone=False)
keys.append(self.generate_and_validate_key(fg))
# Check keys are unique
assert len(set(keys)) == len(graphs)
# Extra unused input should alter the key, because it changes the function signature
y = pt.scalar("y")
for inputs in [[x, y], [y, x]]:
fg = FunctionGraph(inputs, [graphs[0]], clone=False)
keys.append(self.generate_and_validate_key(fg))
assert len(set(keys)) == len(graphs) + 2
# Adding an input as an output should also change the key
for outputs in [
[graphs[0], x],
[x, graphs[0]],
[x, x, graphs[0]],
[x, graphs[0], x],
[graphs[0], x, x],
]:
fg = FunctionGraph([x], outputs, clone=False)
keys.append(self.generate_and_validate_key(fg))
assert len(set(keys)) == len(graphs) + 2 + 5
def test_multi_output(self):
x = pt.scalar("x")
xs = scalar_from_tensor(x)
out0, out1 = Elemwise(Composite([xs], [xs * 2, xs - 2]))(x)
test_outs = [
[out0],
[out1],
[out0, out1],
[out1, out0],
]
keys = []
for test_out in test_outs:
fg = FunctionGraph([x], test_out, clone=False)
keys.append(self.generate_and_validate_key(fg))
assert len(set(keys)) == len(test_outs)
def test_constant_output(self):
fg_pi = FunctionGraph([], [pt.constant(np.pi)])
fg_e = FunctionGraph([], [pt.constant(np.e)])
assert self.generate_and_validate_key(fg_pi) != self.generate_and_validate_key(
fg_e
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论