提交 29032f34 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make Numba conversion of MakeVector work with mixed input types

上级 2dbeb781
...@@ -167,14 +167,29 @@ def numba_funcify_Eye(op, **kwargs): ...@@ -167,14 +167,29 @@ def numba_funcify_Eye(op, **kwargs):
@numba_funcify.register(MakeVector) @numba_funcify.register(MakeVector)
def numba_funcify_MakeVector(op, **kwargs): def numba_funcify_MakeVector(op, node, **kwargs):
dtype = np.dtype(op.dtype) dtype = np.dtype(op.dtype)
@numba.njit global_env = {"np": np, "to_scalar": numba_basic.to_scalar}
def makevector(*args):
return np.array([a.item() for a in args], dtype=dtype) unique_names = unique_name_generator(
["np", "to_scalar"],
suffix_sep="_",
)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
def create_list_string(x):
args = ", ".join([f"to_scalar({i})" for i in x] + ([""] if len(x) == 1 else []))
return f"[{args}]"
makevector_def_src = f"""
def makevector({", ".join(input_names)}):
return np.array({create_list_string(input_names)}, dtype=np.{dtype})
"""
makevector_fn = compile_function_src(makevector_def_src, "makevector", global_env)
return makevector return numba.njit(makevector_fn)
@numba_funcify.register(Rebroadcast) @numba_funcify.register(Rebroadcast)
......
...@@ -957,6 +957,17 @@ def test_scalar_Elemwise_Clip(): ...@@ -957,6 +957,17 @@ def test_scalar_Elemwise_Clip():
), ),
config.floatX, config.floatX,
), ),
(
(
set_test_value(aet.dscalar(), np.array(1, dtype=np.float64)),
set_test_value(aet.lscalar(), np.array(3, dtype=np.int32)),
),
"float64",
),
(
(set_test_value(aet.iscalar(), np.array(1, dtype=np.int32)),),
"float64",
),
], ],
) )
def test_MakeVector(vals, dtype): def test_MakeVector(vals, dtype):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论