提交 d9fe1974 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Use static shape info in numba DimShuffle

上级 db7ae4dc
...@@ -531,7 +531,7 @@ def numba_funcify_CAReduce(op, node, **kwargs): ...@@ -531,7 +531,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
@numba_funcify.register(DimShuffle) @numba_funcify.register(DimShuffle)
def numba_funcify_DimShuffle(op, **kwargs): def numba_funcify_DimShuffle(op, node, **kwargs):
shuffle = tuple(op.shuffle) shuffle = tuple(op.shuffle)
transposition = tuple(op.transposition) transposition = tuple(op.transposition)
augment = tuple(op.augment) augment = tuple(op.augment)
...@@ -560,16 +560,26 @@ def numba_funcify_DimShuffle(op, **kwargs): ...@@ -560,16 +560,26 @@ def numba_funcify_DimShuffle(op, **kwargs):
# To avoid this compile-time error, we omit the expression altogether. # To avoid this compile-time error, we omit the expression altogether.
if len(shuffle) > 0: if len(shuffle) > 0:
@numba_basic.numba_njit # Use the statically known shape if available
def find_shape(array_shape): if all(length is not None for length in node.outputs[0].type.shape):
shape = shape_template shape = node.outputs[0].type.shape
j = 0
for i in range(ndim_new_shape): @numba_basic.numba_njit
if i not in augment: def find_shape(array_shape):
length = array_shape[j] return shape
shape = numba_basic.tuple_setitem(shape, i, length)
j = j + 1 else:
return shape
@numba_basic.numba_njit
def find_shape(array_shape):
shape = shape_template
j = 0
for i in range(ndim_new_shape):
if i not in augment:
length = array_shape[j]
shape = numba_basic.tuple_setitem(shape, i, length)
j = j + 1
return shape
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论