提交 04e8afee authored 作者: Frederic's avatar Frederic

Make import_r as prune_r, this would make the code faster during opt.

上级 ff35387f
...@@ -131,7 +131,8 @@ class FunctionGraph(utils.object2): ...@@ -131,7 +131,8 @@ class FunctionGraph(utils.object2):
self.__setup_r__(input) self.__setup_r__(input)
self.variables.add(input) self.variables.add(input)
self.__import_r__(outputs, reason="init") for output in outputs:
self.__import_r__(output, reason="init")
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
output.clients.append(('output', i)) output.clients.append(('output', i))
...@@ -246,23 +247,23 @@ class FunctionGraph(utils.object2): ...@@ -246,23 +247,23 @@ class FunctionGraph(utils.object2):
return False return False
### import ### ### import ###
def __import_r__(self, variables, reason): def __import_r__(self, variable, reason):
global NullType global NullType
if NullType is None: if NullType is None:
from null_type import NullType from null_type import NullType
# Imports the owners of the variables # Imports the owners of the variables
for apply_node in [r.owner for r in variables if r.owner is not None]: if variable.owner and variable.owner not in self.apply_nodes:
if apply_node not in self.apply_nodes: self.__import__(variable.owner, reason=reason)
self.__import__(apply_node, reason=reason) if (variable.owner is None and
for r in variables: not isinstance(variable, graph.Constant) and
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs: variable not in self.inputs):
if isinstance(r.type, NullType): if isinstance(variable.type, NullType):
raise TypeError("Computation graph contains a NaN. " + raise TypeError("Computation graph contains a NaN. " +
r.type.why_null) variable.type.why_null)
raise MissingInputError("Undeclared input", r) raise MissingInputError("Undeclared input", variable)
if not getattr(r, 'fgraph', None) is self: if not getattr(variable, 'fgraph', None) is self:
self.__setup_r__(r) self.__setup_r__(variable)
self.variables.add(r) self.variables.add(variable)
def __import__(self, apply_node, check=True, reason=None): def __import__(self, apply_node, check=True, reason=None):
# We import the nodes in topological order. We only are interested # We import the nodes in topological order. We only are interested
...@@ -445,7 +446,7 @@ class FunctionGraph(utils.object2): ...@@ -445,7 +446,7 @@ class FunctionGraph(utils.object2):
if r is new_r: if r is new_r:
return return
self.__import_r__([new_r], reason=reason) self.__import_r__(new_r, reason=reason)
self.__add_clients__(new_r, [(node, i)]) self.__add_clients__(new_r, [(node, i)])
prune = self.__remove_clients__(r, [(node, i)], False) prune = self.__remove_clients__(r, [(node, i)], False)
# Precondition: the substitution is semantically valid # Precondition: the substitution is semantically valid
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论