提交 46423b3e authored 作者: Frederic's avatar Frederic

Merge 2 very similar class and just keep the old one.

上级 798fc9d1
...@@ -823,68 +823,6 @@ class LocalOptimizer(object): ...@@ -823,68 +823,6 @@ class LocalOptimizer(object):
(' ' * level), self.__class__.__name__, id(self)) (' ' * level), self.__class__.__name__, id(self))
class LocalSeqOptimizer(LocalOptimizer, list):
"""
This allow to try a group of local optimizer in sequence.
When one do something, we return without trying the following one.
"""
# inherit from Optimizer first to get Optimizer.__hash__
def __init__(self, *opts, **kw):
"""WRITEME"""
if len(opts) == 1 and isinstance(opts[0], (list, tuple)):
opts = opts[0]
self[:] = opts
self.failure_callback = kw.pop('failure_callback', None)
def tracks(self):
t = []
for l in self:
tt = l.tracks()
if tt:
t.extend(tt)
return t
def transform(self, node):
"""Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of two
kinds of things:
- False to indicate that no optimization can be applied to this `node`;
or
- <list of variables> to use in place of `node`'s outputs in the
greater graph.
- dict(old variables -> new variables). A dictionary that map
from old variables to new variables to replace.
:type node: an Apply instance
"""
for l in self:
ret = l.transform(node)
if ret:
return ret
def add_requirements(self, fgraph):
"""
If this local optimization wants to add some requirements to the
fgraph,
This is the place to do it.
"""
for l in self:
l.add_requirements(fgraph)
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None)
print >> stream, "%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self))
# This way, -1 will do all depth
if depth != 0:
depth -= 1
for opt in self:
opt.print_summary(stream, level=(level + 2), depth=depth)
class FromFunctionLocalOptimizer(LocalOptimizer): class FromFunctionLocalOptimizer(LocalOptimizer):
"""WRITEME""" """WRITEME"""
def __init__(self, fn, tracks=None, requirements=()): def __init__(self, fn, tracks=None, requirements=()):
...@@ -934,6 +872,9 @@ class LocalOptGroup(LocalOptimizer): ...@@ -934,6 +872,9 @@ class LocalOptGroup(LocalOptimizer):
"""WRITEME""" """WRITEME"""
def __init__(self, *optimizers): def __init__(self, *optimizers):
if len(optimizers) == 1 and isinstance(optimizers[0], list):
# This happen when created by LocalGroupDB.
optimizers = tuple(optimizers[0])
self.opts = optimizers self.opts = optimizers
self.reentrant = any(getattr(opt, 'reentrant', True) self.reentrant = any(getattr(opt, 'reentrant', True)
for opt in optimizers) for opt in optimizers)
......
...@@ -257,7 +257,10 @@ class SequenceDB(DB): ...@@ -257,7 +257,10 @@ class SequenceDB(DB):
# the order we want. # the order we want.
opts.sort(key=lambda obj: obj.name) opts.sort(key=lambda obj: obj.name)
opts.sort(key=lambda obj: self.__position__[obj.name]) opts.sort(key=lambda obj: self.__position__[obj.name])
ret = self.seq_opt(opts, failure_callback=self.failure_callback) kwargs = {}
if self.failure_callback:
kwargs["failure_callback"] = self.failure_callback
ret = self.seq_opt(opts, **kwargs)
if hasattr(tags[0], 'name'): if hasattr(tags[0], 'name'):
ret.name = tags[0].name ret.name = tags[0].name
return ret return ret
...@@ -280,11 +283,17 @@ class SequenceDB(DB): ...@@ -280,11 +283,17 @@ class SequenceDB(DB):
return sio.getvalue() return sio.getvalue()
class LocalSequenceDB(SequenceDB): class LocalGroupDB(SequenceDB):
""" """This generate a local optimizer of type LocalOptGroup instead
This generate a local optimizer instead of a global optimizer. of a global optimizer.
It support the tracks, to only get applied to some Op.
""" """
seq_opt = opt.LocalSeqOptimizer seq_opt = opt.LocalOptGroup
def __init__(self, failure_callback=opt.SeqOptimizer.warn):
super(LocalGroupDB, self).__init__()
self.failure_callback = None
class ProxyDB(DB): class ProxyDB(DB):
......
...@@ -1108,9 +1108,9 @@ def local_gpu_softmax_with_bias(node): ...@@ -1108,9 +1108,9 @@ def local_gpu_softmax_with_bias(node):
# Convolution, maxpooling # Convolution, maxpooling
from theano.tensor.nnet import conv from theano.tensor.nnet import conv
# We need a fixed order for the user interface. # We need a fixed order for the user interface.
conv_seqopt = theano.gof.optdb.LocalSequenceDB() conv_groupopt = theano.gof.optdb.LocalGroupDB()
conv_seqopt.__name__ = "nnn" conv_groupopt.__name__ = "gpu_conv_opt"
register_opt('fast_compile', 'fast_run', 'gpu')(conv_seqopt) register_opt('fast_compile', 'fast_run', 'gpu')(conv_groupopt)
def _gpu_conv_to_fftconv(node): def _gpu_conv_to_fftconv(node):
# shared helper function for local_conv_fft_valid and local_conv_fft_full. # shared helper function for local_conv_fft_valid and local_conv_fft_full.
...@@ -1384,16 +1384,16 @@ def local_conv_gemm(node): ...@@ -1384,16 +1384,16 @@ def local_conv_gemm(node):
# fft optimization not enabled by default. Need to be registered # fft optimization not enabled by default. Need to be registered
# before the default convolution optimization. If the user ask fft, as # before the default convolution optimization. If the user ask fft, as
# this isn't the default, it should have higher prio then the default. # this isn't the default, it should have higher prio then the default.
conv_seqopt.register("conv_fft_valid", local_conv_fft_valid, 1) conv_groupopt.register("conv_fft_valid", local_conv_fft_valid, 1)
conv_seqopt.register("conv_fft_full", local_conv_fft_full, 1) conv_groupopt.register("conv_fft_full", local_conv_fft_full, 1)
# default gpu conv optimization # default gpu conv optimization
conv_seqopt.register('local_gpu_conv', local_gpu_conv, 10, conv_groupopt.register('local_gpu_conv', local_gpu_conv, 10,
'fast_compile', 'fast_run', "dnn") 'fast_compile', 'fast_run', "dnn")
# Legacy convolution, after default # Legacy convolution, after default
conv_seqopt.register('local_gpu_conv_legacy', local_gpu_conv_legacy, 11, conv_groupopt.register('local_gpu_conv_legacy', local_gpu_conv_legacy, 11,
'fast_compile', 'fast_run', "dnn") 'fast_compile', 'fast_run', "dnn")
# conv gemm after legacy, as it convert legacy to gemm version # conv gemm after legacy, as it convert legacy to gemm version
conv_seqopt.register('local_conv_gemm', local_conv_gemm, 12, conv_groupopt.register('local_conv_gemm', local_conv_gemm, 12,
'fast_compile', 'fast_run', "dnn") 'fast_compile', 'fast_run', "dnn")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论