提交 62326376 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make fgraph_to_python process constant FunctionGraph outputs correctly

上级 3d96ee80
...@@ -8,7 +8,7 @@ from collections import Counter, defaultdict ...@@ -8,7 +8,7 @@ from collections import Counter, defaultdict
from keyword import iskeyword from keyword import iskeyword
from operator import itemgetter from operator import itemgetter
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from textwrap import indent from textwrap import dedent, indent
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
...@@ -767,6 +767,19 @@ def fgraph_to_python( ...@@ -767,6 +767,19 @@ def fgraph_to_python(
assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})" assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})"
body_assigns.append(f"{assign_comment_str}\n{assign_str}") body_assigns.append(f"{assign_comment_str}\n{assign_str}")
# Handle `Constant`-only outputs (these don't have associated `Apply`
# nodes, so the above isn't applicable)
for out in fgraph.outputs:
if isinstance(out, Constant):
local_input_name = unique_name(out)
if local_input_name not in global_env:
global_env[local_input_name] = type_conversion_fn(
storage_map[out][0],
variable=out,
storage=storage_map[out],
**kwargs,
)
fgraph_input_names = [unique_name(v) for v in fgraph.inputs] fgraph_input_names = [unique_name(v) for v in fgraph.inputs]
fgraph_output_names = [unique_name(v) for v in fgraph.outputs] fgraph_output_names = [unique_name(v) for v in fgraph.outputs]
joined_body_assigns = indent("\n".join(body_assigns), " ") joined_body_assigns = indent("\n".join(body_assigns), " ")
...@@ -778,11 +791,13 @@ def fgraph_to_python( ...@@ -778,11 +791,13 @@ def fgraph_to_python(
else: else:
fgraph_return_src = ", ".join(fgraph_output_names) fgraph_return_src = ", ".join(fgraph_output_names)
fgraph_def_src = f""" fgraph_def_src = dedent(
def {fgraph_name}({", ".join(fgraph_input_names)}): f"""
{joined_body_assigns} def {fgraph_name}({", ".join(fgraph_input_names)}):
return {fgraph_return_src} {indent(joined_body_assigns, " " * 4)}
return {fgraph_return_src}
""" """
).strip()
if local_env is None: if local_env is None:
local_env = locals() local_env = locals()
......
...@@ -13,6 +13,7 @@ from aesara.link.utils import ( ...@@ -13,6 +13,7 @@ from aesara.link.utils import (
unique_name_generator, unique_name_generator,
) )
from aesara.scalar.basic import Add, float64 from aesara.scalar.basic import Add, float64
from aesara.tensor import constant
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.type import scalar, vector from aesara.tensor.type import scalar, vector
from aesara.tensor.type_other import NoneConst from aesara.tensor.type_other import NoneConst
...@@ -163,6 +164,18 @@ def test_fgraph_to_python_multiline_str(): ...@@ -163,6 +164,18 @@ def test_fgraph_to_python_multiline_str():
) )
def test_fgraph_to_python_constant_outputs():
"""Make sure that constant outputs are handled properly."""
y = constant(1)
out_fg = FunctionGraph([], [y], clone=False)
out_py = fgraph_to_python(out_fg, to_python)
assert out_py()[0] is y.data
def test_unique_name_generator(): def test_unique_name_generator():
unique_names = unique_name_generator(["blah"], suffix_sep="_") unique_names = unique_name_generator(["blah"], suffix_sep="_")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论