提交 f8771c13 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix make_numba_random_fn RandomStateType check

上级 db4e1f3a
......@@ -20,7 +20,6 @@ from aesara.link.utils import (
)
from aesara.tensor.basic import get_vector_length
from aesara.tensor.random.type import RandomStateType
from aesara.tensor.random.var import RandomStateSharedVariable
class RandomStateNumbaType(types.Type):
......@@ -96,7 +95,7 @@ def make_numba_random_fn(node, np_random_func):
The functions generated here add parameter broadcasting and the ``size``
argument to the Numba-supported scalar ``np.random`` functions.
"""
if not isinstance(node.inputs[0], (RandomStateType, RandomStateSharedVariable)):
if not isinstance(node.inputs[0].type, RandomStateType):
raise TypeError("Numba does not support NumPy `Generator`s")
tuple_size = int(get_vector_length(node.inputs[1]))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论