提交 3b622033 authored 作者: James Bergstra's avatar James Bergstra

Modification to failure callbacks. inplace_set_subtensor no longer prints

traceback on error. All failure_callback tracebacks now include the optimization responsible for the problem.
上级 b87a7adf
......@@ -25,7 +25,6 @@ from opt import (Optimizer, optimizer, SeqOptimizer,
LocalOptimizer, local_optimizer, LocalOptGroup,
OpSub, OpRemove, PatternSub,
NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer,
keep_going, warn,
InplaceOptimizer, PureThenInplaceOptimizer,
OpKeyOptimizer)
......
......@@ -311,9 +311,9 @@ class Env(utils.object2):
For every node that uses r as input, makes it use new_r instead.
"""
if r.env is not self:
raise Exception("Cannot replace %s because it does not belong to this Env" % r)
raise Exception("Cannot replace %s because it does not belong to this Env" % r, str(reason))
if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r, r.type, new_r.type)
raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r, r.type, new_r.type, str(reason))
if r not in self.results:
# this result isn't in the graph... don't raise an exception here, just return silently
# because it makes it easier to implement some optimizations for multiple-output ops
......
......@@ -14,6 +14,7 @@ from copy import copy
from collections import deque, defaultdict
import destroyhandler as dh
import sys
import traceback
_optimizer_idx = [0]
......@@ -87,6 +88,13 @@ class SeqOptimizer(Optimizer, list):
Takes a list of L{Optimizer} instances and applies them
sequentially.
"""
@staticmethod
def warn(exc, self, optimizer):
"""Default failure_callback for SeqOptimizer
"""
print >> sys.stderr, "WARNING: SeqOptimizer apply", optimizer
print >> sys.stderr, "Traceback:"
traceback.print_exc()
def __init__(self, *opts, **kw):
"""WRITEME"""
......@@ -614,6 +622,27 @@ class NavigatorOptimizer(Optimizer):
"""Abstract class
"""
@staticmethod
def warn(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: print traceback
"""
print "WARNING: Optimization failure due to: ", local_opt
print "TRACEBACK:"
traceback.print_exc()
@staticmethod
def warn_inplace(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: ignore InconsistencyErrors, print traceback
"""
if isinstance(exc, InconsistencyError):
return
print >> sys.stderr, "WARNING: Optimization failure due to: ", local_opt
print >> sys.stderr, "TRACEBACK:"
traceback.print_exc()
@staticmethod
def warn_ignore(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: ignore all errors
"""
pass
def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None):
"""
......@@ -706,7 +735,7 @@ class NavigatorOptimizer(Optimizer):
replacements = lopt.transform(node)
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(e, self, [(x, None) for x in node.outputs])
self.failure_callback(e, self, [(x, None) for x in node.outputs], lopt)
return False
else:
raise
......@@ -722,7 +751,7 @@ class NavigatorOptimizer(Optimizer):
# This is not supposed to happen. The default failure_callback will print a
# traceback as a warning.
if self.failure_callback is not None:
self.failure_callback(e, self, repl_pairs)
self.failure_callback(e, self, repl_pairs, lopt)
return False
else:
raise
......@@ -875,16 +904,6 @@ class EquilibriumOptimizer(NavigatorOptimizer):
print >> sys.stderr, "WARNING: EquilibriumOptimizer max'ed out"
def keep_going(exc, nav, repl_pairs):
"""WRITEME"""
pass
import traceback
def warn(exc, nav, repl_pairs):
"""WRITEME"""
traceback.print_exc()
#################
### Utilities ###
......
......@@ -117,12 +117,15 @@ class EquilibriumDB(DB):
def query(self, *tags, **kwtags):
opts = super(EquilibriumDB, self).query(*tags, **kwtags)
return opt.EquilibriumOptimizer(opts, max_depth = 5, max_use_ratio = 10, failure_callback = opt.warn)
return opt.EquilibriumOptimizer(opts,
max_depth=5,
max_use_ratio=10,
failure_callback=opt.NavigatorOptimizer.warn)
class SequenceDB(DB):
def __init__(self, failure_callback = opt.warn):
def __init__(self, failure_callback = opt.SeqOptimizer.warn):
super(SequenceDB, self).__init__()
self.__priority__ = {}
self.failure_callback = failure_callback
......
......@@ -14,7 +14,7 @@ from theano.gof.toolbox import ReplaceValidate
from copy import copy
PatternOptimizer = lambda p1, p2, ign=True: OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
OpSubOptimizer = lambda op1, op2, fail=keep_going, ign=True: TopoOptimizer(OpSub(op1, op2), ignore_newtrees=ign, failure_callback = fail)
OpSubOptimizer = lambda op1, op2, fail=NavigatorOptimizer.warn_ignore, ign=True: TopoOptimizer(OpSub(op1, op2), ignore_newtrees=ign, failure_callback = fail)
def as_result(x):
......@@ -89,7 +89,7 @@ class FailureWatch:
# when passed to OpSubOptimizer or PatternOptimizer, counts the number of failures
def __init__(self):
self.failures = 0
def __call__(self, exc, nav, pairs):
def __call__(self, exc, nav, pairs, lopt):
assert isinstance(exc, InconsistencyError)
self.failures += 1
......
""" Mode that runs all nodes, even those which have been optimized out.
A basic premise of how theano works is that every node that is replaced during optimization should compute the same thing as its replacement.
Normally theano's optimizations work by running such replacements instead of the originals.
This debugging tool does a different thing. It runs the original and the replacement, and then
checks that they both compute the same thing.
If their values are different, the optimization that created the replacement is probably
broken.
""" Provides `OptCheck`
"""
......@@ -480,6 +472,17 @@ class OptCheckFunctionMaker(FunctionMaker):
return fn
class OptCheck(Mode):
"""Evaluation Mode that detects optimization errors.
A basic premise of how theano works is that every node that is replaced during optimization should compute the same thing as its replacement.
Normally such replacements run instead of the originals.
This Mode runs the original and the replacement, and then checks that they both compute the
same thing.
If their values are different, the optimization that created the replacement is probably
broken.
"""
# This function will be used to create a FunctionMaker in
# function_module.function
def function_maker(self, i,o,m, *args, **kwargs):
......
......@@ -4,7 +4,7 @@ import os, sys, traceback
import numpy
from ..gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler,
SeqOptimizer, warn, local_optimizer, LocalOptimizer, OpKeyOptimizer,
SeqOptimizer, local_optimizer, LocalOptimizer, OpKeyOptimizer,
InconsistencyError)
from ..printing import pprint, FunctionPrinter
from .opt import register_specialize, out2in, insert_inplace_optimizer
......
......@@ -284,7 +284,8 @@ def local_inplace_setsubtensor(node):
new_node = new_op(*node.inputs)
return [new_node]
return False
compile.optdb.register('inplace_setsubtensor', TopoOptimizer(local_inplace_setsubtensor), 60, 'fast_run', 'inplace') #DEBUG
compile.optdb.register('inplace_setsubtensor', TopoOptimizer(local_inplace_setsubtensor,
failure_callback=TopoOptimizer.warn_inplace), 60, 'fast_run', 'inplace') #DEBUG
##################
# Reshape opts #
......@@ -833,7 +834,13 @@ def local_greedy_distributor(node):
new_num += num
new_denum += denum
return [local_mul_canonizer.merge_num_denum(new_num, new_denum)]
rval = local_mul_canonizer.merge_num_denum(new_num, new_denum)
if rval.type != out.type:
#WHY DOES THIS HAPPEN?
return False
return [rval]
register_canonicalize(local_greedy_distributor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论