提交 67a5eba9 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Improve numba DimShuffle compile time

上级 1827703c
...@@ -539,44 +539,54 @@ def numba_funcify_DimShuffle(op, **kwargs): ...@@ -539,44 +539,54 @@ def numba_funcify_DimShuffle(op, **kwargs):
ndim_new_shape = len(shuffle) + len(augment) ndim_new_shape = len(shuffle) + len(augment)
if len(shuffle) > 0: no_transpose = all(i == j for i, j in enumerate(transposition))
if no_transpose:
@numba_basic.numba_njit @numba_basic.numba_njit
def populate_new_shape(i, j, new_shape, shuffle_shape): def transpose(x):
if i in augment: return x
new_shape = numba_basic.tuple_setitem(new_shape, i, 1)
return j, new_shape
else:
new_shape = numba_basic.tuple_setitem(new_shape, i, shuffle_shape[j])
return j + 1, new_shape
else: else:
@numba_basic.numba_njit
def transpose(x):
return np.transpose(x, transposition)
shape_template = (1,) * ndim_new_shape
# When `len(shuffle) == 0`, the `shuffle_shape[j]` expression above is # When `len(shuffle) == 0`, the `shuffle_shape[j]` expression above is
# is typed as `getitem(Tuple(), int)`, which has no implementation # is typed as `getitem(Tuple(), int)`, which has no implementation
# (since getting an item from an empty sequence doesn't make sense). # (since getting an item from an empty sequence doesn't make sense).
# To avoid this compile-time error, we omit the expression altogether. # To avoid this compile-time error, we omit the expression altogether.
@numba_basic.numba_njit(inline="always") if len(shuffle) > 0:
def populate_new_shape(i, j, new_shape, shuffle_shape):
return j, numba_basic.tuple_setitem(new_shape, i, 1)
if ndim_new_shape > 0: @numba_basic.numba_njit
create_zeros_tuple = numba_basic.create_tuple_creator( def find_shape(array_shape):
lambda _: 0, ndim_new_shape shape = shape_template
) j = 0
for i in range(ndim_new_shape):
if i not in augment:
length = array_shape[j]
shape = numba_basic.tuple_setitem(shape, i, length)
j = j + 1
return shape
else:
@numba_basic.numba_njit @numba_basic.numba_njit
def dimshuffle_inner(x, shuffle): def find_shape(array_shape):
res = np.transpose(x, transposition) return shape_template
shuffle_shape = res.shape[: len(shuffle)]
new_shape = create_zeros_tuple() if ndim_new_shape > 0:
j = 0 @numba_basic.numba_njit
for i in range(len(new_shape)): def dimshuffle_inner(x, shuffle):
j, new_shape = populate_new_shape(i, j, new_shape, shuffle_shape) x = transpose(x)
shuffle_shape = x.shape[: len(shuffle)]
new_shape = find_shape(shuffle_shape)
# FIXME: Numba's `array.reshape` only accepts C arrays. # FIXME: Numba's `array.reshape` only accepts C arrays.
res_reshape = np.reshape(np.ascontiguousarray(res), new_shape) res_reshape = np.reshape(np.ascontiguousarray(x), new_shape)
if not inplace: if not inplace:
return res_reshape.copy() return res_reshape.copy()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论