提交 31b77a2e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Do not always remap storage in fgraph_to_python

上级 643c9734
...@@ -377,6 +377,9 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs): ...@@ -377,6 +377,9 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
@numba_funcify.register(OpFromGraph) @numba_funcify.register(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs): def numba_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs)) fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
if len(op.fgraph.outputs) == 1: if len(op.fgraph.outputs) == 1:
......
...@@ -221,6 +221,9 @@ def numba_funcify_Clip(op, **kwargs): ...@@ -221,6 +221,9 @@ def numba_funcify_Clip(op, **kwargs):
@numba_funcify.register(Composite) @numba_funcify.register(Composite)
def numba_funcify_Composite(op, node, **kwargs): def numba_funcify_Composite(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True) signature = create_numba_signature(node, force_scalar=True)
_ = kwargs.pop("storage_map", None)
composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs) numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
) )
......
...@@ -678,8 +678,6 @@ def fgraph_to_python( ...@@ -678,8 +678,6 @@ def fgraph_to_python(
*, *,
type_conversion_fn: Callable = lambda x, **kwargs: x, type_conversion_fn: Callable = lambda x, **kwargs: x,
order: Optional[List[Apply]] = None, order: Optional[List[Apply]] = None,
input_storage: Optional["InputStorageType"] = None,
output_storage: Optional["OutputStorageType"] = None,
storage_map: Optional["StorageMapType"] = None, storage_map: Optional["StorageMapType"] = None,
fgraph_name: str = "fgraph_to_python", fgraph_name: str = "fgraph_to_python",
global_env: Optional[Dict[Any, Any]] = None, global_env: Optional[Dict[Any, Any]] = None,
...@@ -704,10 +702,6 @@ def fgraph_to_python( ...@@ -704,10 +702,6 @@ def fgraph_to_python(
``(value: Optional[Any], variable: Variable=None, storage: List[Optional[Any]]=None, **kwargs)``. ``(value: Optional[Any], variable: Variable=None, storage: List[Optional[Any]]=None, **kwargs)``.
order order
The `order` argument to `map_storage`. The `order` argument to `map_storage`.
input_storage
The `input_storage` argument to `map_storage`.
output_storage
The `output_storage` argument to `map_storage`.
storage_map storage_map
The `storage_map` argument to `map_storage`. The `storage_map` argument to `map_storage`.
fgraph_name fgraph_name
...@@ -730,9 +724,9 @@ def fgraph_to_python( ...@@ -730,9 +724,9 @@ def fgraph_to_python(
if order is None: if order is None:
order = fgraph.toposort() order = fgraph.toposort()
input_storage, output_storage, storage_map = map_storage(
fgraph, order, input_storage, output_storage, storage_map if storage_map is None:
) storage_map = {}
unique_name = unique_name_generator([fgraph_name]) unique_name = unique_name_generator([fgraph_name])
...@@ -752,10 +746,13 @@ def fgraph_to_python( ...@@ -752,10 +746,13 @@ def fgraph_to_python(
node_input_names = [] node_input_names = []
for i in node.inputs: for i in node.inputs:
local_input_name = unique_name(i) local_input_name = unique_name(i)
if storage_map[i][0] is not None or isinstance(i, Constant): input_storage = storage_map.setdefault(
i, [None if not isinstance(i, Constant) else i.data]
)
if input_storage[0] is not None or isinstance(i, Constant):
# Constants need to be assigned locally and referenced # Constants need to be assigned locally and referenced
global_env[local_input_name] = type_conversion_fn( global_env[local_input_name] = type_conversion_fn(
storage_map[i][0], variable=i, storage=storage_map[i], **kwargs input_storage[0], variable=i, storage=input_storage, **kwargs
) )
# TODO: We could attempt to use the storage arrays directly # TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"` # E.g. `local_input_name = f"{local_input_name}[0]"`
...@@ -763,20 +760,24 @@ def fgraph_to_python( ...@@ -763,20 +760,24 @@ def fgraph_to_python(
node_output_names = [unique_name(v) for v in node.outputs] node_output_names = [unique_name(v) for v in node.outputs]
assign_comment_str = f"{indent(str(node), '# ')}"
assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})" assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})"
body_assigns.append(f"{assign_comment_str}\n{assign_str}") assign_comment_str = f"{indent(str(node), '# ')}"
assign_block_str = f"{assign_comment_str}\n{assign_str}"
body_assigns.append(assign_block_str)
# Handle `Constant`-only outputs (these don't have associated `Apply` # Handle `Constant`-only outputs (these don't have associated `Apply`
# nodes, so the above isn't applicable) # nodes, so the above isn't applicable)
for out in fgraph.outputs: for out in fgraph.outputs:
if isinstance(out, Constant): if isinstance(out, Constant):
local_input_name = unique_name(out) local_output_name = unique_name(out)
if local_input_name not in global_env: if local_output_name not in global_env:
global_env[local_input_name] = type_conversion_fn( output_storage = storage_map.setdefault(
storage_map[out][0], out, [None if not isinstance(out, Constant) else out.data]
)
global_env[local_output_name] = type_conversion_fn(
output_storage[0],
variable=out, variable=out,
storage=storage_map[out], storage=output_storage,
**kwargs, **kwargs,
) )
...@@ -794,7 +795,7 @@ def fgraph_to_python( ...@@ -794,7 +795,7 @@ def fgraph_to_python(
fgraph_def_src = dedent( fgraph_def_src = dedent(
f""" f"""
def {fgraph_name}({", ".join(fgraph_input_names)}): def {fgraph_name}({", ".join(fgraph_input_names)}):
{indent(joined_body_assigns, " " * 4)} {indent(joined_body_assigns, " " * 4)}
return {fgraph_return_src} return {fgraph_return_src}
""" """
).strip() ).strip()
......
...@@ -176,6 +176,25 @@ def test_fgraph_to_python_constant_outputs(): ...@@ -176,6 +176,25 @@ def test_fgraph_to_python_constant_outputs():
assert out_py()[0] is y.data assert out_py()[0] is y.data
def test_fgraph_to_python_constant_inputs():
x = constant([1.0])
y = vector("y")
out = x + y
out_fg = FunctionGraph(outputs=[out], clone=False)
out_py = fgraph_to_python(out_fg, to_python, storage_map=None)
res = out_py(2.0)
assert res == (3.0,)
storage_map = {out: [None], x: [np.r_[2.0]], y: [None]}
out_py = fgraph_to_python(out_fg, to_python, storage_map=storage_map)
res = out_py(2.0)
assert res == (4.0,)
def test_unique_name_generator(): def test_unique_name_generator():
unique_names = unique_name_generator(["blah"], suffix_sep="_") unique_names = unique_name_generator(["blah"], suffix_sep="_")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论