提交 d9fe1974 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Use static shape info in numba DimShuffle

上级 db7ae4dc
...@@ -531,7 +531,7 @@ def numba_funcify_CAReduce(op, node, **kwargs): ...@@ -531,7 +531,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
@numba_funcify.register(DimShuffle) @numba_funcify.register(DimShuffle)
def numba_funcify_DimShuffle(op, **kwargs): def numba_funcify_DimShuffle(op, node, **kwargs):
shuffle = tuple(op.shuffle) shuffle = tuple(op.shuffle)
transposition = tuple(op.transposition) transposition = tuple(op.transposition)
augment = tuple(op.augment) augment = tuple(op.augment)
...@@ -560,6 +560,16 @@ def numba_funcify_DimShuffle(op, **kwargs): ...@@ -560,6 +560,16 @@ def numba_funcify_DimShuffle(op, **kwargs):
# To avoid this compile-time error, we omit the expression altogether. # To avoid this compile-time error, we omit the expression altogether.
if len(shuffle) > 0: if len(shuffle) > 0:
# Use the statically known shape if available
if all(length is not None for length in node.outputs[0].type.shape):
shape = node.outputs[0].type.shape
@numba_basic.numba_njit
def find_shape(array_shape):
return shape
else:
@numba_basic.numba_njit @numba_basic.numba_njit
def find_shape(array_shape): def find_shape(array_shape):
shape = shape_template shape = shape_template
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论