Unverified 提交 b4522d23 authored 作者: Pablo de Roque's avatar Pablo de Roque 提交者: GitHub

Remove uses of `numba_basic.global_numba_func`

上级 21218d77
......@@ -402,24 +402,22 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
return deepcopyop
@numba_njit
def makeslice(*x):
return slice(*x)
@numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
return global_numba_func(makeslice)
@numba_njit
def makeslice(*x):
return slice(*x)
@numba_njit
def shape(x):
return np.asarray(np.shape(x))
return makeslice
@numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs):
return global_numba_func(shape)
@numba_njit
def shape(x):
return np.asarray(np.shape(x))
return shape
@numba_funcify.register(Shape_i)
......
......@@ -141,17 +141,16 @@ def {scalar_op_fn_name}({', '.join(input_names)}):
)(scalar_op_fn)
@numba_basic.numba_njit
def switch(condition, x, y):
@numba_funcify.register(Switch)
def numba_funcify_Switch(op, node, **kwargs):
@numba_basic.numba_njit
def switch(condition, x, y):
if condition:
return x
else:
return y
@numba_funcify.register(Switch)
def numba_funcify_Switch(op, node, **kwargs):
return numba_basic.global_numba_func(switch)
return switch
def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
......@@ -197,34 +196,32 @@ def numba_funcify_Cast(op, node, **kwargs):
return cast
@numba_basic.numba_njit
def identity(x):
return x
@numba_funcify.register(Identity)
@numba_funcify.register(TypeCastingOp)
def numba_funcify_type_casting(op, **kwargs):
return numba_basic.global_numba_func(identity)
@numba_basic.numba_njit
def identity(x):
return x
return identity
@numba_basic.numba_njit
def clip(_x, _min, _max):
x = numba_basic.to_scalar(_x)
_min_scalar = numba_basic.to_scalar(_min)
_max_scalar = numba_basic.to_scalar(_max)
if x < _min_scalar:
return _min_scalar
elif x > _max_scalar:
return _max_scalar
@numba_funcify.register(Clip)
def numba_funcify_Clip(op, **kwargs):
@numba_basic.numba_njit
def clip(x, min_val, max_val):
x = numba_basic.to_scalar(x)
min_scalar = numba_basic.to_scalar(min_val)
max_scalar = numba_basic.to_scalar(max_val)
if x < min_scalar:
return min_scalar
elif x > max_scalar:
return max_scalar
else:
return x
@numba_funcify.register(Clip)
def numba_funcify_Clip(op, **kwargs):
return numba_basic.global_numba_func(clip)
return clip
@numba_funcify.register(Composite)
......@@ -239,79 +236,72 @@ def numba_funcify_Composite(op, node, **kwargs):
return composite_fn
@numba_basic.numba_njit
def second(x, y):
return y
@numba_funcify.register(Second)
def numba_funcify_Second(op, node, **kwargs):
return numba_basic.global_numba_func(second)
@numba_basic.numba_njit
def second(x, y):
return y
@numba_basic.numba_njit
def reciprocal(x):
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
# `x` is an `int`
return 1 / x
return second
@numba_funcify.register(Reciprocal)
def numba_funcify_Reciprocal(op, node, **kwargs):
return numba_basic.global_numba_func(reciprocal)
@numba_basic.numba_njit
def reciprocal(x):
# TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when
# `x` is an `int`
return 1 / x
@numba_basic.numba_njit
def sigmoid(x):
return 1 / (1 + np.exp(-x))
return reciprocal
@numba_funcify.register(Sigmoid)
def numba_funcify_Sigmoid(op, node, **kwargs):
return numba_basic.global_numba_func(sigmoid)
@numba_basic.numba_njit
def sigmoid(x):
return 1 / (1 + np.exp(-x))
@numba_basic.numba_njit
def gammaln(x):
return math.lgamma(x)
return sigmoid
@numba_funcify.register(GammaLn)
def numba_funcify_GammaLn(op, node, **kwargs):
return numba_basic.global_numba_func(gammaln)
@numba_basic.numba_njit
def gammaln(x):
return math.lgamma(x)
return gammaln
@numba_basic.numba_njit
def logp1mexp(x):
@numba_funcify.register(Log1mexp)
def numba_funcify_Log1mexp(op, node, **kwargs):
@numba_basic.numba_njit
def logp1mexp(x):
if x < np.log(0.5):
return np.log1p(-np.exp(x))
else:
return np.log(-np.expm1(x))
@numba_funcify.register(Log1mexp)
def numba_funcify_Log1mexp(op, node, **kwargs):
return numba_basic.global_numba_func(logp1mexp)
@numba_basic.numba_njit
def erf(x):
return math.erf(x)
return logp1mexp
@numba_funcify.register(Erf)
def numba_funcify_Erf(op, **kwargs):
return numba_basic.global_numba_func(erf)
@numba_basic.numba_njit
def erf(x):
return math.erf(x)
@numba_basic.numba_njit
def erfc(x):
return math.erfc(x)
return erf
@numba_funcify.register(Erfc)
def numba_funcify_Erfc(op, **kwargs):
return numba_basic.global_numba_func(erfc)
@numba_basic.numba_njit
def erfc(x):
return math.erfc(x)
return erfc
@numba_funcify.register(Softplus)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论