提交 ff72a0f7 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed the optimizers in situations of failure

上级 8dd464d2
...@@ -491,7 +491,14 @@ class NavigatorOptimizer(Optimizer): ...@@ -491,7 +491,14 @@ class NavigatorOptimizer(Optimizer):
env.remove_feature(u) env.remove_feature(u)
def process_node(self, env, node): def process_node(self, env, node):
try:
replacements = self.local_opt.transform(node) replacements = self.local_opt.transform(node)
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(e, self, [(x, None) for x in node.outputs])
return
else:
raise
if replacements is False or replacements is None: if replacements is False or replacements is None:
return return
repl_pairs = zip(node.outputs, replacements) repl_pairs = zip(node.outputs, replacements)
......
...@@ -337,12 +337,15 @@ def tensor(*args, **kwargs): ...@@ -337,12 +337,15 @@ def tensor(*args, **kwargs):
return Tensor(*args, **kwargs).make_result() return Tensor(*args, **kwargs).make_result()
def _multi(*fns): def _multi(*fns):
def f2(f, names): def f2(f, *names):
if isinstance(names, int): if isinstance(names, int):
if names == 1: if names == 1:
return f() return f()
else: else:
return [f() for i in xrange(names)] return [f() for i in xrange(names)]
if isinstance(names, tuple):
if len(names) == 1:
names = names[0]
if len(names) == 1: if len(names) == 1:
return f(names) return f(names)
else: else:
......
# TODO: intelligent merge for mul/add
# TODO: 0*x -> 0
import gof import gof
from gof import opt from gof import opt
from elemwise import Elemwise, DimShuffle from elemwise import Elemwise, DimShuffle
...@@ -13,10 +17,14 @@ import sys ...@@ -13,10 +17,14 @@ import sys
# Utilities # Utilities
def out2in(*local_opts): def out2in(*local_opts):
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), order = 'out_to_in') return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts),
order = 'out_to_in',
failure_callback = lambda exc,opt,pairs: None)
def in2out(*local_opts): def in2out(*local_opts):
return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts), order = 'in_to_out') return opt.TopoOptimizer(opt.LocalOptGroup(*local_opts),
order = 'in_to_out',
failure_callback = lambda exc,opt,pairs: None)
# gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c) # gemm: (d,a,b,c,s) -> d = d*s + a*dot(b,c)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论