提交 4de2a7e6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Don't use overload for deepcopy

上级 c72a48d7
import warnings
from copy import copy
from functools import singledispatch
import numba
......@@ -7,7 +6,6 @@ import numpy as np
from numba import types
from numba.core.errors import NumbaWarning, TypingError
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import overload
from pytensor import In, config
from pytensor.compile import NUMBA
......@@ -296,21 +294,21 @@ def numba_funcify_FunctionGraph(
)
def deepcopyop(x):
return copy(x)
@numba_funcify.register(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
if isinstance(node.inputs[0].type, TensorType):
@overload(deepcopyop)
def dispatch_deepcopyop(x):
if isinstance(x, types.Array):
return lambda x: np.copy(x)
@numba_njit
def deepcopy(x):
return np.copy(x)
return lambda x: x
else:
@numba_njit
def deepcopy(x):
return x
@numba_funcify.register(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
return deepcopyop
return deepcopy
@numba.extending.intrinsic
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论