提交 390a8d67 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug in Composite when multiple outputs are identical

上级 60c39df1
......@@ -4146,6 +4146,21 @@ class Composite(ScalarOp, HasInnerGraph):
"The fgraph to Composite must be exclusively"
" composed of ScalarOp instances."
)
# Clone identical outputs that have been merged
if len(set(fgraph.outputs)) != len(self.outputs):
old_outputs = fgraph.outputs
new_outputs = []
for output in old_outputs:
if output not in new_outputs:
new_outputs.append(output)
else:
node = output.owner
output_idx = node.outputs.index(output)
new_output = node.clone().outputs[output_idx]
new_outputs.append(new_output)
fgraph = FunctionGraph(fgraph.inputs, new_outputs, clone=False)
self._fgraph = fgraph
return self._fgraph
......
......@@ -156,6 +156,17 @@ class TestComposite:
fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
def test_identical_outputs(self):
x, y, z = floats("xyz")
e0 = x + y + z
e1 = x + y + z
e2 = x / y
C = Composite([x, y, z], [e0, e1, e2])
c = C.make_node(x, y, z)
g = FunctionGraph([x, y, z], c.outputs)
fn = make_function(DualLinker().accept(g))
assert fn(1.0, 2.0, 3.0) == [6.0, 6.0, 0.5]
def test_composite_printing(self):
x, y, z = floats("xyz")
e0 = x + y + z
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论