提交 3b622033 authored 作者: James Bergstra's avatar James Bergstra

Modification to failure callbacks. inplace_set_subtensor no longer prints

traceback on error. All failure_callback tracebacks now include the optimization responsible for the problem.
上级 b87a7adf
...@@ -25,7 +25,6 @@ from opt import (Optimizer, optimizer, SeqOptimizer, ...@@ -25,7 +25,6 @@ from opt import (Optimizer, optimizer, SeqOptimizer,
LocalOptimizer, local_optimizer, LocalOptGroup, LocalOptimizer, local_optimizer, LocalOptGroup,
OpSub, OpRemove, PatternSub, OpSub, OpRemove, PatternSub,
NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer, NavigatorOptimizer, TopoOptimizer, EquilibriumOptimizer,
keep_going, warn,
InplaceOptimizer, PureThenInplaceOptimizer, InplaceOptimizer, PureThenInplaceOptimizer,
OpKeyOptimizer) OpKeyOptimizer)
......
...@@ -311,9 +311,9 @@ class Env(utils.object2): ...@@ -311,9 +311,9 @@ class Env(utils.object2):
For every node that uses r as input, makes it use new_r instead. For every node that uses r as input, makes it use new_r instead.
""" """
if r.env is not self: if r.env is not self:
raise Exception("Cannot replace %s because it does not belong to this Env" % r) raise Exception("Cannot replace %s because it does not belong to this Env" % r, str(reason))
if not r.type == new_r.type: if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r, r.type, new_r.type) raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r, r.type, new_r.type, str(reason))
if r not in self.results: if r not in self.results:
# this result isn't in the graph... don't raise an exception here, just return silently # this result isn't in the graph... don't raise an exception here, just return silently
# because it makes it easier to implement some optimizations for multiple-output ops # because it makes it easier to implement some optimizations for multiple-output ops
......
...@@ -14,6 +14,7 @@ from copy import copy ...@@ -14,6 +14,7 @@ from copy import copy
from collections import deque, defaultdict from collections import deque, defaultdict
import destroyhandler as dh import destroyhandler as dh
import sys import sys
import traceback
_optimizer_idx = [0] _optimizer_idx = [0]
...@@ -87,6 +88,13 @@ class SeqOptimizer(Optimizer, list): ...@@ -87,6 +88,13 @@ class SeqOptimizer(Optimizer, list):
Takes a list of L{Optimizer} instances and applies them Takes a list of L{Optimizer} instances and applies them
sequentially. sequentially.
""" """
@staticmethod
def warn(exc, self, optimizer):
"""Default failure_callback for SeqOptimizer
"""
print >> sys.stderr, "WARNING: SeqOptimizer apply", optimizer
print >> sys.stderr, "Traceback:"
traceback.print_exc()
def __init__(self, *opts, **kw): def __init__(self, *opts, **kw):
"""WRITEME""" """WRITEME"""
...@@ -614,6 +622,27 @@ class NavigatorOptimizer(Optimizer): ...@@ -614,6 +622,27 @@ class NavigatorOptimizer(Optimizer):
"""Abstract class """Abstract class
""" """
@staticmethod
def warn(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: print traceback
"""
print "WARNING: Optimization failure due to: ", local_opt
print "TRACEBACK:"
traceback.print_exc()
@staticmethod
def warn_inplace(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: ignore InconsistencyErrors, print traceback
"""
if isinstance(exc, InconsistencyError):
return
print >> sys.stderr, "WARNING: Optimization failure due to: ", local_opt
print >> sys.stderr, "TRACEBACK:"
traceback.print_exc()
@staticmethod
def warn_ignore(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: ignore all errors
"""
pass
def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None): def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None):
""" """
...@@ -706,7 +735,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -706,7 +735,7 @@ class NavigatorOptimizer(Optimizer):
replacements = lopt.transform(node) replacements = lopt.transform(node)
except Exception, e: except Exception, e:
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(e, self, [(x, None) for x in node.outputs]) self.failure_callback(e, self, [(x, None) for x in node.outputs], lopt)
return False return False
else: else:
raise raise
...@@ -722,7 +751,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -722,7 +751,7 @@ class NavigatorOptimizer(Optimizer):
# This is not supposed to happen. The default failure_callback will print a # This is not supposed to happen. The default failure_callback will print a
# traceback as a warning. # traceback as a warning.
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(e, self, repl_pairs) self.failure_callback(e, self, repl_pairs, lopt)
return False return False
else: else:
raise raise
...@@ -875,16 +904,6 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -875,16 +904,6 @@ class EquilibriumOptimizer(NavigatorOptimizer):
print >> sys.stderr, "WARNING: EquilibriumOptimizer max'ed out" print >> sys.stderr, "WARNING: EquilibriumOptimizer max'ed out"
def keep_going(exc, nav, repl_pairs):
"""WRITEME"""
pass
import traceback
def warn(exc, nav, repl_pairs):
"""WRITEME"""
traceback.print_exc()
################# #################
### Utilities ### ### Utilities ###
......
...@@ -117,12 +117,15 @@ class EquilibriumDB(DB): ...@@ -117,12 +117,15 @@ class EquilibriumDB(DB):
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
opts = super(EquilibriumDB, self).query(*tags, **kwtags) opts = super(EquilibriumDB, self).query(*tags, **kwtags)
return opt.EquilibriumOptimizer(opts, max_depth = 5, max_use_ratio = 10, failure_callback = opt.warn) return opt.EquilibriumOptimizer(opts,
max_depth=5,
max_use_ratio=10,
failure_callback=opt.NavigatorOptimizer.warn)
class SequenceDB(DB): class SequenceDB(DB):
def __init__(self, failure_callback = opt.warn): def __init__(self, failure_callback = opt.SeqOptimizer.warn):
super(SequenceDB, self).__init__() super(SequenceDB, self).__init__()
self.__priority__ = {} self.__priority__ = {}
self.failure_callback = failure_callback self.failure_callback = failure_callback
......
...@@ -14,7 +14,7 @@ from theano.gof.toolbox import ReplaceValidate ...@@ -14,7 +14,7 @@ from theano.gof.toolbox import ReplaceValidate
from copy import copy from copy import copy
PatternOptimizer = lambda p1, p2, ign=True: OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign) PatternOptimizer = lambda p1, p2, ign=True: OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
OpSubOptimizer = lambda op1, op2, fail=keep_going, ign=True: TopoOptimizer(OpSub(op1, op2), ignore_newtrees=ign, failure_callback = fail) OpSubOptimizer = lambda op1, op2, fail=NavigatorOptimizer.warn_ignore, ign=True: TopoOptimizer(OpSub(op1, op2), ignore_newtrees=ign, failure_callback = fail)
def as_result(x): def as_result(x):
...@@ -89,7 +89,7 @@ class FailureWatch: ...@@ -89,7 +89,7 @@ class FailureWatch:
# when passed to OpSubOptimizer or PatternOptimizer, counts the number of failures # when passed to OpSubOptimizer or PatternOptimizer, counts the number of failures
def __init__(self): def __init__(self):
self.failures = 0 self.failures = 0
def __call__(self, exc, nav, pairs): def __call__(self, exc, nav, pairs, lopt):
assert isinstance(exc, InconsistencyError) assert isinstance(exc, InconsistencyError)
self.failures += 1 self.failures += 1
......
""" Mode that runs all nodes, even those which have been optimized out. """ Provides `OptCheck`
A basic premise of how theano works is that every node that is replaced during optimization should compute the same thing as its replacement.
Normally theano's optimizations work by running such replacements instead of the originals.
This debugging tool does a different thing. It runs the original and the replacement, and then
checks that they both compute the same thing.
If their values are different, the optimization that created the replacement is probably
broken.
""" """
...@@ -480,6 +472,17 @@ class OptCheckFunctionMaker(FunctionMaker): ...@@ -480,6 +472,17 @@ class OptCheckFunctionMaker(FunctionMaker):
return fn return fn
class OptCheck(Mode): class OptCheck(Mode):
"""Evaluation Mode that detects optimization errors.
A basic premise of how theano works is that every node that is replaced during optimization should compute the same thing as its replacement.
Normally such replacements run instead of the originals.
This Mode runs the original and the replacement, and then checks that they both compute the
same thing.
If their values are different, the optimization that created the replacement is probably
broken.
"""
# This function will be used to create a FunctionMaker in # This function will be used to create a FunctionMaker in
# function_module.function # function_module.function
def function_maker(self, i,o,m, *args, **kwargs): def function_maker(self, i,o,m, *args, **kwargs):
......
...@@ -4,7 +4,7 @@ import os, sys, traceback ...@@ -4,7 +4,7 @@ import os, sys, traceback
import numpy import numpy
from ..gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler, from ..gof import (utils, Op, Apply, view_roots, PatternSub, DestroyHandler,
SeqOptimizer, warn, local_optimizer, LocalOptimizer, OpKeyOptimizer, SeqOptimizer, local_optimizer, LocalOptimizer, OpKeyOptimizer,
InconsistencyError) InconsistencyError)
from ..printing import pprint, FunctionPrinter from ..printing import pprint, FunctionPrinter
from .opt import register_specialize, out2in, insert_inplace_optimizer from .opt import register_specialize, out2in, insert_inplace_optimizer
......
...@@ -284,7 +284,8 @@ def local_inplace_setsubtensor(node): ...@@ -284,7 +284,8 @@ def local_inplace_setsubtensor(node):
new_node = new_op(*node.inputs) new_node = new_op(*node.inputs)
return [new_node] return [new_node]
return False return False
compile.optdb.register('inplace_setsubtensor', TopoOptimizer(local_inplace_setsubtensor), 60, 'fast_run', 'inplace') #DEBUG compile.optdb.register('inplace_setsubtensor', TopoOptimizer(local_inplace_setsubtensor,
failure_callback=TopoOptimizer.warn_inplace), 60, 'fast_run', 'inplace') #DEBUG
################## ##################
# Reshape opts # # Reshape opts #
...@@ -833,7 +834,13 @@ def local_greedy_distributor(node): ...@@ -833,7 +834,13 @@ def local_greedy_distributor(node):
new_num += num new_num += num
new_denum += denum new_denum += denum
return [local_mul_canonizer.merge_num_denum(new_num, new_denum)] rval = local_mul_canonizer.merge_num_denum(new_num, new_denum)
if rval.type != out.type:
#WHY DOES THIS HAPPEN?
return False
return [rval]
register_canonicalize(local_greedy_distributor) register_canonicalize(local_greedy_distributor)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论