提交 31b77a2e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Do not always remap storage in fgraph_to_python

上级 643c9734
......@@ -377,6 +377,9 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
@numba_funcify.register(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
if len(op.fgraph.outputs) == 1:
......
......@@ -221,6 +221,9 @@ def numba_funcify_Clip(op, **kwargs):
@numba_funcify.register(Composite)
def numba_funcify_Composite(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
_ = kwargs.pop("storage_map", None)
composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
)
......
......@@ -678,8 +678,6 @@ def fgraph_to_python(
*,
type_conversion_fn: Callable = lambda x, **kwargs: x,
order: Optional[List[Apply]] = None,
input_storage: Optional["InputStorageType"] = None,
output_storage: Optional["OutputStorageType"] = None,
storage_map: Optional["StorageMapType"] = None,
fgraph_name: str = "fgraph_to_python",
global_env: Optional[Dict[Any, Any]] = None,
......@@ -704,10 +702,6 @@ def fgraph_to_python(
``(value: Optional[Any], variable: Variable=None, storage: List[Optional[Any]]=None, **kwargs)``.
order
The `order` argument to `map_storage`.
input_storage
The `input_storage` argument to `map_storage`.
output_storage
The `output_storage` argument to `map_storage`.
storage_map
The `storage_map` argument to `map_storage`.
fgraph_name
......@@ -730,9 +724,9 @@ def fgraph_to_python(
if order is None:
order = fgraph.toposort()
input_storage, output_storage, storage_map = map_storage(
fgraph, order, input_storage, output_storage, storage_map
)
if storage_map is None:
storage_map = {}
unique_name = unique_name_generator([fgraph_name])
......@@ -752,10 +746,13 @@ def fgraph_to_python(
node_input_names = []
for i in node.inputs:
local_input_name = unique_name(i)
if storage_map[i][0] is not None or isinstance(i, Constant):
input_storage = storage_map.setdefault(
i, [None if not isinstance(i, Constant) else i.data]
)
if input_storage[0] is not None or isinstance(i, Constant):
# Constants need to be assigned locally and referenced
global_env[local_input_name] = type_conversion_fn(
storage_map[i][0], variable=i, storage=storage_map[i], **kwargs
input_storage[0], variable=i, storage=input_storage, **kwargs
)
# TODO: We could attempt to use the storage arrays directly
# E.g. `local_input_name = f"{local_input_name}[0]"`
......@@ -763,20 +760,24 @@ def fgraph_to_python(
node_output_names = [unique_name(v) for v in node.outputs]
assign_comment_str = f"{indent(str(node), '# ')}"
assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})"
body_assigns.append(f"{assign_comment_str}\n{assign_str}")
assign_comment_str = f"{indent(str(node), '# ')}"
assign_block_str = f"{assign_comment_str}\n{assign_str}"
body_assigns.append(assign_block_str)
# Handle `Constant`-only outputs (these don't have associated `Apply`
# nodes, so the above isn't applicable)
for out in fgraph.outputs:
if isinstance(out, Constant):
local_input_name = unique_name(out)
if local_input_name not in global_env:
global_env[local_input_name] = type_conversion_fn(
storage_map[out][0],
local_output_name = unique_name(out)
if local_output_name not in global_env:
output_storage = storage_map.setdefault(
out, [None if not isinstance(out, Constant) else out.data]
)
global_env[local_output_name] = type_conversion_fn(
output_storage[0],
variable=out,
storage=storage_map[out],
storage=output_storage,
**kwargs,
)
......@@ -794,7 +795,7 @@ def fgraph_to_python(
fgraph_def_src = dedent(
f"""
def {fgraph_name}({", ".join(fgraph_input_names)}):
{indent(joined_body_assigns, " " * 4)}
{indent(joined_body_assigns, " " * 4)}
return {fgraph_return_src}
"""
).strip()
......
......@@ -176,6 +176,25 @@ def test_fgraph_to_python_constant_outputs():
assert out_py()[0] is y.data
def test_fgraph_to_python_constant_inputs():
x = constant([1.0])
y = vector("y")
out = x + y
out_fg = FunctionGraph(outputs=[out], clone=False)
out_py = fgraph_to_python(out_fg, to_python, storage_map=None)
res = out_py(2.0)
assert res == (3.0,)
storage_map = {out: [None], x: [np.r_[2.0]], y: [None]}
out_py = fgraph_to_python(out_fg, to_python, storage_map=storage_map)
res = out_py(2.0)
assert res == (4.0,)
def test_unique_name_generator():
unique_names = unique_name_generator(["blah"], suffix_sep="_")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论