提交 a3613d13 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix cache of default subtensor

Implementation was specializing on node repeated inputs an `unique_names` would return the same name for repeated inputs. The cache key didn't account for this. We also don't want to compile different functions for different patterns of repeated inputs as it doesn't translate to an obvious handle for the compiler to specialize upon. We we wanted to inline constants that may make more sense.
上级 e2c462d5
......@@ -18,7 +18,6 @@ from pytensor.link.numba.dispatch.basic import (
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.link.utils import unique_name_generator
from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import (
......@@ -143,19 +142,14 @@ def subtensor_op_cache_key(op, **extra_fields):
def numba_funcify_default_subtensor(op, node, **kwargs):
"""Create a Python function that assembles and uses an index on an array."""
unique_names = unique_name_generator(
["subtensor", "incsubtensor", "z"], suffix_sep="_"
)
def convert_indices(indices, entry):
if indices and isinstance(entry, Type):
rval = indices.pop(0)
return unique_names(rval)
def convert_indices(indice_names, entry):
if indice_names and isinstance(entry, Type):
return next(indice_names)
elif isinstance(entry, slice):
return (
f"slice({convert_indices(indices, entry.start)}, "
f"{convert_indices(indices, entry.stop)}, "
f"{convert_indices(indices, entry.step)})"
f"slice({convert_indices(indice_names, entry.start)}, "
f"{convert_indices(indice_names, entry.stop)}, "
f"{convert_indices(indice_names, entry.step)})"
)
elif isinstance(entry, type(None)):
return "None"
......@@ -166,13 +160,15 @@ def numba_funcify_default_subtensor(op, node, **kwargs):
op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
)
index_start_idx = 1 + int(set_or_inc)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
op_indices = list(node.inputs[index_start_idx:])
idx_list = getattr(op, "idx_list", None)
idx_names = [f"idx_{i}" for i in range(len(op_indices))]
input_names = ["x", "y", *idx_names] if set_or_inc else ["x", *idx_names]
idx_names_iterator = iter(idx_names)
indices_creation_src = (
tuple(convert_indices(op_indices, idx) for idx in idx_list)
tuple(convert_indices(idx_names_iterator, idx) for idx in idx_list)
if idx_list
else tuple(input_names[index_start_idx:])
)
......@@ -220,7 +216,9 @@ def {function_name}({", ".join(input_names)}):
function_name=function_name,
global_env=globals() | {"np": np},
)
cache_key = subtensor_op_cache_key(op, func="numba_funcify_default_subtensor")
cache_key = subtensor_op_cache_key(
op, func="numba_funcify_default_subtensor", version=1
)
return numba_basic.numba_njit(func, boundscheck=True), cache_key
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论