提交 ac51f01c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement extensible deepcopy in numba

上级 69567532
from copy import deepcopy
from hashlib import sha256
import numba
import numpy as np
from pytensor.compile.builders import OpFromGraph
......@@ -15,7 +17,34 @@ from pytensor.link.numba.dispatch.basic import (
register_funcify_default_op_cache_key,
)
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.type import TensorType
def numba_deepcopy(x):
return deepcopy(x)
@numba.extending.overload(numba_deepcopy)
def numba_deepcopy_tensor(x):
if isinstance(x, numba.types.Number):
def number_deepcopy(x):
return x
return number_deepcopy
if isinstance(x, numba.types.Array):
def array_deepcopy(x):
return np.copy(x)
return array_deepcopy
if isinstance(x, numba.types.UnicodeType):
def string_deepcopy(x):
return x
return string_deepcopy
@register_funcify_and_cache_key(OpFromGraph)
......@@ -64,19 +93,11 @@ def numba_funcify_type_casting(op, **kwargs):
@register_funcify_default_op_cache_key(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
if isinstance(node.inputs[0].type, TensorType):
@numba_basic.numba_njit
def deepcopy(x):
return np.copy(x)
else:
@numba_basic.numba_njit
def deepcopy(x):
return x
@numba_basic.numba_njit
def deepcopy(x):
return numba_deepcopy(x)
return deepcopy
return deepcopy, 1
@register_funcify_default_op_cache_key(IfElse)
......
from collections.abc import Callable
from copy import copy, deepcopy
from copy import deepcopy
from functools import singledispatch
from hashlib import sha256
from textwrap import dedent
......@@ -20,6 +20,7 @@ from pytensor.link.numba.dispatch.basic import (
numba_funcify,
register_funcify_and_cache_key,
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options,
_vectorized,
......@@ -35,16 +36,16 @@ from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.utils import _parse_gufunc_signature
@overload(copy)
def copy_NumPyRandomGenerator(rng):
def impl(rng):
# TODO: Open issue on Numba?
with numba.objmode(new_rng=types.npy_rng):
new_rng = deepcopy(rng)
@numba.extending.overload(numba_deepcopy)
def numba_deepcopy_random_generator(x):
if isinstance(x, numba.types.NumPyRandomGeneratorType):
return new_rng
def random_generator_deepcopy(x):
with numba.objmode(new_rng=types.npy_rng):
new_rng = deepcopy(x)
return new_rng
return impl
return random_generator_deepcopy
@singledispatch
......@@ -449,7 +450,7 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
def ov_random(core_shape, rng, size, *dist_params):
def impl(core_shape, rng, size, *dist_params):
if not inplace:
rng = copy(rng)
rng = numba_deepcopy(rng)
draws = _vectorized(
core_op_fn,
......
......@@ -18,6 +18,7 @@ from pytensor.link.numba.dispatch.basic import (
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import (
......@@ -104,6 +105,18 @@ def enable_slice_boxing():
enable_slice_boxing()
@numba.extending.overload(numba_deepcopy)
def numba_deepcopy_slice(x):
if isinstance(x, types.SliceType):
def deepcopy_slice(x):
return slice(
numba_deepcopy(x.start), numba_deepcopy(x.stop), numba_deepcopy(x.step)
)
return deepcopy_slice
@register_funcify_default_op_cache_key(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
@numba_basic.numba_njit
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论