Unverified 提交 b4522d23 authored 作者: Pablo de Roque's avatar Pablo de Roque 提交者: GitHub

Remove uses of `numba_basic.global_numba_func`

上级 21218d77
...@@ -402,24 +402,22 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs): ...@@ -402,24 +402,22 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
return deepcopyop return deepcopyop
@numba_njit
def makeslice(*x):
return slice(*x)
@numba_funcify.register(MakeSlice) @numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs): def numba_funcify_MakeSlice(op, **kwargs):
return global_numba_func(makeslice) @numba_njit
def makeslice(*x):
return slice(*x)
@numba_njit return makeslice
def shape(x):
return np.asarray(np.shape(x))
@numba_funcify.register(Shape) @numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs): def numba_funcify_Shape(op, **kwargs):
return global_numba_func(shape) @numba_njit
def shape(x):
return np.asarray(np.shape(x))
return shape
@numba_funcify.register(Shape_i) @numba_funcify.register(Shape_i)
......
...@@ -141,17 +141,16 @@ def {scalar_op_fn_name}({', '.join(input_names)}): ...@@ -141,17 +141,16 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
)(scalar_op_fn) )(scalar_op_fn)
@numba_basic.numba_njit @numba_funcify.register(Switch)
def switch(condition, x, y): def numba_funcify_Switch(op, node, **kwargs):
@numba_basic.numba_njit
def switch(condition, x, y):
if condition: if condition:
return x return x
else: else:
return y return y
return switch
@numba_funcify.register(Switch)
def numba_funcify_Switch(op, node, **kwargs):
return numba_basic.global_numba_func(switch)
def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str): def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
...@@ -197,34 +196,32 @@ def numba_funcify_Cast(op, node, **kwargs): ...@@ -197,34 +196,32 @@ def numba_funcify_Cast(op, node, **kwargs):
return cast return cast
@numba_basic.numba_njit
def identity(x):
return x
@numba_funcify.register(Identity) @numba_funcify.register(Identity)
@numba_funcify.register(TypeCastingOp) @numba_funcify.register(TypeCastingOp)
def numba_funcify_type_casting(op, **kwargs): def numba_funcify_type_casting(op, **kwargs):
return numba_basic.global_numba_func(identity) @numba_basic.numba_njit
def identity(x):
return x
return identity
@numba_basic.numba_njit
def clip(_x, _min, _max):
x = numba_basic.to_scalar(_x)
_min_scalar = numba_basic.to_scalar(_min)
_max_scalar = numba_basic.to_scalar(_max)
if x < _min_scalar: @numba_funcify.register(Clip)
return _min_scalar def numba_funcify_Clip(op, **kwargs):
elif x > _max_scalar: @numba_basic.numba_njit
return _max_scalar def clip(x, min_val, max_val):
x = numba_basic.to_scalar(x)
min_scalar = numba_basic.to_scalar(min_val)
max_scalar = numba_basic.to_scalar(max_val)
if x < min_scalar:
return min_scalar
elif x > max_scalar:
return max_scalar
else: else:
return x return x
return clip
@numba_funcify.register(Clip)
def numba_funcify_Clip(op, **kwargs):
return numba_basic.global_numba_func(clip)
@numba_funcify.register(Composite) @numba_funcify.register(Composite)
...@@ -239,79 +236,72 @@ def numba_funcify_Composite(op, node, **kwargs): ...@@ -239,79 +236,72 @@ def numba_funcify_Composite(op, node, **kwargs):
return composite_fn return composite_fn
@numba_basic.numba_njit
def second(x, y):
return y
@numba_funcify.register(Second) @numba_funcify.register(Second)
def numba_funcify_Second(op, node, **kwargs): def numba_funcify_Second(op, node, **kwargs):
return numba_basic.global_numba_func(second) @numba_basic.numba_njit
def second(x, y):
return y
@numba_basic.numba_njit return second
def reciprocal(x):
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
# `x` is an `int`
return 1 / x
@numba_funcify.register(Reciprocal) @numba_funcify.register(Reciprocal)
def numba_funcify_Reciprocal(op, node, **kwargs): def numba_funcify_Reciprocal(op, node, **kwargs):
return numba_basic.global_numba_func(reciprocal) @numba_basic.numba_njit
def reciprocal(x):
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
# `x` is an `int`
return 1 / x
@numba_basic.numba_njit return reciprocal
def sigmoid(x):
return 1 / (1 + np.exp(-x))
@numba_funcify.register(Sigmoid) @numba_funcify.register(Sigmoid)
def numba_funcify_Sigmoid(op, node, **kwargs): def numba_funcify_Sigmoid(op, node, **kwargs):
return numba_basic.global_numba_func(sigmoid) @numba_basic.numba_njit
def sigmoid(x):
return 1 / (1 + np.exp(-x))
@numba_basic.numba_njit return sigmoid
def gammaln(x):
return math.lgamma(x)
@numba_funcify.register(GammaLn) @numba_funcify.register(GammaLn)
def numba_funcify_GammaLn(op, node, **kwargs): def numba_funcify_GammaLn(op, node, **kwargs):
return numba_basic.global_numba_func(gammaln) @numba_basic.numba_njit
def gammaln(x):
return math.lgamma(x)
return gammaln
@numba_basic.numba_njit @numba_funcify.register(Log1mexp)
def logp1mexp(x): def numba_funcify_Log1mexp(op, node, **kwargs):
@numba_basic.numba_njit
def logp1mexp(x):
if x < np.log(0.5): if x < np.log(0.5):
return np.log1p(-np.exp(x)) return np.log1p(-np.exp(x))
else: else:
return np.log(-np.expm1(x)) return np.log(-np.expm1(x))
return logp1mexp
@numba_funcify.register(Log1mexp)
def numba_funcify_Log1mexp(op, node, **kwargs):
return numba_basic.global_numba_func(logp1mexp)
@numba_basic.numba_njit
def erf(x):
return math.erf(x)
@numba_funcify.register(Erf) @numba_funcify.register(Erf)
def numba_funcify_Erf(op, **kwargs): def numba_funcify_Erf(op, **kwargs):
return numba_basic.global_numba_func(erf) @numba_basic.numba_njit
def erf(x):
return math.erf(x)
@numba_basic.numba_njit return erf
def erfc(x):
return math.erfc(x)
@numba_funcify.register(Erfc) @numba_funcify.register(Erfc)
def numba_funcify_Erfc(op, **kwargs): def numba_funcify_Erfc(op, **kwargs):
return numba_basic.global_numba_func(erfc) @numba_basic.numba_njit
def erfc(x):
return math.erfc(x)
return erfc
@numba_funcify.register(Softplus) @numba_funcify.register(Softplus)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论