corrected handling of outputs and orphans

上级 158b51fb
...@@ -90,7 +90,7 @@ class Env(graph.Graph): ...@@ -90,7 +90,7 @@ class Env(graph.Graph):
# Set of all the results that are not an output of an op in the subgraph but # Set of all the results that are not an output of an op in the subgraph but
# are an input of an op in the subgraph. # are an input of an op in the subgraph.
# e.g. z for inputs=(x, y) and outputs=(x + (y - z),) # e.g. z for inputs=(x, y) and outputs=(x + (y - z),)
self._orphans = set() self._orphans = set(outputs)
# Maps results to ops that use them: # Maps results to ops that use them:
# if op.inputs[i] == v then (op, i) in self._clients[v] # if op.inputs[i] == v then (op, i) in self._clients[v]
...@@ -112,6 +112,7 @@ class Env(graph.Graph): ...@@ -112,6 +112,7 @@ class Env(graph.Graph):
def add_output(self, output): def add_output(self, output):
self.outputs.add(output) self.outputs.add(output)
self.orphans.add(output)
self.__import_r__([output]) self.__import_r__([output])
def clients(self, r): def clients(self, r):
...@@ -364,6 +365,7 @@ class Env(graph.Graph): ...@@ -364,6 +365,7 @@ class Env(graph.Graph):
self._ops.add(op) self._ops.add(op)
self._results.update(op.outputs) self._results.update(op.outputs)
self._orphans.difference_update(op.outputs)
for i, input in enumerate(op.inputs): for i, input in enumerate(op.inputs):
self.__add_clients__(input, [(op, i)]) self.__add_clients__(input, [(op, i)])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论