提交 7e5f7c85 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid calling type_conversion_fn repeatedly on the same variables

上级 d833f039
...@@ -731,6 +731,8 @@ def fgraph_to_python( ...@@ -731,6 +731,8 @@ def fgraph_to_python(
if global_env is None: if global_env is None:
global_env = {} global_env = {}
tipifiyed_vars = set()
body_assigns = [] body_assigns = []
for node in order: for node in order:
compiled_func = op_conversion_fn( compiled_func = op_conversion_fn(
...@@ -742,17 +744,29 @@ def fgraph_to_python( ...@@ -742,17 +744,29 @@ def fgraph_to_python(
global_env[local_compiled_func_name] = compiled_func global_env[local_compiled_func_name] = compiled_func
node_input_names = [] node_input_names = []
for i in node.inputs: for inp in node.inputs:
local_input_name = unique_name(i) local_input_name = unique_name(inp)
is_constant = isinstance(inp, Constant)
input_storage = storage_map.setdefault( input_storage = storage_map.setdefault(
i, [None if not isinstance(i, Constant) else i.data] inp,
[
inp.data # type: ignore[attr-defined]
if is_constant
else None
],
) )
if input_storage[0] is not None or isinstance(i, Constant): if (
is_constant or input_storage[0] is not None
) and inp not in tipifiyed_vars:
# Constants need to be assigned locally and referenced # Constants need to be assigned locally and referenced
# FIXME: This is converting shared variables, but these may change later,
# so this one-time conversion is wasteful / not robust
global_env[local_input_name] = type_conversion_fn( global_env[local_input_name] = type_conversion_fn(
input_storage[0], variable=i, storage=input_storage, **kwargs input_storage[0], variable=inp, storage=input_storage, **kwargs
) )
tipifiyed_vars.add(inp)
# TODO: We could attempt to use the storage arrays directly # TODO: We could attempt to use the storage arrays directly
# Otherwise we're doubling the memory footprint of constants
# E.g. `local_input_name = f"{local_input_name}[0]"` # E.g. `local_input_name = f"{local_input_name}[0]"`
node_input_names.append(local_input_name) node_input_names.append(local_input_name)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论