提交 edad54d2 authored 作者: James Bergstra's avatar James Bergstra

merge

from .. import gof
import sys
class DebugException(Exception):
pass
class DebugLinker(gof.WrapLinker):
def __init__(self,
linkers,
debug_pre = [],
debug_post = [],
copy_originals = False,
check_types = True,
compare_results = True,
compare_fn = lambda x, y: x == y):
gof.WrapLinker.__init__(self,
linkers = linkers,
wrapper = self.wrapper)
self.env = None
self.compare_fn = compare_fn
self.copy_originals = copy_originals
if check_types not in [None, True]:
self.check_types = check_types
if compare_results not in [None, True]:
self.compare_results = compare_results
if not isinstance(debug_pre, (list, tuple)):
debug_pre = [debug_pre]
self.debug_pre = debug_pre
if not isinstance(debug_post, (list, tuple)):
debug_post = [debug_post]
self.debug_post = debug_post
if check_types is not None:
self.debug_post.append(self.check_types)
if compare_results is not None:
self.debug_post.append(self.compare_results)
def accept(self, env, no_recycling = []):
return gof.WrapLinker.accept(self,
env = env,
no_recycling = no_recycling)
def store_value(self, i, node, *thunks):
th1 = thunks[0]
for r, oval in zip(node.outputs, th1.outputs):
r.step = i
r.value = oval[0]
if self.copy_originals:
r.original_value = copy(oval[0])
def check_types(self, i, node, *thunks):
for thunk, linker in zip(thunks, self.linkers):
for r in node.outputs:
try:
r.type.filter(r.value, strict = True)
except TypeError, e:
exc_type, exc_value, exc_trace = sys.exc_info()
exc = DebugException(e, "The output %s was filled with data with the wrong type using linker " \
("%s. This happened at step %i of the program." % (r, linker, i)) + \
"For more info, inspect this exception's 'original_exception', 'debugger', " \
"'output_at_fault', 'step', 'node', 'thunk' and 'linker' fields.")
exc.debugger = self
exc.original_exception = e
exc.output_at_fault = r
exc.step = i
exc.node = node
exc.thunk = thunk
exc.linker = linker
raise DebugException, exc, exc_trace
def compare_results(self, i, node, *thunks):
thunk0 = thunks[0]
linker0 = self.linkers[0]
for thunk, linker in zip(thunks[1:], self.linkers[1:]):
for o, output0, output in zip(node.outputs, thunk0.outputs, thunk.outputs):
if not self.compare_fn(output0[0], output[0]):
exc = DebugException(("The results from %s and %s for output %s are not the same. This happened at step %i." % (linker0, linker, o, step)) + \
"For more info, inspect this exception's 'debugger', 'output', 'output_value1', 'output_value2', " \
"'step', 'node', 'thunk1', 'thunk2', 'linker1' and 'linker2' fields.")
exc.debugger = self
exc.output = o
exc.output_value1 = output0
exc.output_value2 = output
exc.step = i
exc.node = node
exc.thunk1 = thunk0
exc.thunk2 = thunk
exc.linker1 = linker0
exc.linker2 = linker
raise exc
def pre(self, f, inputs, order, thunk_groups):
env = f.env
for r in env.results:
if r.owner is None:
r.step = "value" # this will be overwritten if r is an input
else:
r.step = None
r.value = None
r.original_value = None
if r.owner is None and r not in env.inputs:
r.value = r.data
if self.copy_originals:
r.original_value = copy(r.data)
for idx, (i, r) in enumerate(zip(inputs, env.inputs)):
r.step = "input %i" % idx
r.value = i
if self.copy_originals:
r.original_value = copy(i)
for node, thunk_group in zip(order, thunk_groups):
node.step = None
def wrapper(self, i, node, *thunks):
try:
node.step = i
for f in self.debug_pre:
f(i, node, *thunks)
for thunk in thunks:
thunk()
self.store_value(i, node, *thunks)
for f in self.debug_post:
f(i, node, *thunks)
except Exception, e:
exc_type, exc_value, exc_trace = sys.exc_info()
if isinstance(e, DebugException):
raise
exc = DebugException(e, ("An exception occurred while processing node %s at step %i of the program." % (node, i)) + \
"For more info, inspect this exception's 'original_exception', 'debugger', 'step', 'node' and 'thunks' fields.")
exc.debugger = self
exc.original_exception = e
exc.step = i
exc.node = node
exc.thunks = thunks
raise DebugException, exc, exc_trace
def print_info(i, node, *thunks):
print "step %i, node %s" % (i, node)
def print_from(i, node, *thunks):
print "parents:", ", ".join(str(input.step) for input in node.inputs)
def print_input_shapes(i, node, *thunks):
print "input shapes:", ", ".join(str(input.value.shape) if hasattr(input.value, 'shape') else 'N/A' for input in node.inputs)
def print_input_types(i, node, *thunks):
print "input types:", ", ".join(str(type(input.value)) for input in node.inputs)
def print_sep(i, node, *thunks):
print "==================================="
import numpy
def numpy_compare(a, b, tolerance = 1e-6):
if isinstance(a, numpy.ndarray):
return (abs(a - b) <= tolerance).all()
else:
return a == b
def numpy_debug_linker(pre, post = []):
return DebugLinker([gof.OpWiseCLinker],
pre,
post,
compare_fn = numpy_compare)
from .. import tensor as T
from .. import gof
from copy import copy
class PrinterState(gof.utils.scratchpad):
def __init__(self, props = {}, **more_props):
if isinstance(props, gof.utils.scratchpad):
self.__update__(props)
else:
self.__dict__.update(props)
self.__dict__.update(more_props)
def clone(self, props = {}, **more_props):
return PrinterState(self, **dict(props, **more_props))
class OperatorPrinter:
def __init__(self, operator, precedence, assoc = 'left'):
self.operator = operator
self.precedence = precedence
self.assoc = assoc
def process(self, output, pstate):
pprinter = pstate.pprinter
node = output.owner
if node is None:
raise TypeError("operator %s cannot represent a result with no associated operation" % self.operator)
outer_precedence = getattr(pstate, 'precedence', -999999)
outer_assoc = getattr(pstate, 'assoc', 'none')
if outer_precedence > self.precedence:
parenthesize = True
else:
parenthesize = False
input_strings = []
max_i = len(node.inputs) - 1
for i, input in enumerate(node.inputs):
if self.assoc == 'left' and i != 0 or self.assoc == 'right' and i != max_i:
s = pprinter.process(input, pstate.clone(precedence = self.precedence + 1e-6))
else:
s = pprinter.process(input, pstate.clone(precedence = self.precedence))
input_strings.append(s)
if len(input_strings) == 1:
s = self.operator + input_strings[0]
else:
s = (" %s " % self.operator).join(input_strings)
if parenthesize: return "(%s)" % s
else: return s
class PatternPrinter:
def __init__(self, *patterns):
self.patterns = []
for pattern in patterns:
if isinstance(pattern, str):
self.patterns.append((pattern, ()))
else:
self.patterns.append((pattern[0], pattern[1:]))
def process(self, output, pstate):
pprinter = pstate.pprinter
node = output.owner
if node is None:
raise TypeError("Patterns %s cannot represent a result with no associated operation" % self.patterns)
idx = node.outputs.index(output)
pattern, precedences = self.patterns[idx]
precedences += (1000,) * len(node.inputs)
return pattern % dict((str(i), x)
for i, x in enumerate(pprinter.process(input, pstate.clone(precedence = precedence))
for input, precedence in zip(node.inputs, precedences)))
class FunctionPrinter:
def __init__(self, *names):
self.names = names
def process(self, output, pstate):
pprinter = pstate.pprinter
node = output.owner
if node is None:
raise TypeError("function %s cannot represent a result with no associated operation" % self.names)
idx = node.outputs.index(output)
name = self.names[idx]
return "%s(%s)" % (name, ", ".join([pprinter.process(input, pstate.clone(precedence = -1000))
for input in node.inputs]))
class MemberPrinter:
def __init__(self, *names):
self.names = names
def process(self, output, pstate):
pprinter = pstate.pprinter
node = output.owner
if node is None:
raise TypeError("function %s cannot represent a result with no associated operation" % self.function)
names = self.names
idx = node.outputs.index(output)
name = self.names[idx]
input = node.inputs[0]
return "%s.%s" % (pprinter.process(input, pstate.clone(precedence = 1000)), name)
class IgnorePrinter:
def process(self, output, pstate):
pprinter = pstate.pprinter
node = output.owner
if node is None:
raise TypeError("function %s cannot represent a result with no associated operation" % self.function)
input = node.inputs[0]
return "%s" % pprinter.process(input, pstate)
class DimShufflePrinter:
def __p(self, new_order, pstate, r):
if new_order != () and new_order[0] == 'x':
return "[%s]" % self.__p(new_order[1:], pstate, r)
if list(new_order) == range(r.type.ndim):
return pstate.pprinter.process(r)
if list(new_order) == list(reversed(range(r.type.ndim))):
return "%s.T" % pstate.pprinter.process(r)
return "DimShuffle{%s}(%s)" % (", ".join(map(str, new_order)), pstate.pprinter.process(r))
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print DimShuffle.")
elif isinstance(r.owner.op, T.DimShuffle):
ord = r.owner.op.new_order
return self.__p(ord, pstate, r.owner.inputs[0])
else:
raise TypeError("Can only print DimShuffle.")
class DefaultPrinter:
def __init__(self):
pass
def process(self, r, pstate):
pprinter = pstate.pprinter
node = r.owner
if node is None:
return LeafPrinter().process(r, pstate)
return "%s(%s)" % (str(node.op), ", ".join([pprinter.process(input, pstate.clone(precedence = -1000))
for input in node.inputs]))
class LeafPrinter:
def process(self, r, pstate):
if r.name in greek:
return greek[r.name]
else:
return str(r)
class PPrinter:
def __init__(self):
self.printers = []
def assign(self, condition, printer):
if isinstance(condition, gof.Op):
op = condition
condition = lambda pstate, r: r.owner is not None and r.owner.op == op
self.printers.insert(0, (condition, printer))
def process(self, r, pstate = None):
if pstate is None:
pstate = PrinterState(pprinter = self)
for condition, printer in self.printers:
if condition(pstate, r):
return printer.process(r, pstate)
def clone(self):
cp = copy(self)
cp.printers = list(self.printers)
return cp
def clone_assign(self, condition, printer):
cp = self.clone()
cp.assign(condition, printer)
return cp
def process_graph(self, inputs, outputs):
strings = ["inputs: " + ", ".join(map(str, inputs))]
pprinter = self.clone_assign(lambda pstate, r: r.name is not None and r is not current,
LeafPrinter())
for node in gof.graph.io_toposort(inputs, outputs):
for output in node.outputs:
if output.name is not None or output in outputs:
name = 'outputs[%i]' % outputs.index(output) if output.name is None else output.name
current = output
strings.append("%s = %s" % (name, pprinter.process(output)))
return "\n".join(strings)
special = dict(middle_dot = u"\u00B7",
big_sigma = u"\u03A3")
greek = dict(alpha = u"\u03B1",
beta = u"\u03B2",
gamma = u"\u03B3",
delta = u"\u03B4",
epsilon = u"\u03B5")
ppow = OperatorPrinter('**', 1, 'right')
pneg = OperatorPrinter('-', 0, 'either')
pmul = OperatorPrinter('*', -1, 'either')
pdiv = OperatorPrinter('/', -1, 'left')
padd = OperatorPrinter('+', -2, 'either')
psub = OperatorPrinter('-', -2, 'left')
pdot = OperatorPrinter(special['middle_dot'], -1, 'left')
psum = OperatorPrinter(special['big_sigma']+' ', -2, 'left')
def pprinter():
pp = PPrinter()
pp.assign(lambda pstate, r: True, DefaultPrinter())
pp.assign(T.add, padd)
pp.assign(T.mul, pmul)
pp.assign(T.sub, psub)
pp.assign(T.neg, pneg)
pp.assign(T.div, pdiv)
pp.assign(T.pow, ppow)
pp.assign(T.dot, pdot)
pp.assign(T.Sum(), FunctionPrinter('sum'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.DimShuffle), DimShufflePrinter())
pp.assign(T.tensor_copy, IgnorePrinter())
pp.assign(T.log, FunctionPrinter('log'))
pp.assign(T.tanh, FunctionPrinter('tanh'))
pp.assign(T.transpose_inplace, MemberPrinter('T'))
pp.assign(T._abs, PatternPrinter(('|%(0)s|', -1000)))
return pp
pp = pprinter()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论