提交 19487cd3 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

made the failure_callback for the optdb display errors

上级 8577d9ac
...@@ -23,7 +23,7 @@ from opt import \ ...@@ -23,7 +23,7 @@ from opt import \
LocalOptimizer, local_optimizer, LocalOptGroup, LocalOpKeyOptGroup, \ LocalOptimizer, local_optimizer, LocalOptGroup, LocalOpKeyOptGroup, \
OpSub, OpRemove, PatternSub, \ OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer, EquilibriumOptimizer, \ NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer, EquilibriumOptimizer, \
keep_going, \ keep_going, warn, \
InplaceOptimizer, PureThenInplaceOptimizer InplaceOptimizer, PureThenInplaceOptimizer
from optdb import \ from optdb import \
......
...@@ -820,10 +820,15 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -820,10 +820,15 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def keep_going(exc, nav, repl_pairs): def keep_going(exc, nav, repl_pairs):
"""WRITEME""" """WRITEME"""
print exc, nav, repl_pairs
pass pass
import traceback
def warn(exc, nav, repl_pairs):
"""WRITEME"""
traceback.print_exc()
################# #################
### Utilities ### ### Utilities ###
################# #################
......
...@@ -98,7 +98,7 @@ class EquilibriumDB(DB): ...@@ -98,7 +98,7 @@ class EquilibriumDB(DB):
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
opts = super(EquilibriumDB, self).query(*tags, **kwtags) opts = super(EquilibriumDB, self).query(*tags, **kwtags)
return opt.EquilibriumOptimizer(opts, max_depth = 5, max_use_ratio = 10, failure_callback = opt.keep_going) return opt.EquilibriumOptimizer(opts, max_depth = 5, max_use_ratio = 10, failure_callback = opt.warn)
class SequenceDB(DB): class SequenceDB(DB):
...@@ -115,6 +115,6 @@ class SequenceDB(DB): ...@@ -115,6 +115,6 @@ class SequenceDB(DB):
opts = super(SequenceDB, self).query(*tags, **kwtags) opts = super(SequenceDB, self).query(*tags, **kwtags)
opts = list(opts) opts = list(opts)
opts.sort(key = lambda obj: self.__priority__[obj.name]) opts.sort(key = lambda obj: self.__priority__[obj.name])
return opt.SeqOptimizer(opts, failure_callback = opt.keep_going) return opt.SeqOptimizer(opts, failure_callback = opt.warn)
...@@ -93,7 +93,7 @@ insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer) ...@@ -93,7 +93,7 @@ insert_inplace_optimizer = gof.optimizer(_insert_inplace_optimizer)
inplace_optimizer = gof.InplaceOptimizer( inplace_optimizer = gof.InplaceOptimizer(
gof.SeqOptimizer(out2in(gemm_pattern_1), gof.SeqOptimizer(out2in(gemm_pattern_1),
insert_inplace_optimizer, insert_inplace_optimizer,
failure_callback = gof.keep_going)) failure_callback = gof.warn))
compile.optdb.register('inplace', inplace_optimizer, 99, 'fast_run') compile.optdb.register('inplace', inplace_optimizer, 99, 'fast_run')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论