提交 e78e6cd7 authored 作者: nouiz's avatar nouiz

Merge pull request #921 from goodfeli/doc_topo

NEWS: documented, speed up(50% and renamed _dfs_toposort (now _contains_cycles) Rename FunctionGraph.nodes -> apply_nodes. The old nodes still work, but get warned if used. Made a parent class called Node for Apply and Variable class that are both FunctionGraph nodes
......@@ -62,7 +62,7 @@ purpose, you would set the ``view_map`` field as follows:
What this means is that the first output (position 0) is a view of the
first input (position 0). Even though the interface allows a list of
inputs that are a view of a given output, this feature is currently
inputs that are viewed by a given output, this feature is currently
unsupported. Here are more examples:
......
......@@ -675,7 +675,7 @@ def _optcheck_fgraph(input_specs, output_specs, accept_inplace=False):
features=[equivalence_tracker])
if not accept_inplace:
for node in fgraph.nodes:
for node in fgraph.apply_nodes:
if getattr(node.op, 'destroy_map', None):
raise TypeError("Graph must not contain inplace operations",
node)
......
......@@ -131,7 +131,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
inputs, outputs = gof.graph.clone(orig_inputs, orig_outputs)
fgraph = gof.fg.FunctionGraph(inputs, outputs)
for node in fgraph.nodes:
for node in fgraph.apply_nodes:
if getattr(node.op, 'destroy_map', None):
if not accept_inplace:
raise TypeError("Graph must not contain inplace operations", node, node.op)
......
......@@ -520,7 +520,7 @@ class ProfileMode(Mode):
print "Profile of Theano functions memory:"
print "(This check only the output of each apply node. It don't check the temporary memory used by the op in the apply node.)"
nb_skipped = 0
for fgraph,nodes_mem in fct_memory.iteritems():
for fgraph, nodes_mem in fct_memory.iteritems():
size_sum=sum([sum(val) for key,val in nodes_mem.iteritems()])
if size_sum < min_memory_size:
nb_skipped += 1
......
......@@ -711,7 +711,7 @@ if 0: # old code still to be ported from ProfileMode
var_mem[out]=v
print
print "Profile of Theano functions memory:"
for fgraph,nodes_mem in fct_memory.iteritems():
for fgraph, nodes_mem in fct_memory.iteritems():
print "Theano fct:", [fct for fct in fct_call.keys() if fct.maker.fgraph is fgraph][0].name
size_sum=sum([sum(val) for key,val in nodes_mem.iteritems()])
print " Max without gc, inplace and view (KB)",size_sum/1024
......
"""WRITEME"""
"""
Classes and functions for validating graphs that contain view
and inplace operations.
"""
import sys
if sys.version_info[:2] >= (2,5):
from collections import defaultdict
......@@ -13,127 +16,183 @@ from theano.gof.python25 import deque
from fg import InconsistencyError
class ProtocolError(Exception):
"""WRITEME"""
pass
class DestroyHandler(object):
"""WRITEME"""
def __init__(self, do_imports_on_attach=True):
self.map = {}
self.do_imports_on_attach=do_imports_on_attach
def on_attach(self, fgraph):
dh = self.map.setdefault(fgraph, DestroyHandlerHelper2(do_imports_on_attach=self.do_imports_on_attach))
dh.on_attach(fgraph)
def on_detach(self, fgraph):
self.map[fgraph].on_detach(fgraph)
def on_import(self, fgraph, op):
self.map[fgraph].on_import(fgraph, op)
def on_prune(self, fgraph, op):
self.map[fgraph].on_prune(fgraph, op)
def on_change_input(self, fgraph, node, i, r, new_r):
self.map[fgraph].on_change_input(fgraph, node, i, r, new_r)
def validate(self, fgraph):
self.map[fgraph].validate(fgraph)
def orderings(self, fgraph):
return self.map[fgraph].orderings(fgraph)
def _dfs_toposort(i, r_out, orderings):
"""Raised when FunctionGraph calls DestroyHandler callbacks in
an invalid way, for example, pruning or changing a node that has
never been imported.
"""
i - list of inputs
o - list of outputs
orderings - dict of additions to the normal inputs and outputs
pass
Returns nothing. Raises exception for graph with cycles
def _contains_cycle(fgraph, orderings):
"""
#this is hard-coded reimplementation of functions from graph.py
# reason: go faster, prepare for port to C.
assert isinstance(r_out, (tuple, list, deque))
fgraph - the FunctionGraph to check for cycles
# TODO: For more speed - use a defaultdict for the orderings
orderings - dictionary specifying extra dependencies besides
those encoded in Variable.owner / Apply.inputs
If orderings[my_apply] == dependencies,
iset = set(i)
then my_apply is an Apply instance,
dependencies is a set of Apply instances,
and every member of dependencies must be executed
before my_apply.
if 0:
def expand(obj):
rval = []
if obj not in iset:
if isinstance(obj, graph.Variable):
if obj.owner:
rval = [obj.owner]
if isinstance(obj, graph.Apply):
rval = list(obj.inputs)
rval.extend(orderings.get(obj, []))
else:
assert not orderings.get(obj, [])
return rval
The dependencies are typically used to prevent
inplace apply nodes from destroying their input before
other apply nodes with the same input access it.
expand_cache = {}
# reachable, clients = stack_search( deque(r_out), deps, 'dfs', True)
start=deque(r_out)
rval_set = set()
rval_set.add(id(None))
rval_list = list()
expand_inv = {}
sources = deque()
while start:
l = start.pop()# this makes the search dfs
if id(l) not in rval_set:
rval_list.append(l)
rval_set.add(id(l))
if l in iset:
assert not orderings.get(l, [])
expand_l = []
else:
try:
if l.owner:
expand_l = [l.owner]
else:
expand_l = []
except AttributeError:
expand_l = list(l.inputs)
expand_l.extend(orderings.get(l, []))
if expand_l:
for r in expand_l:
expand_inv.setdefault(r, []).append(l)
start.extend(expand_l)
else:
sources.append(l)
expand_cache[l] = expand_l
assert len(rval_list) == len(rval_set)-1
Returns True if the graph contains a cycle, False otherwise.
"""
rset = set()
rlist = []
while sources:
node = sources.popleft()
if node not in rset:
rlist.append(node)
rset.add(node)
for client in expand_inv.get(node, []):
expand_cache[client] = [a for a in expand_cache[client] if a is not node]
if not expand_cache[client]:
sources.append(client)
# These are lists of Variable instances
inputs = fgraph.inputs
outputs = fgraph.outputs
if len(rlist) != len(rval_list):
raise ValueError('graph contains cycles')
#return [o for o in rlist if isinstance(o, graph.Apply)]
# this is hard-coded reimplementation of functions from graph.py
# reason: go faster, prepare for port to C.
# specifically, it could be replaced with a wrapper
# around graph.io_toposort that returns True iff io_toposort raises
# a ValueError containing the substring 'cycle'.
# This implementation is optimized for the destroyhandler and runs
# slightly faster than io_toposort.
# this is performance-critical code. it is the largest single-function
# bottleneck when compiling large graphs.
assert isinstance(outputs, (tuple, list, deque))
# TODO: For more speed - use a defaultdict for the orderings
# (defaultdict runs faster than dict in the case where the key
# is not in the dictionary, at least in CPython)
iset = set(inputs)
# IG: I tried converting parent_counts to use an id for the key,
# so that the dict would do reference counting on its keys.
# This caused a slowdown.
# Separate benchmark tests showed that calling id is about
# half as expensive as a dictionary access, and that the
# dictionary also runs slower when storing ids than when
# storing objects.
# dict mapping an Apply or Variable instance to the number
# of its parents (including parents imposed by orderings)
# that haven't been visited yet
parent_counts = {}
# dict mapping an Apply or Variable instance to its children
node_to_children = {}
# visitable: A container holding all Variable and Apply instances
# that can currently be visited according to the graph topology
# (ie, whose parents have already been visited)
# TODO: visitable is a fifo_queue. could this run faster if we
# implement it as a stack rather than a deque?
# TODO: visitable need not be a fifo_queue, any kind of container
# that we can throw things into and take things out of quickly will
# work. is there another kind of container that could run faster?
# we don't care about the traversal order here as much as we do
# in io_toposort because we aren't trying to generate an ordering
# on the nodes
visitable = deque()
# IG: visitable could in principle be initialized to fgraph.inputs
# + fgraph.orphans... if there were an fgraph.orphans structure.
# I tried making one and maintaining it caused a huge slowdown.
# This may be because I made it a list, so it would have a
# deterministic iteration order, in hopes of using it to speed
# up toposort as well.
# I think since we need to scan through all variables and nodes
# to make parent_counts anyway, it's cheap enough to always
# detect orphans at cycle detection / toposort time
# Pass through all the nodes to build visitable, parent_count, and
# node_to_children
for var in fgraph.variables:
# this is faster than calling get_parents
owner = var.owner
if owner:
parents = [ owner ]
else:
parents = []
# variables don't appear in orderings, so we don't need to worry
# about that here
if parents:
for parent in parents:
# insert node in node_to_children[r]
# (if r is not already in node_to_children,
# intialize it to [])
node_to_children.setdefault(parent, []).append(var)
parent_counts[var] = len(parents)
else:
visitable.append(var)
parent_counts[var] = 0
for a_n in fgraph.apply_nodes:
parents = list(a_n.inputs)
# This is faster than conditionally extending
# IG: I tried using a shared empty_list = [] constructed
# outside of the for loop to avoid constructing multiple
# lists, but this was not any faster.
parents.extend(orderings.get(a_n,[]))
if parents:
for parent in parents:
# insert node in node_to_children[r]
# (if r is not already in node_to_children,
# intialize it to [])
node_to_children.setdefault(parent, []).append(a_n)
parent_counts[a_n] = len(parents)
else:
# an Apply with no inputs would be a weird case, but I'm
# not sure we forbid it
visitable.append(a_n)
parent_counts[a_n] = 0
# at this point,
# parent_counts.keys() == fgraph.apply_nodes + fgraph.variables
# Now we actually check for cycles
# As long as there are nodes that can be visited while respecting
# the topology, we keep visiting nodes
# If we run out of visitable nodes and we haven't visited all nodes,
# then there was a cycle. It blocked the traversal because some
# node couldn't be visited until one of its descendants had been
# visited too.
# This is a standard cycle detection algorithm.
visited = 0
while visitable:
# Since each node is inserted into the visitable queue exactly
# once, it comes out of the queue exactly once
# That means we can decrement its children's unvisited parent count
# and increment the visited node count without double-counting
node = visitable.popleft()
visited += 1
for client in node_to_children.get(node,[]):
parent_counts[client] -= 1
# If all of a node's parents have been visited,
# it may now be visited too
if not parent_counts[client]:
visitable.append(client)
return visited != len(parent_counts)
def getroot(r, view_i):
"""
TODO: what is view_i ? based on add_impact's docstring, IG is guessing
it might be a dictionary mapping variables to views, but what is
a view? In these old docstrings I'm not sure if "view" always
means "view variable" or if it also sometimes means "viewing
pattern."
For views: Return non-view variable which is ultimatly viewed by r.
For non-views: return self.
"""
......@@ -149,6 +208,12 @@ def add_impact(r, view_o, impact):
:param impact: is a set of variables that are views of r
:param droot: a dictionary mapping views -> r
TODO: this docstring is hideously wrong, the function doesn't return anything.
has droot been renamed to view_o?
does it add things to the impact argument instead of returning them?
IG thinks so, based on reading the code. It looks like get_impact
does what this docstring said this function does.
"""
for v in view_o.get(r,[]):
impact.add(v)
......@@ -176,51 +241,462 @@ def fast_inplace_check(inputs):
and i not in protected_inputs]
return inputs
class DestroyHandlerHelper2(toolbox.Bookkeeper):
if 0:
# old, non-incremental version of the DestroyHandler
class DestroyHandler(toolbox.Bookkeeper):
"""
The DestroyHandler class detects when a graph is impossible to evaluate because of
aliasing and destructive operations.
Several data structures are used to do this.
When an Op uses its view_map property to declare that an output may be aliased
to an input, then if that output is destroyed, the input is also considering to be
destroyed. The view_maps of several Ops can feed into one another and form a directed graph.
The consequence of destroying any variable in such a graph is that all variables in the graph
must be considered to be destroyed, because they could all be refering to the same
underlying storage. In the current implementation, that graph is a tree, and the root of
that tree is called the foundation. The `droot` property of this class maps from every
graph variable to its foundation. The `impact` property maps backward from the foundation
to all of the variables that depend on it. When any variable is destroyed, this class marks
the foundation of that variable as being destroyed, with the `root_destroyer` property.
"""
droot = {}
"""
destroyed view + nonview variables -> foundation
"""
impact = {}
"""
destroyed nonview variable -> it + all views of it
"""
root_destroyer = {}
"""
root -> destroyer apply
"""
def __init__(self, do_imports_on_attach=True):
self.fgraph = None
self.do_imports_on_attach = do_imports_on_attach
def on_attach(self, fgraph):
"""
When attaching to a new fgraph, check that
1) This DestroyHandler wasn't already attached to some fgraph
(its data structures are only set up to serve one)
2) The FunctionGraph doesn't already have a DestroyHandler.
This would result in it validating everything twice, causing
compilation to be slower.
TODO: WRITEME: what does this do besides the checks?
"""
####### Do the checking ###########
already_there = False
if self.fgraph is fgraph:
already_there = True
if self.fgraph not in [None, fgraph]:
raise Exception("A DestroyHandler instance can only serve one FunctionGraph. (Matthew 6:24)")
for attr in ('destroyers', 'destroy_handler'):
if hasattr(fgraph, attr):
already_there = True
if already_there:
# FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.")
####### end of checking ############
def get_destroyers_of(r):
droot, impact, root_destroyer = self.refresh_droot_impact()
try:
return [root_destroyer[droot[r]]]
except Exception:
return []
fgraph.destroyers = get_destroyers_of
fgraph.destroy_handler = self
self.fgraph = fgraph
self.destroyers = set() #set of Apply instances with non-null destroy_map
self.view_i = {} # variable -> variable used in calculation
self.view_o = {} # variable -> set of variables that use this one as a direct input
#clients: how many times does an apply use a given variable
self.clients = {} # variable -> apply -> ninputs
self.stale_droot = True
# IG: It's unclear if this is meant to be included in deployed code. It looks like
# it is unnecessary if FunctionGraph is working correctly, so I am commenting uses
# of it (for speed) but leaving the commented code in place so it is easy to restore
# for debugging purposes.
# Note: is there anything like the C preprocessor for python? It would be useful to
# just ifdef these things out
# self.debug_all_apps = set()
if self.do_imports_on_attach:
toolbox.Bookkeeper.on_attach(self, fgraph)
def refresh_droot_impact(self):
if self.stale_droot:
self.droot, self.impact, self.root_destroyer = self._build_droot_impact()
self.stale_droot = False
return self.droot, self.impact, self.root_destroyer
def _build_droot_impact(self):
droot = {} # destroyed view + nonview variables -> foundation
impact = {} # destroyed nonview variable -> it + all views of it
root_destroyer = {} # root -> destroyer apply
for app in self.destroyers:
for output_idx, input_idx_list in app.op.destroy_map.items():
if len(input_idx_list) != 1:
raise NotImplementedError()
input_idx = input_idx_list[0]
input = app.inputs[input_idx]
input_root = getroot(input, self.view_i)
if input_root in droot:
raise InconsistencyError("Multiple destroyers of %s" % input_root)
droot[input_root] = input_root
root_destroyer[input_root] = app
#input_impact = set([input_root])
#add_impact(input_root, self.view_o, input_impact)
input_impact = get_impact(input_root, self.view_o)
for v in input_impact:
assert v not in droot
droot[v] = input_root
impact[input_root] = input_impact
impact[input_root].add(input_root)
return droot, impact, root_destroyer
def on_detach(self, fgraph):
if fgraph is not self.fgraph:
raise Exception("detaching wrong fgraph", fgraph)
del self.destroyers
del self.view_i
del self.view_o
del self.clients
del self.stale_droot
assert self.fgraph.destroyer_handler is self
delattr(self.fgraph, 'destroyers')
delattr(self.fgraph, 'destroy_handler')
self.fgraph = None
def on_import(self, fgraph, app):
"""Add Apply instance to set which must be computed"""
#if app in self.debug_all_apps: raise ProtocolError("double import")
#self.debug_all_apps.add(app)
#print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list
if getattr(app.op, 'destroy_map', {}):
self.destroyers.add(app)
# add this symbol to the forward and backward maps
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
if len(i_idx_list) > 1:
raise NotImplementedError('destroying this output invalidates multiple inputs', (app.op))
o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]]
self.view_i[o] = i
self.view_o.setdefault(i,set()).add(o)
# update self.clients
for i, input in enumerate(app.inputs):
self.clients.setdefault(input, {}).setdefault(app,0)
self.clients[input][app] += 1
for i, output in enumerate(app.outputs):
self.clients.setdefault(output, {})
self.stale_droot = True
def on_prune(self, fgraph, app):
"""Remove Apply instance from set which must be computed"""
#if app not in self.debug_all_apps: raise ProtocolError("prune without import")
#self.debug_all_apps.remove(app)
#UPDATE self.clients
for i, input in enumerate(set(app.inputs)):
del self.clients[input][app]
if getattr(app.op, 'destroy_map', {}):
self.destroyers.remove(app)
# Note: leaving empty client dictionaries in the struct.
# Why? It's a pain to remove them. I think they aren't doing any harm, they will be
# deleted on_detach().
#UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs
raise NotImplementedError()
o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]]
del self.view_i[o]
self.view_o[i].remove(o)
if not self.view_o[i]:
del self.view_o[i]
self.stale_droot = True
def on_change_input(self, fgraph, app, i, old_r, new_r):
"""app.inputs[i] changed from old_r to new_r """
if app == 'output':
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# considered 'outputs' of the graph.
pass
else:
#if app not in self.debug_all_apps: raise ProtocolError("change without import")
#UPDATE self.clients
self.clients[old_r][app] -= 1
if self.clients[old_r][app] == 0:
del self.clients[old_r][app]
self.clients.setdefault(new_r,{}).setdefault(app,0)
self.clients[new_r][app] += 1
#UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs
raise NotImplementedError()
i_idx = i_idx_list[0]
output = app.outputs[o_idx]
if i_idx == i:
if app.inputs[i_idx] is not new_r:
raise ProtocolError("wrong new_r on change")
self.view_i[output] = new_r
self.view_o[old_r].remove(output)
if not self.view_o[old_r]:
del self.view_o[old_r]
self.view_o.setdefault(new_r,set()).add(output)
self.stale_droot = True
def validate(self, fgraph):
"""Return None
Raise InconsistencyError when
a) orderings() raises an error
b) orderings cannot be topologically sorted.
"""
if self.destroyers:
ords = self.orderings(fgraph)
if _contains_cycle(fgraph, ords):
raise InconsistencyError("Dependency graph contains cycles")
else:
#James's Conjecture:
#If there are no destructive ops, then there can be no cycles.
pass
return True
def orderings(self, fgraph):
"""Return orderings induced by destructive operations.
Raise InconsistencyError when
a) attempting to destroy indestructable variable, or
b) attempting to destroy a value multiple times, or
c) an Apply destroys (illegally) one of its own inputs by aliasing
"""
rval = {}
if self.destroyers:
# BUILD DATA STRUCTURES
# CHECK for multiple destructions during construction of variables
droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants
illegal_destroy = [r for r in droot if \
getattr(r.tag,'indestructible', False) or \
isinstance(r, graph.Constant)]
if illegal_destroy:
#print 'destroying illegally'
raise InconsistencyError("Attempting to destroy indestructible variables: %s" %
illegal_destroy)
# add destroyed variable clients as computational dependencies
for app in self.destroyers:
# for each destroyed input...
for output_idx, input_idx_list in app.op.destroy_map.items():
destroyed_idx = input_idx_list[0]
destroyed_variable = app.inputs[destroyed_idx]
root = droot[destroyed_variable]
root_impact = impact[root]
# we generally want to put all clients of things which depend on root
# as pre-requisites of app.
# But, app is itself one such client!
# App will always be a client of the node we're destroying
# (destroyed_variable, but the tricky thing is when it is also a client of
# *another variable* viewing on the root. Generally this is illegal, (e.g.,
# add_inplace(x, x.T). In some special cases though, the in-place op will
# actually be able to work properly with multiple destroyed inputs (e.g,
# add_inplace(x, x). An Op that can still work in this case should declare
# so via the 'destroyhandler_tolerate_same' attribute or
# 'destroyhandler_tolerate_aliased' attribute.
#
# destroyhandler_tolerate_same should be a list of pairs of the form
# [(idx0, idx1), (idx0, idx2), ...]
# The first element of each pair is the input index of a destroyed
# variable.
# The second element of each pair is the index of a different input where
# we will permit exactly the same variable to appear.
# For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed
# input is also allowed to appear as the second argument.
#
# destroyhandler_tolerate_aliased is the same sort of list of
# pairs.
# op.destroyhandler_tolerate_aliased = [(idx0, idx1)] tells the
# destroyhandler to IGNORE an aliasing between a destroyed
# input idx0 and another input idx1.
# This is generally a bad idea, but it is safe in some
# cases, such as
# - the op reads from the aliased idx1 before modifying idx0
# - the idx0 and idx1 are guaranteed not to overlap (e.g.
# they are pointed at different rows of a matrix).
#
#CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
tolerated = set(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx)
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
ignored = set(idx1 for idx0, idx1 in tolerate_aliased
if idx0 == destroyed_idx)
#print 'tolerated', tolerated
#print 'ignored', ignored
for i, input in enumerate(app.inputs):
if i in ignored:
continue
if input in root_impact \
and (i not in tolerated or input is not destroyed_variable):
raise InconsistencyError("Input aliasing: %s (%i, %i)"
% (app, destroyed_idx, i))
# add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input
root_clients = set()
for r in root_impact:
assert not [a for a,c in self.clients[r].items() if not c]
root_clients.update([a for a,c in self.clients[r].items() if c])
root_clients.remove(app)
if root_clients:
rval[app] = root_clients
return rval
class DestroyHandler(toolbox.Bookkeeper):
"""
The DestroyHandlerHelper2 class detects when a graph is impossible to evaluate because of
aliasing and destructive operations.
The DestroyHandler class detects when a graph is impossible to evaluate
because of aliasing and destructive operations.
Several data structures are used to do this.
When an Op uses its view_map property to declare that an output may be aliased
to an input, then if that output is destroyed, the input is also considering to be
destroyed. The view_maps of several Ops can feed into one another and form a directed graph.
The consequence of destroying any variable in such a graph is that all variables in the graph
must be considered to be destroyed, because they could all be refering to the same
underlying storage. In the current implementation, that graph is a tree, and the root of
that tree is called the foundation. The `droot` property of this class maps from every
graph variable to its foundation. The `impact` property maps backward from the foundation
to all of the variables that depend on it. When any variable is destroyed, this class marks
the foundation of that variable as being destroyed, with the `root_destroyer` property.
An Op can use its view_map property to declare that an output may be
aliased to an input. If that output is destroyed, the input is also
considered to be destroyed. The view_maps of several Ops can feed into
one another and form a directed graph. The consequence of destroying any
variable in such a graph is that all variables in the graph must be
considered to be destroyed, because they could all be refering to the
same underlying storage.
In the current implementation, that graph is a tree, and the root of that
tree is called the foundation.
TODO: why "in the current implementation" ? is there another implementation
planned?
TODO: why is the graph a tree? isn't it possible that one variable could
be aliased to many variables? for example, don't switch and ifelse
have to do this?
The original DestroyHandler (if 0'ed out above) computed several data
structures from scratch each time it was asked to validate the graph.
Because this happens potentially thousands of times and each graph to
validate is extremely similar to the previous one, computing the
data structures from scratch repeatedly was wasteful and resulted in
high compile times for large graphs.
This implementation computes the data structures once at initialization
and then incrementally updates them.
It is a work in progress. The following data structures have been
converted to use the incremental strategy:
<none>
The following data structures remain to be converted:
<unknown>
"""
"""maps every variable in the graph to its "foundation" (deepest
ancestor in view chain)
TODO: change name to var_to_vroot"""
droot = {}
"""
destroyed view + nonview variables -> foundation
"""
"""maps a variable to all variables that are indirect or direct views of it
(including itself)
essentially the inverse of droot
TODO: do all variables appear in this dict, or only those that are foundations?
TODO: do only destoryed variables go in here? one old docstring said so
TODO: rename to x_to_views after reverse engineering what x is"""
impact = {}
"""
destroyed nonview variable -> it + all views of it
"""
"""if a var is destroyed, then this dict will map
droot[var] to the apply node that destroyed var
TODO: rename to vroot_to_destroyer"""
root_destroyer = {}
"""
root -> destroyer apply
"""
def __init__(self, do_imports_on_attach=True):
self.fgraph = None
self.do_imports_on_attach = do_imports_on_attach
def on_attach(self, fgraph):
#boilerplate from old implementation
"""
When attaching to a new fgraph, check that
1) This DestroyHandler wasn't already attached to some fgraph
(its data structures are only set up to serve one)
2) The FunctionGraph doesn't already have a DestroyHandler.
This would result in it validating everything twice, causing
compilation to be slower.
Give the FunctionGraph instance:
1) A new method "destroyers(var)"
TODO: what does this do exactly?
2) A new attribute, "destroy_handler"
TODO: WRITEME: what does this do besides the checks?
"""
####### Do the checking ###########
already_there = False
if self.fgraph is fgraph:
already_there = True
if self.fgraph is not None:
raise Exception("A DestroyHandler instance can only serve one FunctionGraph. (Matthew 6:24)")
for attr in ('destroyers', 'destroy_handler'):
if hasattr(fgraph, attr):
raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.")
already_there = True
if already_there:
# FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.")
####### Annotate the FunctionGraph ############
def get_destroyers_of(r):
droot, impact, root_destroyer = self.refresh_droot_impact()
......@@ -245,39 +721,38 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper):
toolbox.Bookkeeper.on_attach(self, fgraph)
def refresh_droot_impact(self):
"""
Makes sure self.droot, self.impact, and self.root_destroyer are
up to date, and returns them.
(see docstrings for these properties above)
"""
if self.stale_droot:
self.droot, self.impact, self.root_destroyer = self._build_droot_impact()
droot = {} # destroyed view + nonview variables -> foundation
impact = {} # destroyed nonview variable -> it + all views of it
root_destroyer = {} # root -> destroyer apply
for app in self.destroyers:
for output_idx, input_idx_list in app.op.destroy_map.items():
if len(input_idx_list) != 1:
raise NotImplementedError()
input_idx = input_idx_list[0]
input = app.inputs[input_idx]
input_root = getroot(input, self.view_i)
if input_root in droot:
raise InconsistencyError("Multiple destroyers of %s" % input_root)
droot[input_root] = input_root
root_destroyer[input_root] = app
input_impact = get_impact(input_root, self.view_o)
for v in input_impact:
assert v not in droot
droot[v] = input_root
impact[input_root] = input_impact
impact[input_root].add(input_root)
self.droot, self.impact, self.root_destroyer = droot, impact, root_destroyer
self.stale_droot = False
return self.droot, self.impact, self.root_destroyer
def _build_droot_impact(self):
droot = {} # destroyed view + nonview variables -> foundation
impact = {} # destroyed nonview variable -> it + all views of it
root_destroyer = {} # root -> destroyer apply
for app in self.destroyers:
for output_idx, input_idx_list in app.op.destroy_map.items():
if len(input_idx_list) != 1:
raise NotImplementedError()
input_idx = input_idx_list[0]
input = app.inputs[input_idx]
input_root = getroot(input, self.view_i)
if input_root in droot:
raise InconsistencyError("Multiple destroyers of %s" % input_root)
droot[input_root] = input_root
root_destroyer[input_root] = app
#input_impact = set([input_root])
#add_impact(input_root, self.view_o, input_impact)
input_impact = get_impact(input_root, self.view_o)
for v in input_impact:
assert v not in droot
droot[v] = input_root
impact[input_root] = input_impact
impact[input_root].add(input_root)
return droot, impact, root_destroyer
def on_detach(self, fgraph):
if fgraph is not self.fgraph:
raise Exception("detaching wrong fgraph", fgraph)
......@@ -399,24 +874,12 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper):
b) orderings cannot be topologically sorted.
"""
#print '\nVALIDATE'
if self.destroyers:
try:
ords = self.orderings(fgraph)
except Exception, e:
#print 'orderings failed with:', type(e), e.args
raise
#print 'orderings:', ords
try:
### graph.io_toposort(fgraph.inputs, fgraph.outputs, ords)
_dfs_toposort(fgraph.inputs, fgraph.outputs, ords)
except ValueError, e:
#print 'not passing.', ords
if 'cycles' in str(e):
raise InconsistencyError("Dependency graph contains cycles")
else:
raise
#print 'passing...', ords
ords = self.orderings(fgraph)
if _contains_cycle(fgraph, ords):
raise InconsistencyError("Dependency graph contains cycles")
else:
#James's Conjecture:
#If there are no destructive ops, then there can be no cycles.
......@@ -439,17 +902,12 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper):
# CHECK for multiple destructions during construction of variables
droot, impact, __ignore = self.refresh_droot_impact()
#print "droot", droot
#print "impact", impact
#print "view_i", self.view_i
#print "view_o", self.view_o
# check for destruction of constants
illegal_destroy = [r for r in droot if \
getattr(r.tag,'indestructible', False) or \
isinstance(r, graph.Constant)]
if illegal_destroy:
#print 'destroying illegally'
raise InconsistencyError("Attempting to destroy indestructible variables: %s" %
illegal_destroy)
......
......@@ -79,10 +79,10 @@ class FunctionGraph(utils.object2):
# so I probably am) this should be a set.
self._features = []
# All nodes in the subgraph defined by inputs and outputs are cached in nodes
self.nodes = set()
# All apply nodes in the subgraph defined by inputs and outputs are cached in this field
self.apply_nodes = set()
# Ditto for variables
# Ditto for variable nodes
self.variables = set()
self.inputs = list(inputs)
......@@ -151,13 +151,13 @@ class FunctionGraph(utils.object2):
nodes and variables. If there are no features, this should set
them back to what they were originally.
"""
for node in self.nodes:
del node.fgraph
del node.deps
for apply_node in self.apply_nodes:
del apply_node.fgraph
del apply_node.deps
for variable in self.variables:
del variable.fgraph
del variable.clients
self.nodes = set()
self.apply_nodes = set()
self.variables = set()
self.inputs = None
self.outputs = None
......@@ -215,11 +215,11 @@ class FunctionGraph(utils.object2):
if NullType is None:
from null_type import NullType
# Imports the owners of the variables
r_owner_done = set(self.nodes)
for node in [r.owner for r in variables if r.owner is not None]:
if node not in r_owner_done:
r_owner_done.add(node)
self.__import__(node)
r_owner_done = set(self.apply_nodes)
for apply_node in [r.owner for r in variables if r.owner is not None]:
if apply_node not in r_owner_done:
r_owner_done.add(apply_node)
self.__import__(apply_node)
for r in variables:
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
if isinstance(r.type,NullType):
......@@ -229,7 +229,9 @@ class FunctionGraph(utils.object2):
self.__setup_r__(r)
self.variables.add(r)
def __import__(self, node, check = True):
def __import__(self, apply_node, check = True):
node = apply_node
# We import the nodes in topological order. We only are interested
# in new nodes, so we use all variables we know of as if they were the input set.
# (the functions in the graph module only use the input set to
......@@ -311,9 +313,9 @@ class FunctionGraph(utils.object2):
r)
for node in new_nodes:
assert node not in self.nodes
assert node not in self.apply_nodes
self.__setup_node__(node)
self.nodes.add(node)
self.apply_nodes.add(node)
for output in node.outputs:
self.__setup_r__(output)
self.variables.add(output)
......@@ -336,8 +338,9 @@ class FunctionGraph(utils.object2):
if not r.clients and r in self.variables:
self.variables.remove(r)
def __prune__(self, node):
if node not in self.nodes:
def __prune__(self, apply_node):
node = apply_node
if node not in self.apply_nodes:
raise Exception("%s does not belong to this FunctionGraph and cannot be pruned." % node)
assert node.fgraph is self
# If node's outputs have no clients, removes it from the graph
......@@ -348,7 +351,7 @@ class FunctionGraph(utils.object2):
# Cannot prune an op which is an output or used somewhere
if self.clients(output) or output in self.outputs: #output in self.outputs or self.clients(output):
return
self.nodes.remove(node)
self.apply_nodes.remove(node)
self.variables.difference_update(node.outputs)
self.execute_callbacks('on_prune', node)
......@@ -446,21 +449,29 @@ class FunctionGraph(utils.object2):
Adds a gof.toolbox.Feature to this function_graph
and triggers its on_attach callback
"""
# Filter out literally identical features
if feature in self._features:
return # the feature is already present
#it would be nice if we could require a specific class instead of
#a "workalike" so we could do actual error checking
#if not isinstance(feature, toolbox.Feature):
# raise TypeError("Expected gof.toolbox.Feature instance, got "+\
# str(type(feature)))
# Filter out functionally identical features.
# Features may use their on_attach method to raise
# toolbox.AlreadyThere if they detect that some
# installed feature does the same thing already
attach = getattr(feature, 'on_attach', None)
if attach is not None:
try:
attach(self)
except toolbox.AlreadyThere:
return
#it would be nice if we could require a specific class instead of
#a "workalike" so we could do actual error checking
#if not isinstance(feature, toolbox.Feature):
# raise TypeError("Expected gof.toolbox.Feature instance, got "+\
# str(type(feature)))
# Add the feature
self._features.append(feature)
def remove_feature(self, feature):
......@@ -490,6 +501,9 @@ class FunctionGraph(utils.object2):
try:
fn = getattr(feature, name)
except AttributeError:
# this is safe because there is no work done inside the
# try; the AttributeError reall must come from feature.${name}
# not existing
continue
#####HORRIBLE OPTIONAL ARGUMENT HACK
......@@ -532,12 +546,12 @@ class FunctionGraph(utils.object2):
{node: predecessors} where predecessors is a list of nodes
that should be computed before the key node.
"""
if len(self.nodes) < 2:
if len(self.apply_nodes) < 2:
# optimization
# when there are 0 or 1 nodes, no sorting is necessary
# This special case happens a lot because the OpWiseCLinker produces
# 1-element graphs.
return list(self.nodes)
return list(self.apply_nodes)
fg = self
ords = self.orderings()
order = graph.io_toposort(fg.inputs, fg.outputs, ords)
......@@ -569,26 +583,31 @@ class FunctionGraph(utils.object2):
"""WRITEME Same as len(self.clients(r))."""
return len(self.clients(r))
# def edge(self, r):
# return r in self.inputs or r in self.orphans
def nodes_getter(self):
warnings.warn("FunctionGraph.nodes is deprecated, it has been renamed 'apply_nodes'",
stacklevel=2)
return self.apply_nodes
def nodes_setter(self, value):
warnings.warn("FunctionGraph.nodes is deprecated, it has been renamed 'apply_nodes'",
stacklevel=2)
self.apply_nodes = value
def nodes_deleter(self):
warnings.warn("FunctionGraph.nodes is deprecated, it has been renamed 'apply_nodes'",
stacklevel=2)
del self.apply_nodes
# def follow(self, r):
# node = r.owner
# if self.edge(r):
# return None
# else:
# if node is None:
# raise Exception("what the fuck")
# return node.inputs
nodes = property(nodes_getter, nodes_setter, nodes_deleter)
def check_integrity(self):
"""WRITEME
Call this for a diagnosis if things go awry.
"""
nodes = graph.ops(self.inputs, self.outputs)
if self.nodes != nodes:
missing = nodes.difference(self.nodes)
excess = self.nodes.difference(nodes)
if self.apply_nodes != nodes:
missing = nodes.difference(self.apply_nodes)
excess = self.apply_nodes.difference(nodes)
raise Exception("The nodes are inappropriately cached. missing, in excess: ", missing, excess)
for node in nodes:
if node.fgraph is not self:
......
......@@ -21,7 +21,25 @@ is_same_graph_with_merge = None
equal_computations = None
class Apply(utils.object2):
class Node(utils.object2):
"""A Node in a theano graph.
Graphs contain two kinds of Nodes--
Variable and Apply.
Edges in the graph are not explicitly represented.
Instead each Node keeps track of its parents via
Variable.owner / Apply.inputs and its children
via Variable.clients / Apply.outputs.
"""
def get_parents(self):
""" Return a list of the parents of this node.
Should return a copy--i.e., modifying the return
value should not modify the graph structure."""
raise NotImplementedError()
class Apply(Node):
"""
An :term:`Apply` instance is a node in an expression graph which represents the application
of an `Op` to some input `Variable` nodes, producing some output `Variable` nodes.
......@@ -202,6 +220,9 @@ class Apply(utils.object2):
new_node.inputs = new_inputs
return new_node
def get_parents(self):
return list( self.inputs )
#convenience properties
nin = property(lambda self: len(self.inputs), doc='same as len(self.inputs)')
"""property: Number of inputs"""
......@@ -210,7 +231,7 @@ class Apply(utils.object2):
"""property: Number of outputs"""
class Variable(utils.object2):
class Variable(Node):
"""
A :term:`Variable` is a node in an expression graph that represents a variable.
......@@ -364,6 +385,11 @@ class Variable(utils.object2):
raise NotImplementedError('Subclasses of Variable must provide __ge__',
self.__class__.__name__)
def get_parents(self):
if self.owner is not None:
return [ self.owner ]
return [ ]
def env_getter(self):
warnings.warn("Variable.env is deprecated, it has been renamed 'fgraph'",
stacklevel=2)
......@@ -726,13 +752,26 @@ def general_toposort(r_out, deps, debug_print=False):
return rlist
def io_toposort(i, o, orderings=None):
def io_toposort(inputs, outputs, orderings=None):
"""WRITEME
inputs: a list or tuple of Variable instances
outputs: a list or tuple of Variable instances
orderings: a dictionary
key: Apply instance
value: list of Apply instance
it is important that the value be
a container with a deterministic iteration
order. no sets allowed!
"""
if orderings is None:
orderings = {}
#the inputs are used only here in the function that decides what 'predecessors' to explore
iset = set(i)
iset = set(inputs)
def deps(obj):
rval = []
......@@ -747,7 +786,7 @@ def io_toposort(i, o, orderings=None):
assert not orderings.get(obj, [])
return rval
topo = general_toposort(o, deps)
topo = general_toposort(outputs, deps)
return [o for o in topo if isinstance(o, Apply)]
......
......@@ -162,7 +162,7 @@ class SeqOptimizer(Optimizer, list):
l = []
if fgraph.profile:
validate_before = fgraph.profile.validate_time
nb_node_before = len(fgraph.nodes)
nb_node_before = len(fgraph.apply_nodes)
sub_profs = []
for optimizer in self:
try:
......@@ -184,7 +184,7 @@ class SeqOptimizer(Optimizer, list):
print "SeqOptimizer",
if hasattr(self,"name"): print self.name,
elif hasattr(self,"__name__"): print self.__name__,
print " time %.3fs for %d/%d nodes before/after optimization"%(sum(l),nb_node_before,len(fgraph.nodes))
print " time %.3fs for %d/%d nodes before/after optimization"%(sum(l),nb_node_before,len(fgraph.apply_nodes))
print " time %.3fs for validate " % (
fgraph.profile.validate_time - validate_before)
ll=[]
......@@ -208,7 +208,7 @@ class SeqOptimizer(Optimizer, list):
else:
validate_time = None
return (self, l, validate_time, nb_node_before,
len(fgraph.nodes), sub_profs)
len(fgraph.apply_nodes), sub_profs)
def __eq__(self, other):
#added to override the list's __eq__ implementation
......@@ -1503,7 +1503,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
max_use_abort = True
opt_name = (getattr(lopt, "name", None)
or getattr(lopt, "__name__", ""))
if node not in fgraph.nodes:
if node not in fgraph.apply_nodes:
# go to next node
break
finally:
......
......@@ -71,9 +71,9 @@ if 0:
def apply(self, fgraph):
tasks = defaultdict(list)
if self.max_use_ratio is not None:
max_uses = self.max_use_ratio * len(fgraph.nodes)
max_uses = self.max_use_ratio * len(fgraph.apply_nodes)
runs = defaultdict(int)
else:
runs = None
......@@ -91,10 +91,10 @@ if 0:
self.backtrack(new_r.owner, tasks)
# # == NOT IDEAL == #
# for node in fgraph.nodes:
# for node in fgraph.apply_nodes:
# importer(node)
for node in fgraph.toposort():
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
......@@ -124,7 +124,7 @@ if 0:
# if isinstance(in1, basestring):
# candidate.match[in1] = in2
# for client in node.clients:
# op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
......
......@@ -7,6 +7,9 @@ import graph
class AlreadyThere(Exception):
"""Raised by a Feature's on_attach callback method if the FunctionGraph
attempting to attach the feature already has a functionally identical
feature."""
pass
......@@ -32,13 +35,18 @@ class Feature(object):
def on_attach(self, function_graph):
"""
Called by FunctionGraph.attach_feature, the method that attaches the feature
to the FunctionGraph. Since this is called after the FunctionGraph
is initially populated, this is where you should run checks on the
initial contents of the FunctionGraph.
The feature has great freedom in what
it can do with the function_graph: it may, for example, add methods
to it dynamically.
Called by FunctionGraph.attach_feature, the method that attaches
the feature to the FunctionGraph. Since this is called after the
FunctionGraph is initially populated, this is where you should
run checks on the initial contents of the FunctionGraph.
The on_attach method may raise the AlreadyThere exception to cancel
the attach operation if it detects that another Feature instance
implementing the same functionality is already atttached to the
FunctionGraph.
The feature has great freedom in what it can do with the
function_graph: it may, for example, add methods to it dynamically.
"""
def on_detach(self, function_graph):
......@@ -219,7 +227,7 @@ class ReplaceValidate(History, Validator):
"""
chk = fgraph.replace_all_validate(replacements, reason)
for rm in remove:
if rm in fgraph.nodes or rm in fgraph.variables:
if rm in fgraph.apply_nodes or rm in fgraph.variables:
fgraph.revert(chk)
if warn:
out = sys.stderr
......
......@@ -1002,7 +1002,7 @@ def test_many_arg_elemwise():
#assert that the test was done on the gpu.
if mode is mode_with_gpu:
assert any([isinstance(node.op, cuda.GpuElemwise)
for node in f.maker.fgraph.nodes])
for node in f.maker.fgraph.apply_nodes])
#test the optijmization local_gpu_elemwise_1
f = theano.function(
......@@ -1013,7 +1013,7 @@ def test_many_arg_elemwise():
#assert that the test was done on the gpu.
if mode is mode_with_gpu:
assert any([isinstance(node.op, cuda.GpuElemwise)
for node in f.maker.fgraph.nodes])
for node in f.maker.fgraph.apply_nodes])
assert numpy.allclose(out, outputs[-1])
results_gpu, results_cpu = outputs
......
......@@ -2667,7 +2667,7 @@ class Composite(ScalarOp):
def init_fgraph(self):
fgraph = FunctionGraph(*gof.graph.clone(self.inputs, self.outputs))
gof.MergeOptimizer().optimize(fgraph)
for node in fgraph.nodes:
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise ValueError("The fgraph to Composite must be exclusively"
" composed of ScalarOp instances.")
......
......@@ -1382,7 +1382,7 @@ class GemmOptimizer(Optimizer):
(theano.scalar.Add, theano.scalar.Sub,
theano.scalar.Neg, theano.scalar.Mul))):
continue
if not node in fgraph.nodes:
if not node in fgraph.apply_nodes:
# This mean that we already removed this node from
# the graph
continue
......
......@@ -176,7 +176,7 @@ def inplace_elemwise_optimizer_op(OP):
# We execute `validate` after this number of change.
check_each_change = config.tensor.insert_inplace_optimizer_validate_nb
if check_each_change == -1:
if len(fgraph.nodes) > 500:
if len(fgraph.apply_nodes) > 500:
check_each_change = 10
else:
check_each_change = 1
......@@ -4596,7 +4596,7 @@ class FusionOptimizer(Optimizer):
did_something = False
for node in nodelist:
# Don't try to fuse node that have already been fused.
if node in fgraph.nodes:
if node in fgraph.apply_nodes:
new_outputs = self.optimizer(node)
if new_outputs:
assert len(new_outputs) == len(node.outputs)
......
......@@ -478,7 +478,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
mode='FAST_RUN',
on_unused_input='ignore')
nb_gemm = 0
for node in f.maker.fgraph.nodes:
for node in f.maker.fgraph.apply_nodes:
if node.op == T.dot:
raise Failure('dot not changed to gemm_inplace in graph')
if node.op == _dot22:
......@@ -488,7 +488,7 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
assert nb_gemm == expected_nb_gemm, (nb_gemm, expected_nb_gemm)
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
allow_input_downcast=True, on_unused_input='ignore')
for node in g.maker.fgraph.nodes:
for node in g.maker.fgraph.apply_nodes:
if node.op == gemm_inplace:
raise Exception('gemm_inplace in original graph')
......@@ -561,14 +561,14 @@ def test_gemm_opt_double_gemm():
try:
f = inplace_func([Param(ii, mutable=True) for ii in i], o,
mode='FAST_RUN', on_unused_input='ignore')
for node in f.maker.fgraph.nodes:
for node in f.maker.fgraph.apply_nodes:
if node.op == T.dot:
raise Failure('dot in graph')
if node.op == _dot22:
raise Failure('_dot22 in graph')
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
on_unused_input='ignore')
#for node in g.maker.fgraph.nodes:
#for node in g.maker.fgraph.apply_nodes:
# if node.op == gemm_inplace: raise Failure('gemm_inplace in graph')
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
......@@ -760,11 +760,11 @@ def test_gemm_opt_vector_stuff():
u, v = T.vector(), T.vector()
f = inplace_func([a, u, v], a + T.dot(u, v), mode='FAST_RUN')
if gemm_inplace in [n.op for n in f.maker.fgraph.nodes]:
if gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]:
raise Failure('gemm_inplace in graph')
f = inplace_func([a, u, X, Y], a * u + T.dot(X, Y), mode='FAST_RUN')
if (gemm_inplace in [n.op for n in f.maker.fgraph.nodes]):
if (gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]):
raise Failure('gemm_inplace in graph')
......@@ -823,16 +823,16 @@ def test_inplace0():
f = inplace_func([Z, b, R, S],
[Z * (Z + b * T.dot(R, S).T)], mode='FAST_RUN')
if (gemm_inplace in [n.op for n in f.maker.fgraph.nodes]):
if (gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]):
print pp(f.maker.fgraph.outputs[0])
raise Failure('gemm_inplace in graph')
assert gemm_no_inplace in [n.op for n in f.maker.fgraph.nodes]
assert gemm_no_inplace in [n.op for n in f.maker.fgraph.apply_nodes]
# gemm_inplace should be inserted here, to work in-place on Z*c
f = inplace_func([X, Y, Z, a, b, R, S, c],
[Z * (c * Z + a * T.dot(X, Y) + b * T.dot(R, S).T)],
mode='FAST_RUN')
if (not gemm_inplace in [n.op for n in f.maker.fgraph.nodes]):
if (not gemm_inplace in [n.op for n in f.maker.fgraph.apply_nodes]):
theano.printing.debugprint(f)
raise Failure('no gemm_inplace in graph')
......@@ -844,7 +844,7 @@ def test_inplace1():
[Z + Z + T.dot(X, Y)], mode='FAST_RUN')
#theano.printing.debugprint(f)
# it doesn't work inplace because we didn't mark Z as mutable input
assert [n.op for n in f.maker.fgraph.nodes] == [gemm_no_inplace]
assert [n.op for n in f.maker.fgraph.apply_nodes] == [gemm_no_inplace]
def test_dot22():
......
......@@ -590,7 +590,7 @@ def test_naacl_model(iters_per_unsup=3, iters_per_sup=3,
#print input_pretraining_gradients[4].owner.inputs[1].owner.inputs
#sys.exit()
#print "PROGRAM LEN %i HASH %i"% (len(m.pretraining_update.maker.fgraph.nodes), reduce(lambda a, b: hash(a) ^ hash(b),prog_str))
#print "PROGRAM LEN %i HASH %i"% (len(m.pretraining_update.maker.fgraph.apply_nodes), reduce(lambda a, b: hash(a) ^ hash(b),prog_str))
rng = N.random.RandomState(unittest_tools.fetch_seed(23904))
......
......@@ -2791,14 +2791,14 @@ def test_local_pow_specialize_device_more_aggressive_on_cpu():
f = function([v], v ** (15), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert len(nodes) == 1
assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.nodes) == 6
assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.apply_nodes) == 6
assert isinstance(nodes[0].scalar_op, theano.scalar.Composite)
assert numpy.allclose(f(val), val ** 15)
f = function([v], v ** (-15), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert len(nodes) == 2
assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.nodes) == 6
assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.apply_nodes) == 6
assert isinstance(nodes[0].scalar_op, theano.scalar.Composite)
assert isinstance(nodes[-1].scalar_op, theano.scalar.basic.Inv)
assert numpy.allclose(f(val_no0), val_no0 ** (-15))
......@@ -2806,14 +2806,14 @@ def test_local_pow_specialize_device_more_aggressive_on_cpu():
f = function([v], v ** (16), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert len(nodes) == 1
assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.nodes) == 4
assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.apply_nodes) == 4
assert isinstance(nodes[0].scalar_op, theano.scalar.Composite)
assert numpy.allclose(f(val), val ** 16)
f = function([v], v ** (-16), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert len(nodes) == 2
assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.nodes) == 4
assert len(f.maker.fgraph.toposort()[0].op.scalar_op.fgraph.apply_nodes) == 4
assert isinstance(nodes[0].scalar_op, theano.scalar.Composite)
assert isinstance(nodes[-1].scalar_op, theano.scalar.basic.Inv)
assert numpy.allclose(f(val_no0), val_no0 ** (-16))
......@@ -3204,21 +3204,21 @@ class T_local_erfc(unittest.TestCase):
f = theano.function([x], T.log(T.erfc(x)), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.fgraph.nodes) == 23, len(f.maker.fgraph.nodes)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
assert all(numpy.isfinite(f(val)))
f = theano.function([x], T.log(T.erfc(-x)), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.fgraph.nodes) == 24, len(f.maker.fgraph.nodes)
assert len(f.maker.fgraph.apply_nodes) == 24, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
assert all(numpy.isfinite(f(-val)))
f = theano.function([x], T.log(T.erfc(x)), mode=mode_fusion)
assert len(f.maker.fgraph.nodes) == 1, len(f.maker.fgraph.nodes)
assert len(f.maker.fgraph.apply_nodes) == 1, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
assert len(f.maker.fgraph.toposort()[0].fgraph.toposort()[
0].op.scalar_op.fgraph.nodes)==2,len(f.maker.fgraph.toposort()[0].fgraph.toposort()[0].op.scalar_op.fgraph.nodes)
0].op.scalar_op.fgraph.apply_nodes)==2,len(f.maker.fgraph.toposort()[0].fgraph.toposort()[0].op.scalar_op.fgraph.apply_nodes)
#TODO: fix this problem
if theano.config.floatX=="float32" and theano.config.mode in ["DebugMode", "DEBUG_MODE"]:
raise KnownFailureTest(
......@@ -3247,7 +3247,7 @@ class T_local_erfc(unittest.TestCase):
f = theano.function([x], T.grad(T.log(T.erfc(x)).sum(), x), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.fgraph.nodes) == 23, len(f.maker.fgraph.nodes)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes)
assert all(numpy.isfinite(f(val)))
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
......@@ -3255,14 +3255,14 @@ class T_local_erfc(unittest.TestCase):
f = theano.function([x], T.mul(T.exp(T.neg(T.sqr(x))), -
10.12837917) / T.erfc(x), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.fgraph.nodes) == 23, len(f.maker.fgraph.nodes)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
assert all(numpy.isfinite(f(val)))
#test that we work without the mul
f = theano.function([x], T.exp(T.neg(T.sqr(x))) / T.erfc(x), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.fgraph.nodes) == 23, len(f.maker.fgraph.nodes)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
assert all(numpy.isfinite(f(val)))
......@@ -3270,14 +3270,14 @@ class T_local_erfc(unittest.TestCase):
f = theano.function([x, y], T.exp(T.neg(T.sqr(x))) / T.erfc(
y), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.fgraph.nodes) == 5, len(f.maker.fgraph.nodes)
assert len(f.maker.fgraph.apply_nodes) == 5, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
f(val, val - 3)
#test that we work without the sqr and neg
f = theano.function([x], T.exp(T.mul(-1, x, x)) / T.erfc(x), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.fgraph.nodes) == 22, len(f.maker.fgraph.nodes)
assert len(f.maker.fgraph.apply_nodes) == 22, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
assert all(numpy.isfinite(f(val)))
......@@ -3285,13 +3285,13 @@ class T_local_erfc(unittest.TestCase):
f = theano.function([x], T.grad(T.log(T.erfc(2 * x)).sum(),
x), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.fgraph.nodes) == 23, len(f.maker.fgraph.nodes)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes)
assert numpy.isfinite(f(val)).all()
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
f = theano.function([x], T.grad(T.log(T.erfc(x)).sum(), x),
mode=mode_fusion)
assert len(f.maker.fgraph.nodes) == 1, len(f.maker.fgraph.nodes)
assert len(f.maker.fgraph.apply_nodes) == 1, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
#TODO: fix this problem
......@@ -3413,18 +3413,18 @@ class T_local_sum(unittest.TestCase):
a = T.tensor3()
input = numpy.arange(3 * 3 * 3, dtype=config.floatX).reshape(3, 3, 3)
f = theano.function([a], a.sum(), mode=self.mode)
assert len(f.maker.fgraph.nodes) == 1
assert len(f.maker.fgraph.apply_nodes) == 1
assert numpy.allclose(f(input), input.sum())
f = theano.function([a], a.sum([0, 1, 2]), mode=self.mode)
assert len(f.maker.fgraph.nodes) == 1
assert len(f.maker.fgraph.apply_nodes) == 1
assert numpy.allclose(f(input), input.sum())
backup = config.warn.sum_sum_bug
config.warn.sum_sum_bug = False
try:
f = theano.function([a], a.sum(0).sum(0).sum(0), mode=self.mode)
assert len(f.maker.fgraph.nodes) == 1
assert len(f.maker.fgraph.apply_nodes) == 1
assert numpy.allclose(f(input), input.sum())
finally:
config.warn.sum_sum_bug = backup
......@@ -3440,20 +3440,20 @@ class T_local_sum(unittest.TestCase):
for d, dd in dims:
f = theano.function([a], a.sum(d).sum(dd), mode=self.mode)
assert numpy.allclose(f(input), input.sum(d).sum(dd))
assert len(f.maker.fgraph.nodes) == 1
assert len(f.maker.fgraph.apply_nodes) == 1
for d, dd in dims:
f = theano.function([a], a.sum(d).sum(dd).
sum(0), mode=self.mode)
assert numpy.allclose(f(input), input.sum(d).sum(dd).sum(0))
assert len(f.maker.fgraph.nodes) == 1
assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]:
f = theano.function([a], a.sum(d).sum(None), mode=self.mode)
assert numpy.allclose(f(input), input.sum(d).sum())
assert len(f.maker.fgraph.nodes) == 1
assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]:
f = theano.function([a], a.sum(None).sum(), mode=self.mode)
assert numpy.allclose(f(input), input.sum())
assert len(f.maker.fgraph.nodes) == 1
assert len(f.maker.fgraph.apply_nodes) == 1
finally:
config.warn.sum_sum_bug = backup
......@@ -3468,23 +3468,23 @@ class T_local_sum(unittest.TestCase):
f = theano.function([a], t_like(a).sum(None), mode=mode)
assert numpy.allclose(f(input), n_like(input).sum())
assert len(f.maker.fgraph.nodes) == nb_nodes[0]
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[0]
f = theano.function([a], t_like(a).sum([0, 1, 2]), mode=mode)
assert numpy.allclose(f(input), n_like(input).sum())
assert len(f.maker.fgraph.nodes) == nb_nodes[0]
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[0]
for d in range(3):
f = theano.function([a], t_like(a).sum(d), mode=mode)
assert numpy.allclose(f(input), n_like(input).sum(d))
assert len(f.maker.fgraph.nodes) == nb_nodes[1]
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[1]
topo = f.maker.fgraph.toposort()
assert topo[-1].op == T.alloc
assert not any([isinstance(node.op, T.Sum) for node in topo])
for i in range(3):
f = theano.function([a], t_like(a).sum(i), mode=mode)
assert numpy.allclose(f(input), n_like(input).sum(i))
assert len(f.maker.fgraph.nodes) == nb_nodes[2]
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[2]
topo = f.maker.fgraph.toposort()
assert topo[-1].op == T.alloc
assert not any([isinstance(node.op, T.Sum) for node in topo])
......@@ -3497,7 +3497,7 @@ class T_local_sum(unittest.TestCase):
sum(d).sum(dd), mode=mode)
assert numpy.allclose(f(input),
n_like(input).sum(d).sum(dd))
assert len(f.maker.fgraph.nodes) == nb_nodes[3]
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[3]
topo = f.maker.fgraph.toposort()
assert topo[-1].op == T.alloc
assert not any([isinstance(node.op,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论