提交 bab68456 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename local variables in aesara.link.utils.fgraph_to_python

上级 a2671ea4
...@@ -682,13 +682,13 @@ def fgraph_to_python( ...@@ -682,13 +682,13 @@ def fgraph_to_python(
body_assigns = [] body_assigns = []
for node in order: for node in order:
jax_func = op_conversion_fn( compiled_func = op_conversion_fn(
node.op, node=node, storage_map=storage_map, **kwargs node.op, node=node, storage_map=storage_map, **kwargs
) )
# Create a local alias with a unique name # Create a local alias with a unique name
local_jax_func_name = unique_name(jax_func) local_compiled_func_name = unique_name(compiled_func)
global_env[local_jax_func_name] = jax_func global_env[local_compiled_func_name] = compiled_func
node_input_names = [] node_input_names = []
for i in node.inputs: for i in node.inputs:
...@@ -705,7 +705,7 @@ def fgraph_to_python( ...@@ -705,7 +705,7 @@ def fgraph_to_python(
node_output_names = [unique_name(v) for v in node.outputs] node_output_names = [unique_name(v) for v in node.outputs]
body_assigns.append( body_assigns.append(
f"{', '.join(node_output_names)} = {local_jax_func_name}({', '.join(node_input_names)})" f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})"
) )
fgraph_input_names = [unique_name(v) for v in fgraph.inputs] fgraph_input_names = [unique_name(v) for v in fgraph.inputs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论