提交 5710f950 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Reintroduce inline printing for short single output Composites

上级 1f2542eb
......@@ -4143,6 +4143,7 @@ class Composite(ScalarInnerGraphOp):
def __init__(self, inputs, outputs, name="Composite"):
self.name = name
self._name = None
# We need to clone the graph as sometimes its nodes already
# contain a reference to an fgraph. As we want the Composite
# to be pickable, we can't have reference to fgraph.
......@@ -4189,7 +4190,26 @@ class Composite(ScalarInnerGraphOp):
super().__init__()
def __str__(self):
return self.name
if self._name is not None:
return self._name
# Rename internal variables
for i, r in enumerate(self.fgraph.inputs):
r.name = f"i{int(i)}"
for i, r in enumerate(self.fgraph.outputs):
r.name = f"o{int(i)}"
io = set(self.fgraph.inputs + self.fgraph.outputs)
for i, r in enumerate(self.fgraph.variables):
if r not in io and len(self.fgraph.clients[r]) > 1:
r.name = f"t{int(i)}"
if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10:
self._name = "Composite{...}"
else:
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
self._name = f"Composite{{{outputs_str}}}"
return self._name
def make_new_inplace(self, output_types_preference=None, name=None):
"""
......
......@@ -183,7 +183,7 @@ class TestComposite:
make_function(DualLinker().accept(g))
assert str(g) == (
"FunctionGraph(*1 -> Composite(x, y, z), *1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
"FunctionGraph(*1 -> Composite{...}(x, y, z), *1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
)
def test_non_scalar_error(self):
......
......@@ -654,7 +654,7 @@ def test_debugprint_compiled_fn():
Inner graphs:
forall_inplace,cpu,scan_fn} [id A]
← Elemwise{Composite} [id I] (inner_out_sit_sot-0)
← Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0)
├─ TensorConstant{0} [id J]
├─ Subtensor{int64, int64, uint8} [id K]
│ ├─ *2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
......@@ -665,13 +665,13 @@ def test_debugprint_compiled_fn():
│ └─ ScalarConstant{0} [id Q]
└─ TensorConstant{1} [id R]
Elemwise{Composite} [id I]
← Switch [id S]
Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I]
← Switch [id S] 'o0'
├─ LT [id T]
│ ├─ <int64> [id U]
│ └─ <float64> [id V]
├─ <int64> [id W]
└─ <int64> [id U]
│ ├─ i0 [id U]
│ └─ i1 [id V]
├─ i2 [id W]
└─ i0 [id U]
"""
output_str = debugprint(out, file="str", print_op_info=True)
......
......@@ -274,7 +274,7 @@ def test_debugprint():
s = s.getvalue()
exp_res = dedent(
r"""
Elemwise{Composite} 4
Elemwise{Composite{(i2 + (i0 - i1))}} 4
├─ InplaceDimShuffle{x,0} v={0: [0]} 3
│ └─ CGemv{inplace} d={0: [0]} 2
│ ├─ AllocEmpty{dtype='float64'} 1
......@@ -289,12 +289,12 @@ def test_debugprint():
Inner graphs:
Elemwise{Composite}
← add
├─ <float64>
Elemwise{Composite{(i2 + (i0 - i1))}}
← add 'o0'
├─ i2
└─ sub
├─ <float64>
└─ <float64>
├─ i0
└─ i1
"""
).lstrip()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论