提交 62326376 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make fgraph_to_python process constant FunctionGraph outputs correctly

上级 3d96ee80
......@@ -8,7 +8,7 @@ from collections import Counter, defaultdict
from keyword import iskeyword
from operator import itemgetter
from tempfile import NamedTemporaryFile
from textwrap import indent
from textwrap import dedent, indent
from typing import (
TYPE_CHECKING,
Any,
......@@ -767,6 +767,19 @@ def fgraph_to_python(
assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})"
body_assigns.append(f"{assign_comment_str}\n{assign_str}")
# Handle `Constant`-only outputs (these don't have associated `Apply`
# nodes, so the above isn't applicable)
for out in fgraph.outputs:
if isinstance(out, Constant):
local_input_name = unique_name(out)
if local_input_name not in global_env:
global_env[local_input_name] = type_conversion_fn(
storage_map[out][0],
variable=out,
storage=storage_map[out],
**kwargs,
)
fgraph_input_names = [unique_name(v) for v in fgraph.inputs]
fgraph_output_names = [unique_name(v) for v in fgraph.outputs]
joined_body_assigns = indent("\n".join(body_assigns), " ")
......@@ -778,11 +791,13 @@ def fgraph_to_python(
else:
fgraph_return_src = ", ".join(fgraph_output_names)
fgraph_def_src = f"""
def {fgraph_name}({", ".join(fgraph_input_names)}):
{joined_body_assigns}
return {fgraph_return_src}
fgraph_def_src = dedent(
f"""
def {fgraph_name}({", ".join(fgraph_input_names)}):
{indent(joined_body_assigns, " " * 4)}
return {fgraph_return_src}
"""
).strip()
if local_env is None:
local_env = locals()
......
......@@ -13,6 +13,7 @@ from aesara.link.utils import (
unique_name_generator,
)
from aesara.scalar.basic import Add, float64
from aesara.tensor import constant
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.type import scalar, vector
from aesara.tensor.type_other import NoneConst
......@@ -163,6 +164,18 @@ def test_fgraph_to_python_multiline_str():
)
def test_fgraph_to_python_constant_outputs():
"""Make sure that constant outputs are handled properly."""
y = constant(1)
out_fg = FunctionGraph([], [y], clone=False)
out_py = fgraph_to_python(out_fg, to_python)
assert out_py()[0] is y.data
def test_unique_name_generator():
unique_names = unique_name_generator(["blah"], suffix_sep="_")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论