提交 5710f950 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Reintroduce inline printing for short single output Composites

上级 1f2542eb
...@@ -4143,6 +4143,7 @@ class Composite(ScalarInnerGraphOp): ...@@ -4143,6 +4143,7 @@ class Composite(ScalarInnerGraphOp):
def __init__(self, inputs, outputs, name="Composite"): def __init__(self, inputs, outputs, name="Composite"):
self.name = name self.name = name
self._name = None
# We need to clone the graph as sometimes its nodes already # We need to clone the graph as sometimes its nodes already
# contain a reference to an fgraph. As we want the Composite # contain a reference to an fgraph. As we want the Composite
# to be pickable, we can't have reference to fgraph. # to be pickable, we can't have reference to fgraph.
...@@ -4189,7 +4190,26 @@ class Composite(ScalarInnerGraphOp): ...@@ -4189,7 +4190,26 @@ class Composite(ScalarInnerGraphOp):
super().__init__() super().__init__()
def __str__(self): def __str__(self):
return self.name if self._name is not None:
return self._name
# Rename internal variables
for i, r in enumerate(self.fgraph.inputs):
r.name = f"i{int(i)}"
for i, r in enumerate(self.fgraph.outputs):
r.name = f"o{int(i)}"
io = set(self.fgraph.inputs + self.fgraph.outputs)
for i, r in enumerate(self.fgraph.variables):
if r not in io and len(self.fgraph.clients[r]) > 1:
r.name = f"t{int(i)}"
if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10:
self._name = "Composite{...}"
else:
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
self._name = f"Composite{{{outputs_str}}}"
return self._name
def make_new_inplace(self, output_types_preference=None, name=None): def make_new_inplace(self, output_types_preference=None, name=None):
""" """
......
...@@ -183,7 +183,7 @@ class TestComposite: ...@@ -183,7 +183,7 @@ class TestComposite:
make_function(DualLinker().accept(g)) make_function(DualLinker().accept(g))
assert str(g) == ( assert str(g) == (
"FunctionGraph(*1 -> Composite(x, y, z), *1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)" "FunctionGraph(*1 -> Composite{...}(x, y, z), *1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
) )
def test_non_scalar_error(self): def test_non_scalar_error(self):
......
...@@ -654,7 +654,7 @@ def test_debugprint_compiled_fn(): ...@@ -654,7 +654,7 @@ def test_debugprint_compiled_fn():
Inner graphs: Inner graphs:
forall_inplace,cpu,scan_fn} [id A] forall_inplace,cpu,scan_fn} [id A]
← Elemwise{Composite} [id I] (inner_out_sit_sot-0) ← Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0)
├─ TensorConstant{0} [id J] ├─ TensorConstant{0} [id J]
├─ Subtensor{int64, int64, uint8} [id K] ├─ Subtensor{int64, int64, uint8} [id K]
│ ├─ *2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0) │ ├─ *2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
...@@ -665,13 +665,13 @@ def test_debugprint_compiled_fn(): ...@@ -665,13 +665,13 @@ def test_debugprint_compiled_fn():
│ └─ ScalarConstant{0} [id Q] │ └─ ScalarConstant{0} [id Q]
└─ TensorConstant{1} [id R] └─ TensorConstant{1} [id R]
Elemwise{Composite} [id I] Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I]
← Switch [id S] ← Switch [id S] 'o0'
├─ LT [id T] ├─ LT [id T]
│ ├─ <int64> [id U] │ ├─ i0 [id U]
│ └─ <float64> [id V] │ └─ i1 [id V]
├─ <int64> [id W] ├─ i2 [id W]
└─ <int64> [id U] └─ i0 [id U]
""" """
output_str = debugprint(out, file="str", print_op_info=True) output_str = debugprint(out, file="str", print_op_info=True)
......
...@@ -274,7 +274,7 @@ def test_debugprint(): ...@@ -274,7 +274,7 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
exp_res = dedent( exp_res = dedent(
r""" r"""
Elemwise{Composite} 4 Elemwise{Composite{(i2 + (i0 - i1))}} 4
├─ InplaceDimShuffle{x,0} v={0: [0]} 3 ├─ InplaceDimShuffle{x,0} v={0: [0]} 3
│ └─ CGemv{inplace} d={0: [0]} 2 │ └─ CGemv{inplace} d={0: [0]} 2
│ ├─ AllocEmpty{dtype='float64'} 1 │ ├─ AllocEmpty{dtype='float64'} 1
...@@ -289,12 +289,12 @@ def test_debugprint(): ...@@ -289,12 +289,12 @@ def test_debugprint():
Inner graphs: Inner graphs:
Elemwise{Composite} Elemwise{Composite{(i2 + (i0 - i1))}}
← add ← add 'o0'
├─ <float64> ├─ i2
└─ sub └─ sub
├─ <float64> ├─ i0
└─ <float64> └─ i1
""" """
).lstrip() ).lstrip()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论