提交 4b16dc1f authored 作者: James Bergstra's avatar James Bergstra

merge

......@@ -818,85 +818,87 @@ class OpWiseCLinker(link.LocalLinker):
# Acquire lock on compilation directory.
get_lock()
try:
env = self.env
order = env.toposort()
no_recycling = self.no_recycling
input_storage, output_storage, storage_map = link.map_storage(env, order, input_storage, output_storage)
if self.allow_gc:
computed, last_user = link.gc_helper(order)
post_thunk_old_storage = []
else:
post_thunk_old_storage = None
thunks = []
for node in order:
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
try:
e = Env(*graph.clone(node.inputs, node.outputs))
e.toposort = lambda: e.nodes
env = self.env
order = env.toposort()
no_recycling = self.no_recycling
if any(isinstance(input, graph.Value) for input in node.inputs):
desc = None
else:
desc = (node.op,
tuple(input.type for input in node.inputs),
tuple(input.type for input in node.inputs),
tuple(output in no_recycling for output in node.outputs),
tuple(node.inputs.count(input) for input in node.inputs))
input_storage, output_storage, storage_map = link.map_storage(env, order, input_storage, output_storage)
if self.allow_gc:
computed, last_user = link.gc_helper(order)
post_thunk_old_storage = []
else:
post_thunk_old_storage = None
thunks = []
for node in order:
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
try:
cl = self.__cache__.get(desc)
except Exception, exc:
#print >> sys.stderr, "INFO: failed to hash %s: %s. Node will not be cached." % (node, exc)
cl = None
if cl is None:
cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling])
if desc is not None:
try:
self.__cache__[desc] = cl
except:
pass
thunk, node_input_filters, node_output_filters = cl.make_thunk(
input_storage = node_input_storage,
output_storage = node_output_storage)
thunk.inputs = node_input_storage
thunk.outputs = node_output_storage
thunks.append(thunk)
except (NotImplementedError, utils.MethodNotDefined):
if self.fallback_on_perform:
p = node.op.perform
thunk = lambda p = p, i = node_input_storage, o = node_output_storage, n = node: p(n, [x[0] for x in i], o)
e = Env(*graph.clone(node.inputs, node.outputs))
e.toposort = lambda: e.nodes
if any(isinstance(input, graph.Value) for input in node.inputs):
desc = None
else:
desc = (node.op,
tuple(input.type for input in node.inputs),
tuple(input.type for input in node.inputs),
tuple(output in no_recycling for output in node.outputs),
tuple(node.inputs.count(input) for input in node.inputs))
try:
cl = self.__cache__.get(desc)
except Exception, exc:
#print >> sys.stderr, "INFO: failed to hash %s: %s. Node will not be cached." % (node, exc)
cl = None
if cl is None:
cl = CLinker().accept(e, [r for r, r2 in zip(e.outputs, node.outputs) if r2 in no_recycling])
if desc is not None:
try:
self.__cache__[desc] = cl
except:
pass
thunk, node_input_filters, node_output_filters = cl.make_thunk(
input_storage = node_input_storage,
output_storage = node_output_storage)
thunk.inputs = node_input_storage
thunk.outputs = node_output_storage
thunk.perform = p
thunks.append(thunk)
else:
raise
if self.allow_gc:
post_thunk_old_storage.append([storage_map[input]
for input in node.inputs
if (input in computed) and (input not in env.outputs) and node == last_user[input]])
if no_recycling is True:
no_recycling = storage_map.values()
no_recycling = utils.difference(no_recycling, input_storage)
else:
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
except (NotImplementedError, utils.MethodNotDefined):
if self.fallback_on_perform:
p = node.op.perform
thunk = lambda p = p, i = node_input_storage, o = node_output_storage, n = node: p(n, [x[0] for x in i], o)
thunk.inputs = node_input_storage
thunk.outputs = node_output_storage
thunk.perform = p
thunks.append(thunk)
else:
raise
if self.allow_gc:
post_thunk_old_storage.append([storage_map[input]
for input in node.inputs
if (input in computed) and (input not in env.outputs) and node == last_user[input]])
if no_recycling is True:
no_recycling = storage_map.values()
no_recycling = utils.difference(no_recycling, input_storage)
else:
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
f = link.streamline(env, thunks, order,
post_thunk_old_storage,
no_recycling = no_recycling,
nice_errors = self.nice_errors)
f = link.streamline(env, thunks, order,
post_thunk_old_storage,
no_recycling = no_recycling,
nice_errors = self.nice_errors)
f.allow_gc = self.allow_gc
f.allow_gc = self.allow_gc
# Release lock on compilation directory.
release_lock()
finally:
# Release lock on compilation directory.
release_lock()
return f, [link.Container(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[link.Container(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
......
"""WRITEME"""
import sys
from copy import copy
import graph
import utils
......@@ -172,8 +172,9 @@ class Env(utils.object2):
Updates the list of clients of r with new_clients.
"""
if set(r.clients).intersection(set(new_clients)):
print 'RCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in r.clients]
print 'NCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in new_clients]
print >> sys.stderr, 'ERROR: clients intersect!'
print >> sys.stderr, ' RCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in r.clients]
print >> sys.stderr, ' NCLIENTS of', r, [(n,i, type(n), id(n)) for n,i in new_clients]
assert not set(r.clients).intersection(set(new_clients))
r.clients += new_clients
......@@ -187,8 +188,9 @@ class Env(utils.object2):
for entry in clients_to_remove:
r.clients.remove(entry)
if entry in r.clients:
print 'ENTRY', repr(entry), type(entry[0])
print 'CLIENTS', repr(r.clients)
print >> sys.stderr, 'ERROR: DUPLICATE CLIENT ENTRY...'
print >> sys.stderr, ' ENTRY', repr(entry), type(entry[0])
print >> sys.stderr, ' CLIENTS', repr(r.clients)
assert entry not in r.clients # an op,i pair should be unique
if not r.clients:
if prune:
......@@ -330,10 +332,15 @@ class Env(utils.object2):
# because it makes it easier to implement some optimizations for multiple-output ops
return
for node, i in list(r.clients): #copy the client list for iteration
for node, i in list(r.clients): # copy the client list for iteration
assert (node == 'output' and self.outputs[i] is r) or (node.inputs[i] is r)
self.change_input(node, i, new_r, reason=reason)
# sometimes the following is triggered. If you understand why, please explain to James.
# He's curious... -JB20090331
#if len(r.clients) != 0:
# print >> sys.stderr, "WARNING: CLIENTS LEFT AFTER REPLACE", r, r.clients
def replace_all(self, pairs, reason=None):
"""WRITEME"""
for r, new_r in pairs:
......
......@@ -875,6 +875,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
while changed and not max_use_abort:
changed = False
for node in start_from:
assert node in env.outputs
q = deque(graph.io_toposort(env.inputs, start_from))
......@@ -914,6 +916,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def _check_chain(r, chain):
"""WRITEME"""
chain = list(reversed(chain))
while chain:
elem = chain.pop()
......@@ -933,7 +936,10 @@ def _check_chain(r, chain):
return False
if chain:
r = r.owner.inputs[chain.pop()]
#print 'check_chain', _check_chain.n_calls
#_check_chain.n_calls += 1
return r
#_check_chain.n_calls = 0
def check_chain(r, *chain):
"""WRITEME"""
......
......@@ -102,7 +102,7 @@ class ReplaceValidate(History, Validator):
env.replace(r, new_r, reason=reason)
except Exception, e:
if 'The type of the replacement must be the same' not in str(e) and 'does not belong to this Env' not in str(e):
print >>sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e
print >> sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e, reason
env.revert(chk) # this might fail if the error is in a listener: (env.replace kinda needs better internal error handling)
raise
try:
......
......@@ -2135,7 +2135,9 @@ class TensorDotGrad(Op):
tensordot_grad = TensorDotGrad
class TensorDot(Op):
"""Compute tensor-tensor products over the given axes. See numpy documentation for details.
"""Compute tensor-tensor products over the given axes.
See numpy documentation for details.
(http://docs.scipy.org/doc/numpy/reference/generated/numpy.tensordot.html)
"""
......
......@@ -16,6 +16,8 @@ import itertools
import sys
from .. import compile #to register the optimizer built by this file
from ..compile.debugmode import _debugprint
# Utilities
......@@ -669,11 +671,13 @@ class Canonizer(gof.LocalOptimizer):
def transform(self, node):
op = node.op
inputs = node.inputs
out = node.outputs[0]
if op not in [self.main, self.inverse, self.reciprocal]:
return False
inputs = node.inputs
out = node.outputs[0]
assert len(node.outputs) == 1
# I'm not sure if this is actually needed but the following
# block of code puts into "reorg" whether or not we are going
# to change the structure of the graph. For example if we have
......@@ -714,6 +718,12 @@ class Canonizer(gof.LocalOptimizer):
if new.type.broadcastable != out.type.broadcastable:
new = T.fill(out, new)
if 0:
print 'BEFORE'
_debugprint(out, ' ', depth=4)
print 'AFTER'
_debugprint(new, ' ', depth=4)
# if our if's above worked, this should be true. OTW investigate.
if new.type != out.type:
print >> sys.stderr, 'CANONIZE FAILED: new, out = ', new, ',', out, 'types', new.type, ',', out.type
......
......@@ -62,6 +62,15 @@ class test_dimshuffle_lift(unittest.TestCase):
self.failUnless(str(g) == "[add(add(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]", str(g))
def test_add_canonizer_problem0():
#observed in a real graph
n_segments = 10
label = lscalar('label')
segment_labels = label + numpy.asarray([0] * n_segments, dtype='int64')
r = segment_labels * 5
f = function([label], r)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论