提交 ecd6b49c authored 作者: Thomas Wiecki's avatar Thomas Wiecki 提交者: Brandon T. Willard

Implement work-around for numba issue https://github.com/numba/numba/issues/8215…

Implement work-around for numba issue https://github.com/numba/numba/issues/8215 causing a segfault on M1 when using literal_unroll() with bools. Closes #1023.
上级 d09e222b
from textwrap import indent from textwrap import indent
import numba
import numpy as np import numpy as np
from aesara.link.numba.dispatch import basic as numba_basic from aesara.link.numba.dispatch import basic as numba_basic
...@@ -198,11 +197,13 @@ def makevector({", ".join(input_names)}): ...@@ -198,11 +197,13 @@ def makevector({", ".join(input_names)}):
@numba_funcify.register(Rebroadcast) @numba_funcify.register(Rebroadcast)
def numba_funcify_Rebroadcast(op, **kwargs): def numba_funcify_Rebroadcast(op, **kwargs):
op_axis = tuple(op.axis.items()) # Make sure op_axis only has ints. This way we can avoid literal_unroll
# which causes a segfault, see GH issue https://github.com/numba/numba/issues/8215
op_axis = tuple((axis, int(value)) for axis, value in op.axis.items())
@numba_basic.numba_njit @numba_basic.numba_njit
def rebroadcast(x): def rebroadcast(x):
for axis, value in numba.literal_unroll(op_axis): for axis, value in op_axis:
if value and x.shape[axis] != 1: if value and x.shape[axis] != 1:
raise ValueError( raise ValueError(
("Dimension in Rebroadcast's input was supposed to be 1") ("Dimension in Rebroadcast's input was supposed to be 1")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论