提交 ac51f01c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement extensible deepcopy in numba

上级 69567532
from copy import deepcopy
from hashlib import sha256 from hashlib import sha256
import numba
import numpy as np import numpy as np
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
...@@ -15,7 +17,34 @@ from pytensor.link.numba.dispatch.basic import ( ...@@ -15,7 +17,34 @@ from pytensor.link.numba.dispatch.basic import (
register_funcify_default_op_cache_key, register_funcify_default_op_cache_key,
) )
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.type import TensorType
def numba_deepcopy(x):
return deepcopy(x)
@numba.extending.overload(numba_deepcopy)
def numba_deepcopy_tensor(x):
if isinstance(x, numba.types.Number):
def number_deepcopy(x):
return x
return number_deepcopy
if isinstance(x, numba.types.Array):
def array_deepcopy(x):
return np.copy(x)
return array_deepcopy
if isinstance(x, numba.types.UnicodeType):
def string_deepcopy(x):
return x
return string_deepcopy
@register_funcify_and_cache_key(OpFromGraph) @register_funcify_and_cache_key(OpFromGraph)
...@@ -64,19 +93,11 @@ def numba_funcify_type_casting(op, **kwargs): ...@@ -64,19 +93,11 @@ def numba_funcify_type_casting(op, **kwargs):
@register_funcify_default_op_cache_key(DeepCopyOp) @register_funcify_default_op_cache_key(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs): def numba_funcify_DeepCopyOp(op, node, **kwargs):
if isinstance(node.inputs[0].type, TensorType):
@numba_basic.numba_njit @numba_basic.numba_njit
def deepcopy(x): def deepcopy(x):
return np.copy(x) return numba_deepcopy(x)
else:
@numba_basic.numba_njit
def deepcopy(x):
return x
return deepcopy return deepcopy, 1
@register_funcify_default_op_cache_key(IfElse) @register_funcify_default_op_cache_key(IfElse)
......
from collections.abc import Callable from collections.abc import Callable
from copy import copy, deepcopy from copy import deepcopy
from functools import singledispatch from functools import singledispatch
from hashlib import sha256 from hashlib import sha256
from textwrap import dedent from textwrap import dedent
...@@ -20,6 +20,7 @@ from pytensor.link.numba.dispatch.basic import ( ...@@ -20,6 +20,7 @@ from pytensor.link.numba.dispatch.basic import (
numba_funcify, numba_funcify,
register_funcify_and_cache_key, register_funcify_and_cache_key,
) )
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.link.numba.dispatch.vectorize_codegen import ( from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options, _jit_options,
_vectorized, _vectorized,
...@@ -35,16 +36,16 @@ from pytensor.tensor.type_other import NoneTypeT ...@@ -35,16 +36,16 @@ from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.utils import _parse_gufunc_signature from pytensor.tensor.utils import _parse_gufunc_signature
@overload(copy) @numba.extending.overload(numba_deepcopy)
def copy_NumPyRandomGenerator(rng): def numba_deepcopy_random_generator(x):
def impl(rng): if isinstance(x, numba.types.NumPyRandomGeneratorType):
# TODO: Open issue on Numba?
with numba.objmode(new_rng=types.npy_rng):
new_rng = deepcopy(rng)
def random_generator_deepcopy(x):
with numba.objmode(new_rng=types.npy_rng):
new_rng = deepcopy(x)
return new_rng return new_rng
return impl return random_generator_deepcopy
@singledispatch @singledispatch
...@@ -449,7 +450,7 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs ...@@ -449,7 +450,7 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
def ov_random(core_shape, rng, size, *dist_params): def ov_random(core_shape, rng, size, *dist_params):
def impl(core_shape, rng, size, *dist_params): def impl(core_shape, rng, size, *dist_params):
if not inplace: if not inplace:
rng = copy(rng) rng = numba_deepcopy(rng)
draws = _vectorized( draws = _vectorized(
core_op_fn, core_op_fn,
......
...@@ -18,6 +18,7 @@ from pytensor.link.numba.dispatch.basic import ( ...@@ -18,6 +18,7 @@ from pytensor.link.numba.dispatch.basic import (
register_funcify_and_cache_key, register_funcify_and_cache_key,
register_funcify_default_op_cache_key, register_funcify_default_op_cache_key,
) )
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.tensor import TensorType from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.subtensor import is_full_slice from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
...@@ -104,6 +105,18 @@ def enable_slice_boxing(): ...@@ -104,6 +105,18 @@ def enable_slice_boxing():
enable_slice_boxing() enable_slice_boxing()
@numba.extending.overload(numba_deepcopy)
def numba_deepcopy_slice(x):
if isinstance(x, types.SliceType):
def deepcopy_slice(x):
return slice(
numba_deepcopy(x.start), numba_deepcopy(x.stop), numba_deepcopy(x.step)
)
return deepcopy_slice
@register_funcify_default_op_cache_key(MakeSlice) @register_funcify_default_op_cache_key(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs): def numba_funcify_MakeSlice(op, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论