提交 f34471c4 authored 作者: Frederic's avatar Frederic

Add string name to some optimizers

上级 9456574b
......@@ -555,6 +555,9 @@ class MergeOptimizer(Optimizer):
# clear blacklist
fgraph.merge_feature.blacklist = []
def __str__(self):
return self.__class__.__name__
merge_optimizer = MergeOptimizer()
......@@ -1302,6 +1305,10 @@ class TopoOptimizer(NavigatorOptimizer):
raise
self.detach_updater(fgraph, u)
def __str__(self):
return getattr(self, '__name__',
'<TopoOptimizer instance>')
class OpKeyOptimizer(NavigatorOptimizer):
"""WRITEME"""
......
......@@ -47,29 +47,43 @@ theano.configparser.AddConfigVar('on_shape_error',
# Utilities
def out2in(*local_opts):
def out2in(*local_opts, **kwargs):
"""WRITEME """
name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts),
else:
local_opts, = local_opts
return opt.TopoOptimizer(local_opts,
order='out_to_in',
failure_callback=TopoOptimizer.warn_inplace)
if not name:
name = local_opts.__name__
ret = opt.TopoOptimizer(local_opts,
order='out_to_in',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if name:
ret.__name__ = name
return ret
def in2out(*local_opts, **kwargs):
"""WRITEME """
name = (kwargs and kwargs.pop('name', None))
if len(local_opts) > 1:
# Don't wrap it uselessly if their is only 1 optimization.
local_opts = opt.LocalOptGroup(*local_opts),
else:
local_opts, = local_opts
return opt.TopoOptimizer(local_opts,
order='in_to_out',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if not name:
#import pdb;pdb.set_trace()
name = local_opts.__name__
ret = opt.TopoOptimizer(local_opts,
order='in_to_out',
failure_callback=TopoOptimizer.warn_inplace,
**kwargs)
if name:
ret.__name__ = name
return ret
def _fill_chain(new_out, orig_inputs):
......@@ -3717,7 +3731,8 @@ register_specialize(local_add_specialize)
# mul_to_neg = out2in(gof.LocalOptGroup(local_mul_to_neg))
mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut,
local_fill_sink))
local_fill_sink),
name='mul_canonizer_groups')
def check_for_x_over_absX(numerators, denominators):
......@@ -3859,7 +3874,8 @@ def add_calculate(num, denum, aslist=False, out_type=None):
local_add_canonizer = Canonizer(T.add, T.sub, T.neg, add_calculate)
add_canonizer = in2out(gof.LocalOptGroup(local_add_canonizer, local_fill_cut,
local_fill_sink))
local_fill_sink),
name='add_canonizer_group')
register_canonicalize(local_add_canonizer, name='local_add_canonizer')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论