提交 2b12a455 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Inline all one-line Numba functions without varargs

上级 e352e04f
...@@ -479,7 +479,7 @@ def {inplace_elemwise_fn_name}({input_signature_str}): ...@@ -479,7 +479,7 @@ def {inplace_elemwise_fn_name}({input_signature_str}):
inplace_elemwise_fn = compile_function_src( inplace_elemwise_fn = compile_function_src(
inplace_elemwise_src, inplace_elemwise_fn_name, inplace_global_env inplace_elemwise_src, inplace_elemwise_fn_name, inplace_global_env
) )
return numba.njit(inplace_elemwise_fn) return numba.njit(inline="always")(inplace_elemwise_fn)
return elemwise_fn return elemwise_fn
...@@ -777,13 +777,13 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs): ...@@ -777,13 +777,13 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
# NumPy scalars, so we need two separate Numba functions for each case. # NumPy scalars, so we need two separate Numba functions for each case.
if node.outputs[0].type.ndim == 0: if node.outputs[0].type.ndim == 0:
# TODO: Do we really need to compile a pass-through function like this? # TODO: Do we really need to compile a pass-through function like this?
@numba.njit @numba.njit(inline="always")
def deepcopyop(x): def deepcopyop(x):
return x return x
else: else:
@numba.njit @numba.njit(inline="always")
def deepcopyop(x): def deepcopyop(x):
return x.copy() return x.copy()
...@@ -812,7 +812,7 @@ def numba_funcify_MakeVector(op, **kwargs): ...@@ -812,7 +812,7 @@ def numba_funcify_MakeVector(op, **kwargs):
@numba_funcify.register(Shape) @numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs): def numba_funcify_Shape(op, **kwargs):
@numba.njit @numba.njit(inline="always")
def shape(x): def shape(x):
return np.asarray(np.shape(x)) return np.asarray(np.shape(x))
...@@ -823,7 +823,7 @@ def numba_funcify_Shape(op, **kwargs): ...@@ -823,7 +823,7 @@ def numba_funcify_Shape(op, **kwargs):
def numba_funcify_Shape_i(op, **kwargs): def numba_funcify_Shape_i(op, **kwargs):
i = op.i i = op.i
@numba.njit @numba.njit(inline="always")
def shape_i(x): def shape_i(x):
return np.shape(x)[i] return np.shape(x)[i]
...@@ -832,7 +832,7 @@ def numba_funcify_Shape_i(op, **kwargs): ...@@ -832,7 +832,7 @@ def numba_funcify_Shape_i(op, **kwargs):
@numba_funcify.register(TensorFromScalar) @numba_funcify.register(TensorFromScalar)
def numba_funcify_TensorFromScalar(op, **kwargs): def numba_funcify_TensorFromScalar(op, **kwargs):
@numba.njit @numba.njit(inline="always")
def tensor_from_scalar(x): def tensor_from_scalar(x):
return np.array(x) return np.array(x)
...@@ -841,7 +841,7 @@ def numba_funcify_TensorFromScalar(op, **kwargs): ...@@ -841,7 +841,7 @@ def numba_funcify_TensorFromScalar(op, **kwargs):
@numba_funcify.register(ScalarFromTensor) @numba_funcify.register(ScalarFromTensor)
def numba_funcify_ScalarFromTensor(op, **kwargs): def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba.njit @numba.njit(inline="always")
def scalar_from_tensor(x): def scalar_from_tensor(x):
return x.item() return x.item()
...@@ -920,7 +920,7 @@ def alloc(val, {", ".join(shape_var_names)}): ...@@ -920,7 +920,7 @@ def alloc(val, {", ".join(shape_var_names)}):
def numba_funcify_AllocDiag(op, **kwargs): def numba_funcify_AllocDiag(op, **kwargs):
offset = op.offset offset = op.offset
@numba.njit @numba.njit(inline="always")
def allocdiag(v): def allocdiag(v):
return np.diag(v, k=offset) return np.diag(v, k=offset)
...@@ -929,7 +929,7 @@ def numba_funcify_AllocDiag(op, **kwargs): ...@@ -929,7 +929,7 @@ def numba_funcify_AllocDiag(op, **kwargs):
@numba_funcify.register(Second) @numba_funcify.register(Second)
def numba_funcify_Second(op, node, **kwargs): def numba_funcify_Second(op, node, **kwargs):
@numba.njit @numba.njit(inline="always")
def second(x, y): def second(x, y):
return y return y
...@@ -962,10 +962,9 @@ def numba_funcify_DimShuffle(op, **kwargs): ...@@ -962,10 +962,9 @@ def numba_funcify_DimShuffle(op, **kwargs):
# is typed as `getitem(Tuple(), int)`, which has no implementation # is typed as `getitem(Tuple(), int)`, which has no implementation
# (since getting an item from an empty sequence doesn't make sense). # (since getting an item from an empty sequence doesn't make sense).
# To avoid this compile-time error, we omit the expression altogether. # To avoid this compile-time error, we omit the expression altogether.
@numba.njit @numba.njit(inline="always")
def populate_new_shape(i, j, new_shape, shuffle_shape): def populate_new_shape(i, j, new_shape, shuffle_shape):
new_shape = tuple_setitem(new_shape, i, 1) return j, tuple_setitem(new_shape, i, 1)
return j, new_shape
@numba.njit @numba.njit
def dimshuffle_inner(x, shuffle): def dimshuffle_inner(x, shuffle):
...@@ -1047,7 +1046,7 @@ def numba_funcify_Cast(op, node, **kwargs): ...@@ -1047,7 +1046,7 @@ def numba_funcify_Cast(op, node, **kwargs):
dtype = np.dtype(op.o_type.dtype) dtype = np.dtype(op.o_type.dtype)
dtype = numba.np.numpy_support.from_dtype(dtype) dtype = numba.np.numpy_support.from_dtype(dtype)
@numba.njit @numba.njit(inline="always")
def cast(x): def cast(x):
return direct_cast(x, dtype) return direct_cast(x, dtype)
...@@ -1058,10 +1057,9 @@ def numba_funcify_Cast(op, node, **kwargs): ...@@ -1058,10 +1057,9 @@ def numba_funcify_Cast(op, node, **kwargs):
def numba_funcify_Reshape(op, **kwargs): def numba_funcify_Reshape(op, **kwargs):
ndim = op.ndim ndim = op.ndim
@numba.njit @numba.njit(inline="always")
def reshape(x, shape): def reshape(x, shape):
new_shape = to_fixed_tuple(shape, ndim) return np.reshape(x, to_fixed_tuple(shape, ndim))
return np.reshape(x, new_shape)
return reshape return reshape
...@@ -1079,7 +1077,7 @@ def numba_funcify_SpecifyShape(op, **kwargs): ...@@ -1079,7 +1077,7 @@ def numba_funcify_SpecifyShape(op, **kwargs):
@numba_funcify.register(Identity) @numba_funcify.register(Identity)
@numba_funcify.register(ViewOp) @numba_funcify.register(ViewOp)
def numba_funcify_ViewOp(op, **kwargs): def numba_funcify_ViewOp(op, **kwargs):
@numba.njit @numba.njit(inline="always")
def viewop(x): def viewop(x):
return x return x
...@@ -1103,7 +1101,7 @@ def numba_funcify_ARange(op, **kwargs): ...@@ -1103,7 +1101,7 @@ def numba_funcify_ARange(op, **kwargs):
dtype = np.dtype(op.dtype) dtype = np.dtype(op.dtype)
dtype = numba.np.numpy_support.from_dtype(dtype) dtype = numba.np.numpy_support.from_dtype(dtype)
@numba.njit @numba.njit(inline="always")
def arange(start, stop, step): def arange(start, stop, step):
return np.arange( return np.arange(
to_scalar(start), to_scalar(stop), to_scalar(step), dtype=dtype to_scalar(start), to_scalar(stop), to_scalar(step), dtype=dtype
...@@ -1135,7 +1133,7 @@ def numba_funcify_ExtractDiag(op, **kwargs): ...@@ -1135,7 +1133,7 @@ def numba_funcify_ExtractDiag(op, **kwargs):
# axis1 = op.axis1 # axis1 = op.axis1
# axis2 = op.axis2 # axis2 = op.axis2
@numba.njit @numba.njit(inline="always")
def extract_diag(x): def extract_diag(x):
return np.diag(x, k=offset) return np.diag(x, k=offset)
...@@ -1147,7 +1145,7 @@ def numba_funcify_Eye(op, **kwargs): ...@@ -1147,7 +1145,7 @@ def numba_funcify_Eye(op, **kwargs):
dtype = np.dtype(op.dtype) dtype = np.dtype(op.dtype)
dtype = numba.np.numpy_support.from_dtype(dtype) dtype = numba.np.numpy_support.from_dtype(dtype)
@numba.njit @numba.njit(inline="always")
def eye(N, M, k): def eye(N, M, k):
return np.eye(to_scalar(N), to_scalar(M), to_scalar(k), dtype=dtype) return np.eye(to_scalar(N), to_scalar(M), to_scalar(k), dtype=dtype)
...@@ -1156,7 +1154,7 @@ def numba_funcify_Eye(op, **kwargs): ...@@ -1156,7 +1154,7 @@ def numba_funcify_Eye(op, **kwargs):
@numba_funcify.register(Bartlett) @numba_funcify.register(Bartlett)
def numba_funcify_Bartlett(op, **kwargs): def numba_funcify_Bartlett(op, **kwargs):
@numba.njit @numba.njit(inline="always")
def bartlett(x): def bartlett(x):
return np.bartlett(to_scalar(x)) return np.bartlett(to_scalar(x))
...@@ -1360,13 +1358,13 @@ def numba_funcify_Repeat(op, node, **kwargs): ...@@ -1360,13 +1358,13 @@ def numba_funcify_Repeat(op, node, **kwargs):
if repeats_ndim == 0: if repeats_ndim == 0:
@numba.njit @numba.njit(inline="always")
def repeatop(x, repeats): def repeatop(x, repeats):
return np.repeat(x, repeats.item()) return np.repeat(x, repeats.item())
else: else:
@numba.njit @numba.njit(inline="always")
def repeatop(x, repeats): def repeatop(x, repeats):
return np.repeat(x, repeats) return np.repeat(x, repeats)
...@@ -1391,7 +1389,7 @@ def numba_funcify_Unique(op, node, **kwargs): ...@@ -1391,7 +1389,7 @@ def numba_funcify_Unique(op, node, **kwargs):
if not use_python: if not use_python:
@numba.njit @numba.njit(inline="always")
def unique(x): def unique(x):
return np.unique(x) return np.unique(x)
...@@ -1481,7 +1479,7 @@ def numba_funcify_Searchsorted(op, node, **kwargs): ...@@ -1481,7 +1479,7 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
else: else:
@numba.njit @numba.njit(inline="always")
def searchsorted(a, v): def searchsorted(a, v):
return np.searchsorted(a, v, side) return np.searchsorted(a, v, side)
...@@ -1516,7 +1514,7 @@ def numba_funcify_Dot(op, node, **kwargs): ...@@ -1516,7 +1514,7 @@ def numba_funcify_Dot(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype) inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit @numba.njit(inline="always")
def dot(x, y): def dot(x, y):
return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype) return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype)
...@@ -1670,7 +1668,7 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -1670,7 +1668,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
inputs_cast = int_to_float_fn(node.inputs, out_dtype) inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit @numba.njit(inline="always")
def cholesky(a): def cholesky(a):
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype) return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)
...@@ -1749,7 +1747,7 @@ def numba_funcify_Det(op, node, **kwargs): ...@@ -1749,7 +1747,7 @@ def numba_funcify_Det(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype) inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit @numba.njit(inline="always")
def det(x): def det(x):
return direct_cast(np.linalg.det(inputs_cast(x)), out_dtype) return direct_cast(np.linalg.det(inputs_cast(x)), out_dtype)
...@@ -1800,7 +1798,7 @@ def numba_funcify_Eigh(op, node, **kwargs): ...@@ -1800,7 +1798,7 @@ def numba_funcify_Eigh(op, node, **kwargs):
else: else:
@numba.njit @numba.njit(inline="always")
def eigh(x): def eigh(x):
return np.linalg.eigh(x) return np.linalg.eigh(x)
...@@ -1813,7 +1811,7 @@ def numba_funcify_MatrixInverse(op, node, **kwargs): ...@@ -1813,7 +1811,7 @@ def numba_funcify_MatrixInverse(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype) inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit @numba.njit(inline="always")
def matrix_inverse(x): def matrix_inverse(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype) return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
...@@ -1875,10 +1873,9 @@ def numba_funcify_QRFull(op, node, **kwargs): ...@@ -1875,10 +1873,9 @@ def numba_funcify_QRFull(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype) inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit @numba.njit(inline="always")
def qr_full(x): def qr_full(x):
res = np.linalg.qr(inputs_cast(x)) return np.linalg.qr(inputs_cast(x))
return res
return qr_full return qr_full
...@@ -1911,7 +1908,7 @@ def numba_funcify_SVD(op, node, **kwargs): ...@@ -1911,7 +1908,7 @@ def numba_funcify_SVD(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype) inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit @numba.njit(inline="always")
def svd(x): def svd(x):
return np.linalg.svd(inputs_cast(x), full_matrices) return np.linalg.svd(inputs_cast(x), full_matrices)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论