提交 b56bff5b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify Numba implementation of Alloc

上级 5f5be921
...@@ -68,7 +68,7 @@ def numba_funcify_Alloc(op, node, **kwargs): ...@@ -68,7 +68,7 @@ def numba_funcify_Alloc(op, node, **kwargs):
shape_var_item_names = [f"{name}_item" for name in shape_var_names] shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent( shapes_to_items_src = indent(
"\n".join( "\n".join(
f"{item_name} = to_scalar({shape_name})" f"{item_name} = {shape_name}.item()"
for item_name, shape_name in zip( for item_name, shape_name in zip(
shape_var_item_names, shape_var_names, strict=True shape_var_item_names, shape_var_names, strict=True
) )
...@@ -86,12 +86,11 @@ def numba_funcify_Alloc(op, node, **kwargs): ...@@ -86,12 +86,11 @@ def numba_funcify_Alloc(op, node, **kwargs):
alloc_def_src = f""" alloc_def_src = f"""
def alloc(val, {", ".join(shape_var_names)}): def alloc(val, {", ".join(shape_var_names)}):
val_np = np.asarray(val)
{shapes_to_items_src} {shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)} scalar_shape = {create_tuple_string(shape_var_item_names)}
{check_runtime_broadcast_src} {check_runtime_broadcast_src}
res = np.empty(scalar_shape, dtype=val_np.dtype) res = np.empty(scalar_shape, dtype=val.dtype)
res[...] = val_np res[...] = val
return res return res
""" """
alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env}) alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论