提交 4a263f3f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba list Ops: do not cache when inplace

上级 47bd6fb5
......@@ -365,6 +365,36 @@ def register_funcify_default_op_cache_key(op_type):
return decorator
def default_hash_key_from_props(op, **extra_fields):
props_dict = op._props_dict()
if not props_dict:
# Simple op, just use the type string as key
hash = sha256(
f"({type(op)}, {tuple(extra_fields.items())})".encode()
).hexdigest()
else:
simple_types = (str, bool, int, type(None), float)
container_types = (tuple, frozenset)
if all(
isinstance(v, simple_types)
or (
isinstance(v, container_types)
and all(isinstance(i, simple_types) for i in v)
)
for v in props_dict.values()
):
# Simple props, can use string representation of props as key
hash = sha256(
f"({type(op)}, {tuple(props_dict.items())}, {tuple(extra_fields.items())})".encode()
).hexdigest()
else:
# Complex props, use pickle to serialize them
hash = hash_from_pickle_dump(
(str(type(op)), tuple(props_dict.items()), tuple(extra_fields.items())),
)
return hash
@singledispatch
def numba_funcify_and_cache_key(op, node=None, **kwargs) -> tuple[Callable, str | None]:
"""Funcify an Op and return a unique cache key that can be used by numba caching.
......@@ -408,36 +438,12 @@ def numba_funcify_and_cache_key(op, node=None, **kwargs) -> tuple[Callable, str
else:
func, integer_str = func_and_int, "None"
try:
props_dict = op._props_dict()
except AttributeError:
if not hasattr(op, "__props__"):
raise ValueError(
"The function wrapped by `numba_funcify_default_op_cache_key` can only be used with Ops with `_props`, "
f"but {op} of type {type(op)} has no _props defined (not even empty)."
)
if not props_dict:
# Simple op, just use the type string as key
hash = sha256(f"({type(op)}, {integer_str})".encode()).hexdigest()
else:
# Simple props, can use string representation of props as key
simple_types = (str, bool, int, type(None), float)
container_types = (tuple, frozenset)
if all(
isinstance(v, simple_types)
or (
isinstance(v, container_types)
and all(isinstance(i, simple_types) for i in v)
)
for v in props_dict.values()
):
hash = sha256(
f"({type(op)}, {tuple(props_dict.items())}, {integer_str})".encode()
).hexdigest()
else:
# Complex props, use pickle to serialize them
hash = hash_from_pickle_dump(
(str(type(op)), tuple(props_dict.items()), integer_str),
)
hash = default_hash_key_from_props(op, cache_version=integer_str)
return func, hash
......
......@@ -3,7 +3,11 @@ import numpy as np
from numba.types import Array, Boolean, List, Number
import pytensor.link.numba.dispatch.basic as numba_basic
from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key
from pytensor.link.numba.dispatch.basic import (
default_hash_key_from_props,
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.tensor.type_other import SliceType
from pytensor.typed_list import (
......@@ -48,12 +52,14 @@ def list_all_equal(x, y):
return False
return True
if isinstance(x, Array) and isinstance(y, Array):
if (isinstance(x, Array) and x.ndim > 0) and (isinstance(y, Array) and y.ndim > 0):
# (x == y).all() fails for 0d arrays
def all_equal(x, y):
return (x == y).all()
if isinstance(x, Number | Boolean) and isinstance(y, Number | Boolean):
if (isinstance(x, Number | Boolean) or (isinstance(x, Array) and x.ndim == 0)) and (
isinstance(y, Number | Boolean) or (isinstance(y, Array) and y.ndim == 0)
):
def all_equal(x, y):
return x == y
......@@ -71,6 +77,16 @@ def numba_deepcopy_list(x):
return deepcopy_list
def cache_key_if_not_inplace(op, inplace: bool):
if inplace:
# NUMBA is misbehaving with wrapped inplace ListType operations
# which happens when we cache it in PyTensor
# https://github.com/numba/numba/issues/10356
return None
else:
return default_hash_key_from_props(op)
@register_funcify_default_op_cache_key(MakeList)
def numba_funcify_make_list(op, node, **kwargs):
@numba_basic.numba_njit
......@@ -108,7 +124,7 @@ def numba_funcify_list_get_item(op, node, **kwargs):
return list_get_item_index
@register_funcify_default_op_cache_key(Reverse)
@register_funcify_and_cache_key(Reverse)
def numba_funcify_list_reverse(op, node, **kwargs):
inplace = op.inplace
......@@ -121,10 +137,10 @@ def numba_funcify_list_reverse(op, node, **kwargs):
z.reverse()
return z
return list_reverse
return list_reverse, cache_key_if_not_inplace(op, inplace)
@register_funcify_default_op_cache_key(Append)
@register_funcify_and_cache_key(Append)
def numba_funcify_list_append(op, node, **kwargs):
inplace = op.inplace
......@@ -137,10 +153,10 @@ def numba_funcify_list_append(op, node, **kwargs):
z.append(numba_deepcopy(to_append))
return z
return list_append
return list_append, cache_key_if_not_inplace(op, inplace)
@register_funcify_default_op_cache_key(Extend)
@register_funcify_and_cache_key(Extend)
def numba_funcify_list_extend(op, node, **kwargs):
inplace = op.inplace
......@@ -153,10 +169,10 @@ def numba_funcify_list_extend(op, node, **kwargs):
z.extend(numba_deepcopy(to_append))
return z
return list_extend
return list_extend, cache_key_if_not_inplace(op, inplace)
@register_funcify_default_op_cache_key(Insert)
@register_funcify_and_cache_key(Insert)
def numba_funcify_list_insert(op, node, **kwargs):
inplace = op.inplace
......@@ -169,7 +185,7 @@ def numba_funcify_list_insert(op, node, **kwargs):
z.insert(index.item(), numba_deepcopy(to_insert))
return z
return list_insert
return list_insert, cache_key_if_not_inplace(op, inplace)
@register_funcify_default_op_cache_key(Index)
......@@ -197,7 +213,7 @@ def numba_funcify_list_count(op, node, **kwargs):
return list_count
@register_funcify_default_op_cache_key(Remove)
@register_funcify_and_cache_key(Remove)
def numba_funcify_list_remove(op, node, **kwargs):
inplace = op.inplace
......@@ -217,4 +233,4 @@ def numba_funcify_list_remove(op, node, **kwargs):
z.pop(index_to_remove)
return z
return list_remove
return list_remove, cache_key_if_not_inplace(op, inplace)
import numpy as np
from pytensor.tensor import matrix
from pytensor.typed_list import make_list
from pytensor import In
from pytensor.tensor import as_tensor, lscalar, matrix
from pytensor.typed_list import TypedListType, make_list
from tests.link.numba.test_basic import compare_numba_and_py
......@@ -44,3 +45,32 @@ def test_make_list_find_ops():
x_test = np.arange(12).reshape(3, 4)
test_y = x_test[2]
compare_numba_and_py([x, y], [l.ind(y), l.count(y), l.remove(y)], [x_test, test_y])
def test_inplace_ops():
int64_list = TypedListType(lscalar)
ls = [int64_list(f"list[{i}]") for i in range(5)]
to_extend = lscalar("to_extend")
ls_test = [np.arange(3, dtype="int64").tolist() for _ in range(5)]
to_extend_test = np.array(99, dtype="int64")
def as_lscalar(x):
return as_tensor(x, ndim=0, dtype="int64")
fn, _ = compare_numba_and_py(
[*(In(l, mutable=True) for l in ls), to_extend],
[
ls[0].reverse(),
ls[1].append(as_lscalar(99)),
# This fails because it gets constant folded
# ls_to_extend = make_list([as_lscalar(99), as_lscalar(100)])
ls[2].extend(make_list([to_extend, to_extend + 1])),
ls[3].insert(as_lscalar(1), as_lscalar(99)),
ls[4].remove(as_lscalar(2)),
],
[*ls_test, to_extend_test],
numba_mode="NUMBA", # So it triggers inplace
)
for out in fn.maker.fgraph.outputs:
assert out.owner.op.destroy_map
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论