提交 0b0ad512 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Implement FunctionGraph.__contains__ for easy membership determination

上级 a4eb9873
......@@ -329,3 +329,18 @@ class TestFunctionGraph:
fg.add_client(var4, (var3.owner, 0))
fg.check_integrity()
def test_contains(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = op1(var2, var1)
var4 = op2(var3, var2)
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
assert var1 in fg
assert var3 in fg
assert var3.owner in fg
assert var5 in fg
assert var5.owner in fg
"""
fg.py: fg stands for FunctionGraph
Contains the FunctionGraph class and exception
types that it can raise.
"""
"""A container for specifying and manipulating a graph with distinct inputs and outputs."""
import time
from collections import OrderedDict
from io import StringIO
import theano
from theano import config
from theano.gof import graph, toolbox, utils
from theano.gof import toolbox, utils
from theano.gof.graph import Apply, Constant, Variable
from theano.gof.graph import as_string as graph_as_string
from theano.gof.graph import clone as clone_graph
from theano.gof.graph import clone_get_equiv, io_toposort
from theano.gof.graph import ops as ops_between
from theano.gof.graph import variables as variables_between
from theano.gof.utils import TestValueError, get_variable_trace_string
from theano.misc.ordered_set import OrderedSet
......@@ -113,7 +114,7 @@ class FunctionGraph(utils.object2):
"""
if clone:
inputs, outputs = graph.clone(inputs, outputs)
inputs, outputs = clone_graph(inputs, outputs)
if not isinstance(inputs, list):
raise TypeError("Argument `inputs` should be a list")
......@@ -350,7 +351,7 @@ class FunctionGraph(utils.object2):
self.import_node(var.owner, reason=reason)
elif (
var.owner is None
and not isinstance(var, graph.Constant)
and not isinstance(var, Constant)
and var not in self.inputs
):
from theano.gof.null_type import NullType
......@@ -383,7 +384,7 @@ class FunctionGraph(utils.object2):
# in new nodes, so we use all variables we know of as if they were the input set.
# (the functions in the graph module only use the input set to
# know where to stop going down)
new_nodes = graph.io_toposort(self.variables, apply_node.outputs)
new_nodes = io_toposort(self.variables, apply_node.outputs)
if check:
for node in new_nodes:
......@@ -394,7 +395,7 @@ class FunctionGraph(utils.object2):
raise Exception(f"{var} is already owned by another fgraph")
if (
var.owner is None
and not isinstance(var, graph.Constant)
and not isinstance(var, Constant)
and var not in self.inputs
):
# Standard error message
......@@ -688,7 +689,7 @@ class FunctionGraph(utils.object2):
ords = self.orderings()
order = graph.io_toposort(fg.inputs, fg.outputs, ords)
order = io_toposort(fg.inputs, fg.outputs, ords)
return order
......@@ -743,7 +744,7 @@ class FunctionGraph(utils.object2):
Call this for a diagnosis if things go awry.
"""
nodes = graph.ops(self.inputs, self.outputs)
nodes = ops_between(self.inputs, self.outputs)
if self.apply_nodes != nodes:
missing = nodes.difference(self.apply_nodes)
excess = self.apply_nodes.difference(nodes)
......@@ -761,7 +762,7 @@ class FunctionGraph(utils.object2):
raise Exception(
f"Inconsistent clients list {(node, i)} in {clients}"
)
variables = set(graph.variables(self.inputs, self.outputs))
variables = set(variables_between(self.inputs, self.outputs))
if set(self.variables) != variables:
missing = variables.difference(self.variables)
excess = self.variables.difference(variables)
......@@ -774,7 +775,7 @@ class FunctionGraph(utils.object2):
if (
variable.owner is None
and variable not in self.inputs
and not isinstance(variable, graph.Constant)
and not isinstance(variable, Constant)
):
raise Exception("Undeclared input.", variable)
if variable.fgraph is not self:
......@@ -799,7 +800,7 @@ class FunctionGraph(utils.object2):
)
def __repr__(self):
return f"FunctionGraph({', '.join(graph.as_string(self.inputs, self.outputs))})"
return f"FunctionGraph({', '.join(graph_as_string(self.inputs, self.outputs))})"
def clone(self, check_integrity=True):
"""
......@@ -824,7 +825,7 @@ class FunctionGraph(utils.object2):
equiv: dict
A dict that map old node to new node.
"""
equiv = graph.clone_get_equiv(self.inputs, self.outputs)
equiv = clone_get_equiv(self.inputs, self.outputs)
if check_integrity:
self.check_integrity()
......@@ -865,3 +866,11 @@ class FunctionGraph(utils.object2):
for feature in self._features:
if hasattr(feature, "unpickle"):
feature.unpickle(self)
def __contains__(self, item):
if isinstance(item, Variable):
return item in self.variables
elif isinstance(item, Apply):
return item in self.apply_nodes
else:
raise TypeError()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论