提交 58046078 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Brandon T. Willard

Fix typing error in numba MakeVector impl

上级 9665120e
......@@ -184,7 +184,7 @@ def numba_funcify_Eye(op, **kwargs):
def numba_funcify_MakeVector(op, node, **kwargs):
dtype = np.dtype(op.dtype)
global_env = {"np": np, "to_scalar": numba_basic.to_scalar}
global_env = {"np": np, "to_scalar": numba_basic.to_scalar, "dtype": dtype}
unique_names = unique_name_generator(
["np", "to_scalar"],
......@@ -198,7 +198,7 @@ def numba_funcify_MakeVector(op, node, **kwargs):
makevector_def_src = f"""
def makevector({", ".join(input_names)}):
return np.array({create_list_string(input_names)}, dtype=np.{dtype})
return np.array({create_list_string(input_names)}, dtype=dtype)
"""
makevector_fn = compile_function_src(
......
......@@ -998,6 +998,10 @@ def test_scalar_Elemwise_Clip():
(set_test_value(at.iscalar(), np.array(1, dtype=np.int32)),),
"float64",
),
(
(set_test_value(at.scalar(dtype=bool), True),),
bool,
),
],
)
def test_MakeVector(vals, dtype):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论