提交 f34471c4 authored 作者: Frederic's avatar Frederic

Add string name to some optimizers

上级 9456574b
...@@ -555,6 +555,9 @@ class MergeOptimizer(Optimizer): ...@@ -555,6 +555,9 @@ class MergeOptimizer(Optimizer):
# clear blacklist # clear blacklist
fgraph.merge_feature.blacklist = [] fgraph.merge_feature.blacklist = []
def __str__(self):
return self.__class__.__name__
merge_optimizer = MergeOptimizer() merge_optimizer = MergeOptimizer()
...@@ -1302,6 +1305,10 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -1302,6 +1305,10 @@ class TopoOptimizer(NavigatorOptimizer):
raise raise
self.detach_updater(fgraph, u) self.detach_updater(fgraph, u)
def __str__(self):
return getattr(self, '__name__',
'<TopoOptimizer instance>')
class OpKeyOptimizer(NavigatorOptimizer): class OpKeyOptimizer(NavigatorOptimizer):
"""WRITEME""" """WRITEME"""
......
...@@ -47,29 +47,43 @@ theano.configparser.AddConfigVar('on_shape_error', ...@@ -47,29 +47,43 @@ theano.configparser.AddConfigVar('on_shape_error',
# Utilities # Utilities
def out2in(*local_opts): def out2in(*local_opts, **kwargs):
"""WRITEME """ """WRITEME """
name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1: if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization. # Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts), local_opts = opt.LocalOptGroup(*local_opts),
else: else:
local_opts, = local_opts local_opts, = local_opts
return opt.TopoOptimizer(local_opts, if not name:
order='out_to_in', name = local_opts.__name__
failure_callback=TopoOptimizer.warn_inplace) ret = opt.TopoOptimizer(local_opts,
order='out_to_in',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if name:
ret.__name__ = name
return ret
def in2out(*local_opts, **kwargs): def in2out(*local_opts, **kwargs):
"""WRITEME """ """WRITEME """
name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1: if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization. # Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts), local_opts = opt.LocalOptGroup(*local_opts),
else: else:
local_opts, = local_opts local_opts, = local_opts
return opt.TopoOptimizer(local_opts, if not name:
order='in_to_out', #import pdb;pdb.set_trace()
failure_callback=TopoOptimizer.warn_inplace, name = local_opts.__name__
**kwargs) ret = opt.TopoOptimizer(local_opts,
order='in_to_out',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if name:
ret.__name__ = name
return ret
def _fill_chain(new_out, orig_inputs): def _fill_chain(new_out, orig_inputs):
...@@ -3717,7 +3731,8 @@ register_specialize(local_add_specialize) ...@@ -3717,7 +3731,8 @@ register_specialize(local_add_specialize)
# mul_to_neg = out2in(gof.LocalOptGroup(local_mul_to_neg)) # mul_to_neg = out2in(gof.LocalOptGroup(local_mul_to_neg))
mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut,
local_fill_sink)) local_fill_sink),
name='mul_canonizer_groups')
def check_for_x_over_absX(numerators, denominators): def check_for_x_over_absX(numerators, denominators):
...@@ -3859,7 +3874,8 @@ def add_calculate(num, denum, aslist=False, out_type=None): ...@@ -3859,7 +3874,8 @@ def add_calculate(num, denum, aslist=False, out_type=None):
local_add_canonizer = Canonizer(T.add, T.sub, T.neg, add_calculate) local_add_canonizer = Canonizer(T.add, T.sub, T.neg, add_calculate)
add_canonizer = in2out(gof.LocalOptGroup(local_add_canonizer, local_fill_cut, add_canonizer = in2out(gof.LocalOptGroup(local_add_canonizer, local_fill_cut,
local_fill_sink)) local_fill_sink),
name='add_canonizer_group')
register_canonicalize(local_add_canonizer, name='local_add_canonizer') register_canonicalize(local_add_canonizer, name='local_add_canonizer')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论