提交 45642af3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Allow newlines in __str__ output printed by fgraph_to_python

上级 c028e387
...@@ -760,9 +760,9 @@ def fgraph_to_python( ...@@ -760,9 +760,9 @@ def fgraph_to_python(
node_output_names = [unique_name(v) for v in node.outputs] node_output_names = [unique_name(v) for v in node.outputs]
body_assigns.append( assign_comment_str = f"{indent(str(node), '# ')}"
f"# {node}\n{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})" assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})"
) body_assigns.append(f"{assign_comment_str}\n{assign_str}")
fgraph_input_names = [unique_name(v) for v in fgraph.inputs] fgraph_input_names = [unique_name(v) for v in fgraph.inputs]
fgraph_output_names = [unique_name(v) for v in fgraph.outputs] fgraph_output_names = [unique_name(v) for v in fgraph.outputs]
......
import inspect
from functools import singledispatch from functools import singledispatch
import numpy as np import numpy as np
...@@ -108,6 +109,56 @@ def test_fgraph_to_python_once(): ...@@ -108,6 +109,56 @@ def test_fgraph_to_python_once():
assert op2.called == 2 assert op2.called == 2
def test_fgraph_to_python_multiline_str():
"""Make sure that multiline `__str__` values are supported by `fgraph_to_python`."""
x = vector("x")
y = vector("y")
class TestOp(Op):
def __init__(self):
super().__init__()
def make_node(self, *args):
return Apply(self, list(args), [x.type() for x in args])
def perform(self, inputs, outputs):
for i, inp in enumerate(inputs):
outputs[i][0] = inp[0]
def __str__(self):
return "Test\nOp()"
@to_python.register(TestOp)
def to_python_TestOp(op, **kwargs):
def func(*args, op=op):
return list(args)
return func
op1 = TestOp()
op2 = TestOp()
q, r = op1(x, y)
outs = op2(q + r, q + r)
out_fg = FunctionGraph([x, y], outs, clone=False)
assert len(out_fg.outputs) == 2
out_py = fgraph_to_python(out_fg, to_python)
out_py_src = inspect.getsource(out_py)
assert (
"""
# Elemwise{add,no_inplace}(Test
# Op().0, Test
# Op().1)
"""
in out_py_src
)
def test_unique_name_generator(): def test_unique_name_generator():
unique_names = unique_name_generator(["blah"], suffix_sep="_") unique_names = unique_name_generator(["blah"], suffix_sep="_")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论