提交 58046078 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Brandon T. Willard

Fix typing error in numba MakeVector impl

上级 9665120e
...@@ -184,7 +184,7 @@ def numba_funcify_Eye(op, **kwargs): ...@@ -184,7 +184,7 @@ def numba_funcify_Eye(op, **kwargs):
def numba_funcify_MakeVector(op, node, **kwargs): def numba_funcify_MakeVector(op, node, **kwargs):
dtype = np.dtype(op.dtype) dtype = np.dtype(op.dtype)
global_env = {"np": np, "to_scalar": numba_basic.to_scalar} global_env = {"np": np, "to_scalar": numba_basic.to_scalar, "dtype": dtype}
unique_names = unique_name_generator( unique_names = unique_name_generator(
["np", "to_scalar"], ["np", "to_scalar"],
...@@ -198,7 +198,7 @@ def numba_funcify_MakeVector(op, node, **kwargs): ...@@ -198,7 +198,7 @@ def numba_funcify_MakeVector(op, node, **kwargs):
makevector_def_src = f""" makevector_def_src = f"""
def makevector({", ".join(input_names)}): def makevector({", ".join(input_names)}):
return np.array({create_list_string(input_names)}, dtype=np.{dtype}) return np.array({create_list_string(input_names)}, dtype=dtype)
""" """
makevector_fn = compile_function_src( makevector_fn = compile_function_src(
......
...@@ -998,6 +998,10 @@ def test_scalar_Elemwise_Clip(): ...@@ -998,6 +998,10 @@ def test_scalar_Elemwise_Clip():
(set_test_value(at.iscalar(), np.array(1, dtype=np.int32)),), (set_test_value(at.iscalar(), np.array(1, dtype=np.int32)),),
"float64", "float64",
), ),
(
(set_test_value(at.scalar(dtype=bool), True),),
bool,
),
], ],
) )
def test_MakeVector(vals, dtype): def test_MakeVector(vals, dtype):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论