提交 0b0ad512 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Implement FunctionGraph.__contains__ for easy membership determination

上级 a4eb9873
...@@ -329,3 +329,18 @@ class TestFunctionGraph: ...@@ -329,3 +329,18 @@ class TestFunctionGraph:
fg.add_client(var4, (var3.owner, 0)) fg.add_client(var4, (var3.owner, 0))
fg.check_integrity() fg.check_integrity()
def test_contains(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = op1(var2, var1)
var4 = op2(var3, var2)
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
assert var1 in fg
assert var3 in fg
assert var3.owner in fg
assert var5 in fg
assert var5.owner in fg
""" """A container for specifying and manipulating a graph with distinct inputs and outputs."""
fg.py: fg stands for FunctionGraph
Contains the FunctionGraph class and exception
types that it can raise.
"""
import time import time
from collections import OrderedDict from collections import OrderedDict
from io import StringIO from io import StringIO
import theano import theano
from theano import config from theano import config
from theano.gof import graph, toolbox, utils from theano.gof import toolbox, utils
from theano.gof.graph import Apply, Constant, Variable
from theano.gof.graph import as_string as graph_as_string
from theano.gof.graph import clone as clone_graph
from theano.gof.graph import clone_get_equiv, io_toposort
from theano.gof.graph import ops as ops_between
from theano.gof.graph import variables as variables_between
from theano.gof.utils import TestValueError, get_variable_trace_string from theano.gof.utils import TestValueError, get_variable_trace_string
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
...@@ -113,7 +114,7 @@ class FunctionGraph(utils.object2): ...@@ -113,7 +114,7 @@ class FunctionGraph(utils.object2):
""" """
if clone: if clone:
inputs, outputs = graph.clone(inputs, outputs) inputs, outputs = clone_graph(inputs, outputs)
if not isinstance(inputs, list): if not isinstance(inputs, list):
raise TypeError("Argument `inputs` should be a list") raise TypeError("Argument `inputs` should be a list")
...@@ -350,7 +351,7 @@ class FunctionGraph(utils.object2): ...@@ -350,7 +351,7 @@ class FunctionGraph(utils.object2):
self.import_node(var.owner, reason=reason) self.import_node(var.owner, reason=reason)
elif ( elif (
var.owner is None var.owner is None
and not isinstance(var, graph.Constant) and not isinstance(var, Constant)
and var not in self.inputs and var not in self.inputs
): ):
from theano.gof.null_type import NullType from theano.gof.null_type import NullType
...@@ -383,7 +384,7 @@ class FunctionGraph(utils.object2): ...@@ -383,7 +384,7 @@ class FunctionGraph(utils.object2):
# in new nodes, so we use all variables we know of as if they were the input set. # in new nodes, so we use all variables we know of as if they were the input set.
# (the functions in the graph module only use the input set to # (the functions in the graph module only use the input set to
# know where to stop going down) # know where to stop going down)
new_nodes = graph.io_toposort(self.variables, apply_node.outputs) new_nodes = io_toposort(self.variables, apply_node.outputs)
if check: if check:
for node in new_nodes: for node in new_nodes:
...@@ -394,7 +395,7 @@ class FunctionGraph(utils.object2): ...@@ -394,7 +395,7 @@ class FunctionGraph(utils.object2):
raise Exception(f"{var} is already owned by another fgraph") raise Exception(f"{var} is already owned by another fgraph")
if ( if (
var.owner is None var.owner is None
and not isinstance(var, graph.Constant) and not isinstance(var, Constant)
and var not in self.inputs and var not in self.inputs
): ):
# Standard error message # Standard error message
...@@ -688,7 +689,7 @@ class FunctionGraph(utils.object2): ...@@ -688,7 +689,7 @@ class FunctionGraph(utils.object2):
ords = self.orderings() ords = self.orderings()
order = graph.io_toposort(fg.inputs, fg.outputs, ords) order = io_toposort(fg.inputs, fg.outputs, ords)
return order return order
...@@ -743,7 +744,7 @@ class FunctionGraph(utils.object2): ...@@ -743,7 +744,7 @@ class FunctionGraph(utils.object2):
Call this for a diagnosis if things go awry. Call this for a diagnosis if things go awry.
""" """
nodes = graph.ops(self.inputs, self.outputs) nodes = ops_between(self.inputs, self.outputs)
if self.apply_nodes != nodes: if self.apply_nodes != nodes:
missing = nodes.difference(self.apply_nodes) missing = nodes.difference(self.apply_nodes)
excess = self.apply_nodes.difference(nodes) excess = self.apply_nodes.difference(nodes)
...@@ -761,7 +762,7 @@ class FunctionGraph(utils.object2): ...@@ -761,7 +762,7 @@ class FunctionGraph(utils.object2):
raise Exception( raise Exception(
f"Inconsistent clients list {(node, i)} in {clients}" f"Inconsistent clients list {(node, i)} in {clients}"
) )
variables = set(graph.variables(self.inputs, self.outputs)) variables = set(variables_between(self.inputs, self.outputs))
if set(self.variables) != variables: if set(self.variables) != variables:
missing = variables.difference(self.variables) missing = variables.difference(self.variables)
excess = self.variables.difference(variables) excess = self.variables.difference(variables)
...@@ -774,7 +775,7 @@ class FunctionGraph(utils.object2): ...@@ -774,7 +775,7 @@ class FunctionGraph(utils.object2):
if ( if (
variable.owner is None variable.owner is None
and variable not in self.inputs and variable not in self.inputs
and not isinstance(variable, graph.Constant) and not isinstance(variable, Constant)
): ):
raise Exception("Undeclared input.", variable) raise Exception("Undeclared input.", variable)
if variable.fgraph is not self: if variable.fgraph is not self:
...@@ -799,7 +800,7 @@ class FunctionGraph(utils.object2): ...@@ -799,7 +800,7 @@ class FunctionGraph(utils.object2):
) )
def __repr__(self): def __repr__(self):
return f"FunctionGraph({', '.join(graph.as_string(self.inputs, self.outputs))})" return f"FunctionGraph({', '.join(graph_as_string(self.inputs, self.outputs))})"
def clone(self, check_integrity=True): def clone(self, check_integrity=True):
""" """
...@@ -824,7 +825,7 @@ class FunctionGraph(utils.object2): ...@@ -824,7 +825,7 @@ class FunctionGraph(utils.object2):
equiv: dict equiv: dict
A dict that map old node to new node. A dict that map old node to new node.
""" """
equiv = graph.clone_get_equiv(self.inputs, self.outputs) equiv = clone_get_equiv(self.inputs, self.outputs)
if check_integrity: if check_integrity:
self.check_integrity() self.check_integrity()
...@@ -865,3 +866,11 @@ class FunctionGraph(utils.object2): ...@@ -865,3 +866,11 @@ class FunctionGraph(utils.object2):
for feature in self._features: for feature in self._features:
if hasattr(feature, "unpickle"): if hasattr(feature, "unpickle"):
feature.unpickle(self) feature.unpickle(self)
def __contains__(self, item):
if isinstance(item, Variable):
return item in self.variables
elif isinstance(item, Apply):
return item in self.apply_nodes
else:
raise TypeError()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论