提交 0d0cc128 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

renamed FunctionGraph.nodes -> FunctionGraph.apply_nodes

上级 e5a29ece
......@@ -675,7 +675,7 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
features=[equivalence_tracker])
if not accept_inplace:
for node in fgraph.nodes:
for node in fgraph.apply_nodes:
if getattr(node.op, 'destroy_map', None):
raise TypeError("Graph must not contain inplace operations",
node)
......
......@@ -131,7 +131,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs)
fgraph = gof.fg.FunctionGraph(inputs, outputs)
for node in fgraph.nodes:
for node in fgraph.apply_nodes:
if getattr(node.op, 'destroy_map', None):
if not accept_inplace:
raise TypeError("Graph must not contain inplace operations", node, node.op)
......
......@@ -520,7 +520,7 @@ class ProfileMode(Mode):
print "Profile of Theano functions memory:"
print "(This check only the output of each apply node. It don't check the temporary memory used by the op in the apply node.)"
nb_skipped = 0
for fgraph,nodes_mem in fct_memory.iteritems():
for fgraph, nodes_mem in fct_memory.iteritems():
size_sum=sum([sum(val) for key,val in nodes_mem.iteritems()])
if size_sum < min_memory_size:
nb_skipped += 1
......
......@@ -711,7 +711,7 @@ if 0: # old code still to be ported from ProfileMode
var_mem[out]=v
print
print "Profile of Theano functions memory:"
for fgraph,nodes_mem in fct_memory.iteritems():
for fgraph, nodes_mem in fct_memory.iteritems():
print "Theano fct:", [fct for fct in fct_call.keys() if fct.maker.fgraph is fgraph][0].name
size_sum=sum([sum(val) for key,val in nodes_mem.iteritems()])
print " Max without gc, inplace and view (KB)",size_sum/1024
......
......@@ -133,6 +133,7 @@ def _contains_cycle(inputs, outputs, orderings):
# into fifo_queue
# TODO: does the order of the roots in the fifo_queue matter?
while lifo_queue:
# using pop rather than pop_left makes this queue LIFO
# using a LIFO queue makes the search DFS
......
......@@ -79,10 +79,10 @@ class FunctionGraph(utils.object2):
# so I probably am) this should be a set.
self._features = []
# All nodes in the subgraph defined by inputs and outputs are cached in nodes
self.nodes = set()
# All apply nodes in the subgraph defined by inputs and outputs are cached in this field
self.apply_nodes = set()
# Ditto for variables
# Ditto for variable nodes
self.variables = set()
self.inputs = list(inputs)
......@@ -151,13 +151,13 @@ class FunctionGraph(utils.object2):
nodes and variables. If there are no features, this should set
them back to what they were originally.
"""
for node in self.nodes:
del node.fgraph
del node.deps
for apply_node in self.apply_nodes:
del apply_node.fgraph
del apply_node.deps
for variable in self.variables:
del variable.fgraph
del variable.clients
self.nodes = set()
self.apply_nodes = set()
self.variables = set()
self.inputs = None
self.outputs = None
......@@ -215,11 +215,11 @@ class FunctionGraph(utils.object2):
if NullType is None:
from null_type import NullType
# Imports the owners of the variables
r_owner_done = set(self.nodes)
for node in [r.owner for r in variables if r.owner is not None]:
if node not in r_owner_done:
r_owner_done.add(node)
self.__import__(node)
r_owner_done = set(self.apply_nodes)
for apply_node in [r.owner for r in variables if r.owner is not None]:
if apply_node not in r_owner_done:
r_owner_done.add(apply_node)
self.__import__(apply_node)
for r in variables:
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
if isinstance(r.type,NullType):
......@@ -229,7 +229,9 @@ class FunctionGraph(utils.object2):
self.__setup_r__(r)
self.variables.add(r)
def __import__(self, node, check = True):
def __import__(self, apply_node, check = True):
node = apply_node
# We import the nodes in topological order. We only are interested
# in new nodes, so we use all variables we know of as if they were the input set.
# (the functions in the graph module only use the input set to
......@@ -311,9 +313,9 @@ class FunctionGraph(utils.object2):
r)
for node in new_nodes:
assert node not in self.nodes
assert node not in self.apply_nodes
self.__setup_node__(node)
self.nodes.add(node)
self.apply_nodes.add(node)
for output in node.outputs:
self.__setup_r__(output)
self.variables.add(output)
......@@ -336,8 +338,9 @@ class FunctionGraph(utils.object2):
if not r.clients and r in self.variables:
self.variables.remove(r)
def __prune__(self, node):
if node not in self.nodes:
def __prune__(self, apply_node):
node = apply_node
if node not in self.apply_nodes:
raise Exception("%s does not belong to this FunctionGraph and cannot be pruned." % node)
assert node.fgraph is self
# If node's outputs have no clients, removes it from the graph
......@@ -348,7 +351,7 @@ class FunctionGraph(utils.object2):
# Cannot prune an op which is an output or used somewhere
if self.clients(output) or output in self.outputs: #output in self.outputs or self.clients(output):
return
self.nodes.remove(node)
self.apply_nodes.remove(node)
self.variables.difference_update(node.outputs)
self.execute_callbacks('on_prune', node)
......@@ -532,12 +535,12 @@ class FunctionGraph(utils.object2):
{node: predecessors} where predecessors is a list of nodes
that should be computed before the key node.
"""
if len(self.nodes) < 2:
if len(self.apply_nodes) < 2:
# optimization
# when there are 0 or 1 nodes, no sorting is necessary
# This special case happens a lot because the OpWiseCLinker produces
# 1-element graphs.
return list(self.nodes)
return list(self.apply_nodes)
fg = self
ords = self.orderings()
order = graph.io_toposort(fg.inputs, fg.outputs, ords)
......@@ -569,26 +572,31 @@ class FunctionGraph(utils.object2):
"""WRITEME Same as len(self.clients(r))."""
return len(self.clients(r))
# def edge(self, r):
# return r in self.inputs or r in self.orphans
def nodes_getter(self):
warnings.warn("FunctionGraph.nodes is deprecated, it has been renamed 'apply_nodes'",
stacklevel=2)
return self.apply_nodes
def nodes_setter(self, value):
warnings.warn("FunctionGraph.nodes is deprecated, it has been renamed 'apply_nodes'",
stacklevel=2)
self.apply_nodes = value
def nodes_deleter(self):
warnings.warn("FunctionGraph.nodes is deprecated, it has been renamed 'apply_nodes'",
stacklevel=2)
del self.apply_nodes
# def follow(self, r):
# node = r.owner
# if self.edge(r):
# return None
# else:
# if node is None:
# raise Exception("what the fuck")
# return node.inputs
nodes = property(nodes_getter, nodes_setter, nodes_deleter)
def check_integrity(self):
"""WRITEME
Call this for a diagnosis if things go awry.
"""
nodes = graph.ops(self.inputs, self.outputs)
if self.nodes != nodes:
missing = nodes.difference(self.nodes)
excess = self.nodes.difference(nodes)
if self.apply_nodes != nodes:
missing = nodes.difference(self.apply_nodes)
excess = self.apply_nodes.difference(nodes)
raise Exception("The nodes are inappropriately cached. missing, in excess: ", missing, excess)
for node in nodes:
if node.fgraph is not self:
......
......@@ -162,7 +162,7 @@ class SeqOptimizer(Optimizer, list):
l = []
if fgraph.profile:
validate_before = fgraph.profile.validate_time
nb_node_before = len(fgraph.nodes)
nb_node_before = len(fgraph.apply_nodes)
sub_profs = []
for optimizer in self:
try:
......@@ -184,7 +184,7 @@ class SeqOptimizer(Optimizer, list):
print "SeqOptimizer",
if hasattr(self,"name"): print self.name,
elif hasattr(self,"__name__"): print self.__name__,
print " time %.3fs for %d/%d nodes before/after optimization"%(sum(l),nb_node_before,len(fgraph.nodes))
print " time %.3fs for %d/%d nodes before/after optimization"%(sum(l),nb_node_before,len(fgraph.apply_nodes))
print " time %.3fs for validate " % (
fgraph.profile.validate_time - validate_before)
ll=[]
......@@ -208,7 +208,7 @@ class SeqOptimizer(Optimizer, list):
else:
validate_time = None
return (self, l, validate_time, nb_node_before,
len(fgraph.nodes), sub_profs)
len(fgraph.apply_nodes), sub_profs)
def __eq__(self, other):
#added to override the list's __eq__ implementation
......@@ -1503,7 +1503,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
max_use_abort = True
opt_name = (getattr(lopt, "name", None)
or getattr(lopt, "__name__", ""))
if node not in fgraph.nodes:
if node not in fgraph.apply_nodes:
# go to next node
break
finally:
......
......@@ -71,9 +71,9 @@ if 0:
def apply(self, fgraph):
tasks = defaultdict(list)
if self.max_use_ratio is not None:
max_uses = self.max_use_ratio * len(fgraph.nodes)
max_uses = self.max_use_ratio * len(fgraph.apply_nodes)
runs = defaultdict(int)
else:
runs = None
......@@ -91,10 +91,10 @@ if 0:
self.backtrack(new_r.owner, tasks)
# # == NOT IDEAL == #
# for node in fgraph.nodes:
# for node in fgraph.apply_nodes:
# importer(node)
for node in fgraph.toposort():
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
......@@ -124,7 +124,7 @@ if 0:
# if isinstance(in1, basestring):
# candidate.match[in1] = in2
# for client in node.clients:
# op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
......
......@@ -219,7 +219,7 @@ class ReplaceValidate(History, Validator):
"""
chk = fgraph.replace_all_validate(replacements, reason)
for rm in remove:
if rm in fgraph.nodes or rm in fgraph.variables:
if rm in fgraph.apply_nodes or rm in fgraph.variables:
fgraph.revert(chk)
if warn:
out = sys.stderr
......
......@@ -1002,7 +1002,7 @@ def test_many_arg_elemwise():
#assert that the test was done on the gpu.
if mode is mode_with_gpu:
assert any([isinstance(node.op, cuda.GpuElemwise)
for node in f.maker.fgraph.nodes])
for node in f.maker.fgraph.apply_nodes])
#test the optijmization local_gpu_elemwise_1
f = theano.function(
......@@ -1013,7 +1013,7 @@ def test_many_arg_elemwise():
#assert that the test was done on the gpu.
if mode is mode_with_gpu:
assert any([isinstance(node.op, cuda.GpuElemwise)
for node in f.maker.fgraph.nodes])
for node in f.maker.fgraph.apply_nodes])
assert numpy.allclose(out, outputs[-1])
results_gpu, results_cpu = outputs
......
......@@ -2667,7 +2667,7 @@ class Composite(ScalarOp):
def init_fgraph(self):
fgraph = FunctionGraph(*gof.graph.clone(self.inputs, self.outputs))
gof.MergeOptimizer().optimize(fgraph)
for node in fgraph.nodes:
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise ValueError("The fgraph to Composite must be exclusively"
" composed of ScalarOp instances.")
......
......@@ -1382,7 +1382,7 @@ class GemmOptimizer(Optimizer):
(theano.scalar.Add, theano.scalar.Sub,
theano.scalar.Neg, theano.scalar.Mul))):
continue
if not node in fgraph.nodes:
if not node in fgraph.apply_nodes:
# This mean that we already removed this node from
# the graph
continue
......
......@@ -176,7 +176,7 @@ def inplace_elemwise_optimizer_op(OP):
# We execute `validate` after this number of change.
check_each_change = config.tensor.insert_inplace_optimizer_validate_nb
if check_each_change == -1:
if len(fgraph.nodes) > 500:
if len(fgraph.apply_nodes) > 500:
check_each_change = 10
else:
check_each_change = 1
......@@ -4596,7 +4596,7 @@ class FusionOptimizer(Optimizer):
did_something = False
for node in nodelist:
# Don't try to fuse node that have already been fused.
if node in fgraph.nodes:
if node in fgraph.apply_nodes:
new_outputs = self.optimizer(node)
if new_outputs:
assert len(new_outputs) == len(node.outputs)
......
......@@ -478,7 +478,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
mode='FAST_RUN',
on_unused_input='ignore')
nb_gemm = 0
for node in f.maker.fgraph.nodes:
for node in f.maker.fgraph.apply_nodes:
if node.op == T.dot:
raise Failure('dot not changed to gemm_inplace in graph')
if node.op == _dot22:
......@@ -488,7 +488,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
assert nb_gemm == expected_nb_gemm, (nb_gemm, expected_nb_gemm)
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
allow_input_downcast=True, on_unused_input='ignore')
for node in g.maker.fgraph.nodes:
for node in g.maker.fgraph.apply_nodes:
if node.op == gemm_inplace:
raise Exception('gemm_inplace in original graph')
......@@ -561,14 +561,14 @@ def test_gemm_opt_double_gemm():
try:
f = inplace_func([Param(ii, mutable=True) for ii in i], o,
mode='FAST_RUN', on_unused_input='ignore')
for node in f.maker.fgraph.nodes:
for node in f.maker.fgraph.apply_nodes:
if node.op == T.dot:
raise Failure('dot in graph')
if node.op == _dot22:
raise Failure('_dot22 in graph')
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
on_unused_input='ignore')
#for node in g.maker.fgraph.nodes:
#for node in g.maker.fgraph.apply_nodes:
# if node.op == gemm_inplace: raise Failure('gemm_inplace in graph')
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
......@@ -760,11 +760,11 @@ def test_gemm_opt_vector_stuff():
u, v = T.vector(), T.vector()
f = inplace_func([a, u, v], a + T.dot(u, v), mode='FAST_RUN')
if gemm_inplace in [n.op for n in f.maker.fgraph.nodes]:
if gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]:
raise Failure('gemm_inplace in graph')
f = inplace_func([a, u, X, Y], a * u + T.dot(X, Y), mode='FAST_RUN')
if (gemm_inplace in [n.op for n in f.maker.fgraph.nodes]):
if (gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]):
raise Failure('gemm_inplace in graph')
......@@ -823,16 +823,16 @@ def test_inplace0():
f = inplace_func([Z, b, R, S],
[Z * (Z + b * T.dot(R, S).T)], mode='FAST_RUN')
if (gemm_inplace in [n.op for n in f.maker.fgraph.nodes]):
if (gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]):
print pp(f.maker.fgraph.outputs[0])
raise Failure('gemm_inplace in graph')
assert gemm_no_inplace in [n.op for n in f.maker.fgraph.nodes]
assert gemm_no_inplace in [n.op for n in f.maker.fgraph.apply_nodes]
# gemm_inplace should be inserted here, to work in-place on Z*c
f = inplace_func([X, Y, Z, a, b, R, S, c],
[Z * (c * Z + a * T.dot(X, Y) + b * T.dot(R, S).T)],
mode='FAST_RUN')
if (not gemm_inplace in [n.op for n in f.maker.fgraph.nodes]):
if (not gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]):
theano.printing.debugprint(f)
raise Failure('no gemm_inplace in graph')
......@@ -844,7 +844,7 @@ def test_inplace1():
[Z + Z + T.dot(X, Y)], mode='FAST_RUN')
#theano.printing.debugprint(f)
# it doesn't work inplace because we didn't mark Z as mutable input
assert [n.op for n in f.maker.fgraph.nodes] == [gemm_no_inplace]
assert [n.op for n in f.maker.fgraph.apply_nodes] == [gemm_no_inplace]
def test_dot22():
......
......@@ -590,7 +590,7 @@ def test_naacl_model(iters_per_unsup=3, iters_per_sup=3,
#print input_pretraining_gradients[4].owner.inputs[1].owner.inputs
#sys.exit()
#print "PROGRAM LEN %i HASH %i"% (len(m.pretraining_update.maker.fgraph.nodes), reduce(lambda a, b: hash(a) ^ hash(b),prog_str))
#print "PROGRAM LEN %i HASH %i"% (len(m.pretraining_update.maker.fgraph.apply_nodes), reduce(lambda a, b: hash(a) ^ hash(b),prog_str))
rng = N.random.RandomState(unittest_tools.fetch_seed(23904))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论