提交 d9fe1974 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Use static shape info in numba DimShuffle

上级 db7ae4dc
......@@ -531,7 +531,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
@numba_funcify.register(DimShuffle)
def numba_funcify_DimShuffle(op, **kwargs):
def numba_funcify_DimShuffle(op, node, **kwargs):
shuffle = tuple(op.shuffle)
transposition = tuple(op.transposition)
augment = tuple(op.augment)
......@@ -560,16 +560,26 @@ def numba_funcify_DimShuffle(op, **kwargs):
# To avoid this compile-time error, we omit the expression altogether.
if len(shuffle) > 0:
@numba_basic.numba_njit
def find_shape(array_shape):
shape = shape_template
j = 0
for i in range(ndim_new_shape):
if i not in augment:
length = array_shape[j]
shape = numba_basic.tuple_setitem(shape, i, length)
j = j + 1
return shape
# Use the statically known shape if available
if all(length is not None for length in node.outputs[0].type.shape):
shape = node.outputs[0].type.shape
@numba_basic.numba_njit
def find_shape(array_shape):
return shape
else:
@numba_basic.numba_njit
def find_shape(array_shape):
shape = shape_template
j = 0
for i in range(ndim_new_shape):
if i not in augment:
length = array_shape[j]
shape = numba_basic.tuple_setitem(shape, i, length)
j = j + 1
return shape
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论