提交 4b16dc1f authored 作者: James Bergstra's avatar James Bergstra

merge

...@@ -818,85 +818,87 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -818,85 +818,87 @@ class OpWiseCLinker(link.LocalLinker):
# Acquire lock on compilation directory. # Acquire lock on compilation directory.
get_lock() get_lock()
try:
env = self.env env = self.env
order = env.toposort() order = env.toposort()
no_recycling = self.no_recycling no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage(env, order, input_storage, output_storage)
if self.allow_gc:
computed, last_user = link.gc_helper(order)
post_thunk_old_storage = []
else:
post_thunk_old_storage = None
thunks = []
for node in order:
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
try:
e = Env(*graph.clone(node.inputs, node.outputs))
e.toposort = lambda: e.nodes
if any(isinstance(input, graph.Value) for input in node.inputs): input_storage, output_storage, storage_map = link.map_storage(env, order, input_storage, output_storage)
desc = None if self.allow_gc:
else: computed, last_user = link.gc_helper(order)
desc = (node.op, post_thunk_old_storage = []
tuple(input.type for input in node.inputs), else:
tuple(input.type for input in node.inputs), post_thunk_old_storage = None
tuple(output in no_recycling for output in node.outputs),
tuple(node.inputs.count(input) for input in node.inputs))
thunks = []
for node in order:
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
try: try:
cl = self.__cache__.get(desc) e = Env(*graph.clone(node.inputs, node.outputs))
except Exception, exc: e.toposort = lambda: e.nodes
#print >> sys.stderr, "INFO: failed to hash %s: %s. Node will not be cached." % (node, exc)
cl = None if any(isinstance(input, graph.Value) for input in node.inputs):
if cl is None: desc = None
cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling]) else:
if desc is not None: desc = (node.op,
try: tuple(input.type for input in node.inputs),
self.__cache__[desc] = cl tuple(input.type for input in node.inputs),
except: tuple(output in no_recycling for output in node.outputs),
pass tuple(node.inputs.count(input) for input in node.inputs))
thunk, node_input_filters, node_output_filters = cl.make_thunk( try:
input_storage = node_input_storage, cl = self.__cache__.get(desc)
output_storage = node_output_storage) except Exception, exc:
thunk.inputs = node_input_storage #print >> sys.stderr, "INFO: failed to hash %s: %s. Node will not be cached." % (node, exc)
thunk.outputs = node_output_storage cl = None
thunks.append(thunk) if cl is None:
except (NotImplementedError, utils.MethodNotDefined): cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling])
if self.fallback_on_perform: if desc is not None:
p = node.op.perform try:
thunk = lambda p = p, i = node_input_storage, o = node_output_storage, n = node: p(n, [x[0] for x in i], o) self.__cache__[desc] = cl
except:
pass
thunk, node_input_filters, node_output_filters = cl.make_thunk(
input_storage = node_input_storage,
output_storage = node_output_storage)
thunk.inputs = node_input_storage thunk.inputs = node_input_storage
thunk.outputs = node_output_storage thunk.outputs = node_output_storage
thunk.perform = p
thunks.append(thunk) thunks.append(thunk)
else: except (NotImplementedError, utils.MethodNotDefined):
raise if self.fallback_on_perform:
p = node.op.perform
if self.allow_gc: thunk = lambda p = p, i = node_input_storage, o = node_output_storage, n = node: p(n, [x[0] for x in i], o)
post_thunk_old_storage.append([storage_map[input] thunk.inputs = node_input_storage
for input in node.inputs thunk.outputs = node_output_storage
if (input in computed) and (input not in env.outputs) and node == last_user[input]]) thunk.perform = p
thunks.append(thunk)
if no_recycling is True: else:
no_recycling = storage_map.values() raise
no_recycling = utils.difference(no_recycling, input_storage)
else: if self.allow_gc:
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs] post_thunk_old_storage.append([storage_map[input]
for input in node.inputs
if (input in computed) and (input not in env.outputs) and node == last_user[input]])
if no_recycling is True:
no_recycling = storage_map.values()
no_recycling = utils.difference(no_recycling, input_storage)
else:
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
f = link.streamline(env, thunks, order, f = link.streamline(env, thunks, order,
post_thunk_old_storage, post_thunk_old_storage,
no_recycling = no_recycling, no_recycling = no_recycling,
nice_errors = self.nice_errors) nice_errors = self.nice_errors)
f.allow_gc = self.allow_gc f.allow_gc = self.allow_gc
# Release lock on compilation directory. finally:
release_lock() # Release lock on compilation directory.
release_lock()
return f, [link.Container(input, storage) for input, storage in zip(env.inputs, input_storage)], \ return f, [link.Container(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[link.Container(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \ [link.Container(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
......
"""WRITEME""" """WRITEME"""
import sys
from copy import copy from copy import copy
import graph import graph
import utils import utils
...@@ -172,8 +172,9 @@ class Env(utils.object2): ...@@ -172,8 +172,9 @@ class Env(utils.object2):
Updates the list of clients of r with new_clients. Updates the list of clients of r with new_clients.
""" """
if set(r.clients).intersection(set(new_clients)): if set(r.clients).intersection(set(new_clients)):
print 'RCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in r.clients] print >> sys.stderr, 'ERROR: clients intersect!'
print 'NCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in new_clients] print >> sys.stderr, ' RCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in r.clients]
print >> sys.stderr, ' NCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in new_clients]
assert not set(r.clients).intersection(set(new_clients)) assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients r.clients += new_clients
...@@ -187,8 +188,9 @@ class Env(utils.object2): ...@@ -187,8 +188,9 @@ class Env(utils.object2):
for entry in clients_to_remove: for entry in clients_to_remove:
r.clients.remove(entry) r.clients.remove(entry)
if entry in r.clients: if entry in r.clients:
print 'ENTRY', repr(entry), type(entry[0]) print >> sys.stderr, 'ERROR: DUPLICATE CLIENT ENTRY...'
print 'CLIENTS', repr(r.clients) print >> sys.stderr, ' ENTRY', repr(entry), type(entry[0])
print >> sys.stderr, ' CLIENTS', repr(r.clients)
assert entry not in r.clients # an op,i pair should be unique assert entry not in r.clients # an op,i pair should be unique
if not r.clients: if not r.clients:
if prune: if prune:
...@@ -330,10 +332,15 @@ class Env(utils.object2): ...@@ -330,10 +332,15 @@ class Env(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops # because it makes it easier to implement some optimizations for multiple-output ops
return return
for node, i in list(r.clients): #copy the client list for iteration for node, i in list(r.clients): # copy the client list for iteration
assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r) assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r)
self.change_input(node, i, new_r, reason=reason) self.change_input(node, i, new_r, reason=reason)
# sometimes the following is triggered. If you understand why, please explain to James.
# He's curious... -JB20090331
#if len(r.clients) != 0:
# print >> sys.stderr, "WARNING: CLIENTS LEFT AFTER REPLACE", r, r.clients
def replace_all(self, pairs, reason=None): def replace_all(self, pairs, reason=None):
"""WRITEME""" """WRITEME"""
for r, new_r in pairs: for r, new_r in pairs:
......
...@@ -875,6 +875,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -875,6 +875,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
while changed and not max_use_abort: while changed and not max_use_abort:
changed = False changed = False
for node in start_from:
assert node in env.outputs
q = deque(graph.io_toposort(env.inputs, start_from)) q = deque(graph.io_toposort(env.inputs, start_from))
...@@ -914,6 +916,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -914,6 +916,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def _check_chain(r, chain): def _check_chain(r, chain):
"""WRITEME""" """WRITEME"""
chain = list(reversed(chain)) chain = list(reversed(chain))
while chain: while chain:
elem = chain.pop() elem = chain.pop()
...@@ -933,7 +936,10 @@ def _check_chain(r, chain): ...@@ -933,7 +936,10 @@ def _check_chain(r, chain):
return False return False
if chain: if chain:
r = r.owner.inputs[chain.pop()] r = r.owner.inputs[chain.pop()]
#print 'check_chain', _check_chain.n_calls
#_check_chain.n_calls += 1
return r return r
#_check_chain.n_calls = 0
def check_chain(r, *chain): def check_chain(r, *chain):
"""WRITEME""" """WRITEME"""
......
...@@ -102,7 +102,7 @@ class ReplaceValidate(History, Validator): ...@@ -102,7 +102,7 @@ class ReplaceValidate(History, Validator):
env.replace(r, new_r, reason=reason) env.replace(r, new_r, reason=reason)
except Exception, e: except Exception, e:
if 'The type of the replacement must be the same' not in str(e) and 'does not belong to this Env' not in str(e): if 'The type of the replacement must be the same' not in str(e) and 'does not belong to this Env' not in str(e):
print >>sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e print >> sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e, reason
env.revert(chk) # this might fail if the error is in a listener: (env.replace kinda needs better internal error handling) env.revert(chk) # this might fail if the error is in a listener: (env.replace kinda needs better internal error handling)
raise raise
try: try:
......
...@@ -2135,7 +2135,9 @@ class TensorDotGrad(Op): ...@@ -2135,7 +2135,9 @@ class TensorDotGrad(Op):
tensordot_grad = TensorDotGrad tensordot_grad = TensorDotGrad
class TensorDot(Op): class TensorDot(Op):
"""Compute tensor-tensor products over the given axes. See numpy documentation for details. """Compute tensor-tensor products over the given axes.
See numpy documentation for details.
(http://docs.scipy.org/doc/numpy/reference/generated/numpy.tensordot.html)
""" """
......
...@@ -16,6 +16,8 @@ import itertools ...@@ -16,6 +16,8 @@ import itertools
import sys import sys
from .. import compile #to register the optimizer built by this file from .. import compile #to register the optimizer built by this file
from ..compile.debugmode import _debugprint
# Utilities # Utilities
...@@ -669,11 +671,13 @@ class Canonizer(gof.LocalOptimizer): ...@@ -669,11 +671,13 @@ class Canonizer(gof.LocalOptimizer):
def transform(self, node): def transform(self, node):
op = node.op op = node.op
inputs = node.inputs
out = node.outputs[0]
if op not in [self.main, self.inverse, self.reciprocal]: if op not in [self.main, self.inverse, self.reciprocal]:
return False return False
inputs = node.inputs
out = node.outputs[0]
assert len(node.outputs) == 1
# I'm not sure if this is actually needed but the following # I'm not sure if this is actually needed but the following
# block of code puts into "reorg" whether or not we are going # block of code puts into "reorg" whether or not we are going
# to change the structure of the graph. For example if we have # to change the structure of the graph. For example if we have
...@@ -714,6 +718,12 @@ class Canonizer(gof.LocalOptimizer): ...@@ -714,6 +718,12 @@ class Canonizer(gof.LocalOptimizer):
if new.type.broadcastable != out.type.broadcastable: if new.type.broadcastable != out.type.broadcastable:
new = T.fill(out, new) new = T.fill(out, new)
if 0:
print 'BEFORE'
_debugprint(out, ' ', depth=4)
print 'AFTER'
_debugprint(new, ' ', depth=4)
# if our if's above worked, this should be true. OTW investigate. # if our if's above worked, this should be true. OTW investigate.
if new.type != out.type: if new.type != out.type:
print >> sys.stderr, 'CANONIZE FAILED: new, out = ', new, ',', out, 'types', new.type, ',', out.type print >> sys.stderr, 'CANONIZE FAILED: new, out = ', new, ',', out, 'types', new.type, ',', out.type
......
...@@ -62,6 +62,15 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -62,6 +62,15 @@ class test_dimshuffle_lift(unittest.TestCase):
self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g)) self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
def test_add_canonizer_problem0():
#observed in a real graph
n_segments = 10
label = lscalar('label')
segment_labels = label + numpy.asarray([0] * n_segments, dtype='int64')
r = segment_labels * 5
f = function([label], r)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论