提交 0b8d47fe authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Move Numba RandomVariable Generator error to make_numba_random_fn

上级 e8d5e510
......@@ -96,6 +96,8 @@ def make_numba_random_fn(node, np_random_func):
The functions generated here add parameter broadcasting and the ``size``
argument to the Numba-supported scalar ``np.random`` functions.
"""
if not isinstance(node.inputs[0], (RandomStateType, RandomStateSharedVariable)):
raise TypeError("Numba does not support NumPy `Generator`s")
tuple_size = int(get_vector_length(node.inputs[1]))
size_dims = tuple_size - max(i.ndim for i in node.inputs[3:])
......@@ -215,9 +217,6 @@ def numba_funcify_RandomVariable(op, node, **kwargs):
name = op.name
np_random_func = getattr(np.random, name)
if not isinstance(node.inputs[0], (RandomStateType, RandomStateSharedVariable)):
raise TypeError("Numba does not support NumPy `Generator`s")
return make_numba_random_fn(node, np_random_func)
......@@ -271,9 +270,6 @@ def {np_random_fn_name}({np_input_names}):
@numba_funcify.register(aer.NegBinomialRV)
def numba_funcify_NegBinomialRV(op, node, **kwargs):
if not isinstance(node.inputs[0], (RandomStateType, RandomStateSharedVariable)):
raise TypeError("Numba does not support NumPy `Generator`s")
return make_numba_random_fn(node, np.random.negative_binomial)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论