提交 2b12a455 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Inline all one-line Numba functions without varargs

上级 e352e04f
......@@ -479,7 +479,7 @@ def {inplace_elemwise_fn_name}({input_signature_str}):
inplace_elemwise_fn = compile_function_src(
inplace_elemwise_src, inplace_elemwise_fn_name, inplace_global_env
)
return numba.njit(inplace_elemwise_fn)
return numba.njit(inline="always")(inplace_elemwise_fn)
return elemwise_fn
......@@ -777,13 +777,13 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
# NumPy scalars, so we need two separate Numba functions for each case.
if node.outputs[0].type.ndim == 0:
# TODO: Do we really need to compile a pass-through function like this?
@numba.njit
@numba.njit(inline="always")
def deepcopyop(x):
return x
else:
@numba.njit
@numba.njit(inline="always")
def deepcopyop(x):
return x.copy()
......@@ -812,7 +812,7 @@ def numba_funcify_MakeVector(op, **kwargs):
@numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs):
@numba.njit
@numba.njit(inline="always")
def shape(x):
return np.asarray(np.shape(x))
......@@ -823,7 +823,7 @@ def numba_funcify_Shape(op, **kwargs):
def numba_funcify_Shape_i(op, **kwargs):
i = op.i
@numba.njit
@numba.njit(inline="always")
def shape_i(x):
return np.shape(x)[i]
......@@ -832,7 +832,7 @@ def numba_funcify_Shape_i(op, **kwargs):
@numba_funcify.register(TensorFromScalar)
def numba_funcify_TensorFromScalar(op, **kwargs):
@numba.njit
@numba.njit(inline="always")
def tensor_from_scalar(x):
return np.array(x)
......@@ -841,7 +841,7 @@ def numba_funcify_TensorFromScalar(op, **kwargs):
@numba_funcify.register(ScalarFromTensor)
def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba.njit
@numba.njit(inline="always")
def scalar_from_tensor(x):
return x.item()
......@@ -920,7 +920,7 @@ def alloc(val, {", ".join(shape_var_names)}):
def numba_funcify_AllocDiag(op, **kwargs):
offset = op.offset
@numba.njit
@numba.njit(inline="always")
def allocdiag(v):
return np.diag(v, k=offset)
......@@ -929,7 +929,7 @@ def numba_funcify_AllocDiag(op, **kwargs):
@numba_funcify.register(Second)
def numba_funcify_Second(op, node, **kwargs):
@numba.njit
@numba.njit(inline="always")
def second(x, y):
return y
......@@ -962,10 +962,9 @@ def numba_funcify_DimShuffle(op, **kwargs):
# is typed as `getitem(Tuple(), int)`, which has no implementation
# (since getting an item from an empty sequence doesn't make sense).
# To avoid this compile-time error, we omit the expression altogether.
@numba.njit
@numba.njit(inline="always")
def populate_new_shape(i, j, new_shape, shuffle_shape):
new_shape = tuple_setitem(new_shape, i, 1)
return j, new_shape
return j, tuple_setitem(new_shape, i, 1)
@numba.njit
def dimshuffle_inner(x, shuffle):
......@@ -1047,7 +1046,7 @@ def numba_funcify_Cast(op, node, **kwargs):
dtype = np.dtype(op.o_type.dtype)
dtype = numba.np.numpy_support.from_dtype(dtype)
@numba.njit
@numba.njit(inline="always")
def cast(x):
return direct_cast(x, dtype)
......@@ -1058,10 +1057,9 @@ def numba_funcify_Cast(op, node, **kwargs):
def numba_funcify_Reshape(op, **kwargs):
ndim = op.ndim
@numba.njit
@numba.njit(inline="always")
def reshape(x, shape):
new_shape = to_fixed_tuple(shape, ndim)
return np.reshape(x, new_shape)
return np.reshape(x, to_fixed_tuple(shape, ndim))
return reshape
......@@ -1079,7 +1077,7 @@ def numba_funcify_SpecifyShape(op, **kwargs):
@numba_funcify.register(Identity)
@numba_funcify.register(ViewOp)
def numba_funcify_ViewOp(op, **kwargs):
@numba.njit
@numba.njit(inline="always")
def viewop(x):
return x
......@@ -1103,7 +1101,7 @@ def numba_funcify_ARange(op, **kwargs):
dtype = np.dtype(op.dtype)
dtype = numba.np.numpy_support.from_dtype(dtype)
@numba.njit
@numba.njit(inline="always")
def arange(start, stop, step):
return np.arange(
to_scalar(start), to_scalar(stop), to_scalar(step), dtype=dtype
......@@ -1135,7 +1133,7 @@ def numba_funcify_ExtractDiag(op, **kwargs):
# axis1 = op.axis1
# axis2 = op.axis2
@numba.njit
@numba.njit(inline="always")
def extract_diag(x):
return np.diag(x, k=offset)
......@@ -1147,7 +1145,7 @@ def numba_funcify_Eye(op, **kwargs):
dtype = np.dtype(op.dtype)
dtype = numba.np.numpy_support.from_dtype(dtype)
@numba.njit
@numba.njit(inline="always")
def eye(N, M, k):
return np.eye(to_scalar(N), to_scalar(M), to_scalar(k), dtype=dtype)
......@@ -1156,7 +1154,7 @@ def numba_funcify_Eye(op, **kwargs):
@numba_funcify.register(Bartlett)
def numba_funcify_Bartlett(op, **kwargs):
@numba.njit
@numba.njit(inline="always")
def bartlett(x):
return np.bartlett(to_scalar(x))
......@@ -1360,13 +1358,13 @@ def numba_funcify_Repeat(op, node, **kwargs):
if repeats_ndim == 0:
@numba.njit
@numba.njit(inline="always")
def repeatop(x, repeats):
return np.repeat(x, repeats.item())
else:
@numba.njit
@numba.njit(inline="always")
def repeatop(x, repeats):
return np.repeat(x, repeats)
......@@ -1391,7 +1389,7 @@ def numba_funcify_Unique(op, node, **kwargs):
if not use_python:
@numba.njit
@numba.njit(inline="always")
def unique(x):
return np.unique(x)
......@@ -1481,7 +1479,7 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
else:
@numba.njit
@numba.njit(inline="always")
def searchsorted(a, v):
return np.searchsorted(a, v, side)
......@@ -1516,7 +1514,7 @@ def numba_funcify_Dot(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit
@numba.njit(inline="always")
def dot(x, y):
return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype)
......@@ -1670,7 +1668,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit
@numba.njit(inline="always")
def cholesky(a):
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)
......@@ -1749,7 +1747,7 @@ def numba_funcify_Det(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit
@numba.njit(inline="always")
def det(x):
return direct_cast(np.linalg.det(inputs_cast(x)), out_dtype)
......@@ -1800,7 +1798,7 @@ def numba_funcify_Eigh(op, node, **kwargs):
else:
@numba.njit
@numba.njit(inline="always")
def eigh(x):
return np.linalg.eigh(x)
......@@ -1813,7 +1811,7 @@ def numba_funcify_MatrixInverse(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit
@numba.njit(inline="always")
def matrix_inverse(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
......@@ -1875,10 +1873,9 @@ def numba_funcify_QRFull(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit
@numba.njit(inline="always")
def qr_full(x):
res = np.linalg.qr(inputs_cast(x))
return res
return np.linalg.qr(inputs_cast(x))
return qr_full
......@@ -1911,7 +1908,7 @@ def numba_funcify_SVD(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit
@numba.njit(inline="always")
def svd(x):
return np.linalg.svd(inputs_cast(x), full_matrices)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论