提交 4b16dc1f authored 作者: James Bergstra's avatar James Bergstra

merge

...@@ -818,6 +818,7 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -818,6 +818,7 @@ class OpWiseCLinker(link.LocalLinker):
# Acquire lock on compilation directory. # Acquire lock on compilation directory.
get_lock() get_lock()
try:
env = self.env env = self.env
order = env.toposort() order = env.toposort()
...@@ -895,6 +896,7 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -895,6 +896,7 @@ class OpWiseCLinker(link.LocalLinker):
f.allow_gc = self.allow_gc f.allow_gc = self.allow_gc
finally:
# Release lock on compilation directory. # Release lock on compilation directory.
release_lock() release_lock()
......
"""WRITEME""" """WRITEME"""
import sys
from copy import copy from copy import copy
import graph import graph
import utils import utils
...@@ -172,8 +172,9 @@ class Env(utils.object2): ...@@ -172,8 +172,9 @@ class Env(utils.object2):
Updates the list of clients of r with new_clients. Updates the list of clients of r with new_clients.
""" """
if set(r.clients).intersection(set(new_clients)): if set(r.clients).intersection(set(new_clients)):
print 'RCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in r.clients] print >> sys.stderr, 'ERROR: clients intersect!'
print 'NCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in new_clients] print >> sys.stderr, ' RCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in r.clients]
print >> sys.stderr, ' NCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in new_clients]
assert not set(r.clients).intersection(set(new_clients)) assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients r.clients += new_clients
...@@ -187,8 +188,9 @@ class Env(utils.object2): ...@@ -187,8 +188,9 @@ class Env(utils.object2):
for entry in clients_to_remove: for entry in clients_to_remove:
r.clients.remove(entry) r.clients.remove(entry)
if entry in r.clients: if entry in r.clients:
print 'ENTRY', repr(entry), type(entry[0]) print >> sys.stderr, 'ERROR: DUPLICATE CLIENT ENTRY...'
print 'CLIENTS', repr(r.clients) print >> sys.stderr, ' ENTRY', repr(entry), type(entry[0])
print >> sys.stderr, ' CLIENTS', repr(r.clients)
assert entry not in r.clients # an op,i pair should be unique assert entry not in r.clients # an op,i pair should be unique
if not r.clients: if not r.clients:
if prune: if prune:
...@@ -330,10 +332,15 @@ class Env(utils.object2): ...@@ -330,10 +332,15 @@ class Env(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops # because it makes it easier to implement some optimizations for multiple-output ops
return return
for node, i in list(r.clients): #copy the client list for iteration for node, i in list(r.clients): # copy the client list for iteration
assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r) assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r)
self.change_input(node, i, new_r, reason=reason) self.change_input(node, i, new_r, reason=reason)
# sometimes the following is triggered. If you understand why, please explain to James.
# He's curious... -JB20090331
#if len(r.clients) != 0:
# print >> sys.stderr, "WARNING: CLIENTS LEFT AFTER REPLACE", r, r.clients
def replace_all(self, pairs, reason=None): def replace_all(self, pairs, reason=None):
"""WRITEME""" """WRITEME"""
for r, new_r in pairs: for r, new_r in pairs:
......
...@@ -875,6 +875,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -875,6 +875,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
while changed and not max_use_abort: while changed and not max_use_abort:
changed = False changed = False
for node in start_from:
assert node in env.outputs
q = deque(graph.io_toposort(env.inputs, start_from)) q = deque(graph.io_toposort(env.inputs, start_from))
...@@ -914,6 +916,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -914,6 +916,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def _check_chain(r, chain): def _check_chain(r, chain):
"""WRITEME""" """WRITEME"""
chain = list(reversed(chain)) chain = list(reversed(chain))
while chain: while chain:
elem = chain.pop() elem = chain.pop()
...@@ -933,7 +936,10 @@ def _check_chain(r, chain): ...@@ -933,7 +936,10 @@ def _check_chain(r, chain):
return False return False
if chain: if chain:
r = r.owner.inputs[chain.pop()] r = r.owner.inputs[chain.pop()]
#print 'check_chain', _check_chain.n_calls
#_check_chain.n_calls += 1
return r return r
#_check_chain.n_calls = 0
def check_chain(r, *chain): def check_chain(r, *chain):
"""WRITEME""" """WRITEME"""
......
...@@ -102,7 +102,7 @@ class ReplaceValidate(History, Validator): ...@@ -102,7 +102,7 @@ class ReplaceValidate(History, Validator):
env.replace(r, new_r, reason=reason) env.replace(r, new_r, reason=reason)
except Exception, e: except Exception, e:
if 'The type of the replacement must be the same' not in str(e) and 'does not belong to this Env' not in str(e): if 'The type of the replacement must be the same' not in str(e) and 'does not belong to this Env' not in str(e):
print >>sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e print >> sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e, reason
env.revert(chk) # this might fail if the error is in a listener: (env.replace kinda needs better internal error handling) env.revert(chk) # this might fail if the error is in a listener: (env.replace kinda needs better internal error handling)
raise raise
try: try:
......
...@@ -2135,7 +2135,9 @@ class TensorDotGrad(Op): ...@@ -2135,7 +2135,9 @@ class TensorDotGrad(Op):
tensordot_grad = TensorDotGrad tensordot_grad = TensorDotGrad
class TensorDot(Op): class TensorDot(Op):
"""Compute tensor-tensor products over the given axes. See numpy documentation for details. """Compute tensor-tensor products over the given axes.
See numpy documentation for details.
(http://docs.scipy.org/doc/numpy/reference/generated/numpy.tensordot.html)
""" """
......
...@@ -16,6 +16,8 @@ import itertools ...@@ -16,6 +16,8 @@ import itertools
import sys import sys
from .. import compile #to register the optimizer built by this file from .. import compile #to register the optimizer built by this file
from ..compile.debugmode import _debugprint
# Utilities # Utilities
...@@ -669,11 +671,13 @@ class Canonizer(gof.LocalOptimizer): ...@@ -669,11 +671,13 @@ class Canonizer(gof.LocalOptimizer):
def transform(self, node): def transform(self, node):
op = node.op op = node.op
inputs = node.inputs
out = node.outputs[0]
if op not in [self.main, self.inverse, self.reciprocal]: if op not in [self.main, self.inverse, self.reciprocal]:
return False return False
inputs = node.inputs
out = node.outputs[0]
assert len(node.outputs) == 1
# I'm not sure if this is actually needed but the following # I'm not sure if this is actually needed but the following
# block of code puts into "reorg" whether or not we are going # block of code puts into "reorg" whether or not we are going
# to change the structure of the graph. For example if we have # to change the structure of the graph. For example if we have
...@@ -714,6 +718,12 @@ class Canonizer(gof.LocalOptimizer): ...@@ -714,6 +718,12 @@ class Canonizer(gof.LocalOptimizer):
if new.type.broadcastable != out.type.broadcastable: if new.type.broadcastable != out.type.broadcastable:
new = T.fill(out, new) new = T.fill(out, new)
if 0:
print 'BEFORE'
_debugprint(out, ' ', depth=4)
print 'AFTER'
_debugprint(new, ' ', depth=4)
# if our if's above worked, this should be true. OTW investigate. # if our if's above worked, this should be true. OTW investigate.
if new.type != out.type: if new.type != out.type:
print >> sys.stderr, 'CANONIZE FAILED: new, out = ', new, ',', out, 'types', new.type, ',', out.type print >> sys.stderr, 'CANONIZE FAILED: new, out = ', new, ',', out, 'types', new.type, ',', out.type
......
...@@ -62,6 +62,15 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -62,6 +62,15 @@ class test_dimshuffle_lift(unittest.TestCase):
self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g)) self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
def test_add_canonizer_problem0():
#observed in a real graph
n_segments = 10
label = lscalar('label')
segment_labels = label + numpy.asarray([0] * n_segments, dtype='int64')
r = segment_labels * 5
f = function([label], r)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论