提交 973ca763 authored 作者: affanv14's avatar affanv14

basic implementation of forward pass

上级 b7b5dd39
...@@ -1131,13 +1131,20 @@ class LocalMetaOptimizer(LocalOptimizer): ...@@ -1131,13 +1131,20 @@ class LocalMetaOptimizer(LocalOptimizer):
""" """
def __init__(self, tracks=None, optimizers=()): def __init__(self, optimizers=()):
self._tracks = tracks self._tracks = [x for o in optimizers for x in o.tracks()]
self.optimizers = list(optimizers) self.optimizers = list(optimizers)
self.verbose = config.metaopt.verbose self.verbose = config.metaopt.verbose
self.track_dict = defaultdict(lambda: [])
for o in optimizers:
for c in o.tracks():
self.track_dict[c].append(o)
def register(self, optimizer): def register(self, optimizer):
self.optimizers.append(optimizer) self.optimizers.append(optimizer)
for c in optimizer.tracks():
self.track_dict[c].append(optimizer)
def tracks(self): def tracks(self):
return self._tracks return self._tracks
...@@ -1178,7 +1185,7 @@ class LocalMetaOptimizer(LocalOptimizer): ...@@ -1178,7 +1185,7 @@ class LocalMetaOptimizer(LocalOptimizer):
print(("%s meta-optimizing %s (%d choices):" % print(("%s meta-optimizing %s (%d choices):" %
(self.__class__.__name__, node, len(self.optimizers)))) (self.__class__.__name__, node, len(self.optimizers))))
timings = [] timings = []
for opt in self.optimizers: for opt in (self.track_dict[type(node.op)] + self.track_dict[node.op]):
outputs = opt.transform(node) outputs = opt.transform(node)
if outputs: if outputs:
try: try:
...@@ -2313,7 +2320,6 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2313,7 +2320,6 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.final_optimizers = [] self.final_optimizers = []
self.cleanup_optimizers = [] self.cleanup_optimizers = []
self.tracks_on_change_inputs = tracks_on_change_inputs self.tracks_on_change_inputs = tracks_on_change_inputs
for opt in optimizers: for opt in optimizers:
if isinstance(opt, LocalOptimizer): if isinstance(opt, LocalOptimizer):
if opt.tracks() is None: if opt.tracks() is None:
......
...@@ -15,6 +15,7 @@ from theano.compile.ops import shape_i ...@@ -15,6 +15,7 @@ from theano.compile.ops import shape_i
from theano.gof import (local_optimizer, EquilibriumDB, TopoOptimizer, from theano.gof import (local_optimizer, EquilibriumDB, TopoOptimizer,
LocalGroupDB, LocalGroupDB,
SequenceDB, Optimizer, DB, toolbox, graph) SequenceDB, Optimizer, DB, toolbox, graph)
from theano.gof.opt import LocalMetaOptimizer
from theano.ifelse import IfElse from theano.ifelse import IfElse
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
...@@ -1776,6 +1777,35 @@ def local_abstractconv3d_gradinputs_gemm(node): ...@@ -1776,6 +1777,35 @@ def local_abstractconv3d_gradinputs_gemm(node):
return [rval] return [rval]
class LocalCudaMetaOptimizer(LocalMetaOptimizer):
def time_call(self, fn):
start = time.time()
fn()[0].sync()
return time.time() - start
class ConvMetaOptimizer(LocalCudaMetaOptimizer):
def __init__(self, optimizers):
super(ConvMetaOptimizer, self).__init__(optimizers)
def provide_inputs(self, node, inputs):
result = {}
img, kern = node.inputs
vars = (img, kern)
shapes = (node.op.imshp, node.op.kshp)
if(node.op.imshp is None or node.op.kshp is None or
any([s is None for shape in shapes for s in shape])):
return result
for(var, shape) in zip(vars, shapes):
result[var] = theano.shared(np.random.random(shape).astype(theano.config.floatX),
var.name, borrow=True)
return result
# This deals with any abstract convs that have a transfer somewhere # This deals with any abstract convs that have a transfer somewhere
@register_opt('fast_compile', 'conv_dnn', 'cudnn') @register_opt('fast_compile', 'conv_dnn', 'cudnn')
@op_lifter([AbstractConv2d, @op_lifter([AbstractConv2d,
...@@ -2356,6 +2386,12 @@ register_opt('fast_compile')(abstractconv_groupopt) ...@@ -2356,6 +2386,12 @@ register_opt('fast_compile')(abstractconv_groupopt)
# to avoid a circular dependency problem with dnn # to avoid a circular dependency problem with dnn
from .dnn import (local_abstractconv_cudnn, local_abstractconv_gw_cudnn, from .dnn import (local_abstractconv_cudnn, local_abstractconv_gw_cudnn,
local_abstractconv_gi_cudnn) # noqa: 402 local_abstractconv_gi_cudnn) # noqa: 402
conv_metaopt = ConvMetaOptimizer((local_abstractconv_gemm,
local_abstractconv_cudnn))
abstractconv_groupopt.register('conv_metaopt', conv_metaopt, 0, 'conv_meta')
abstractconv_groupopt.register('local_abstractconv_dnn', abstractconv_groupopt.register('local_abstractconv_dnn',
local_abstractconv_cudnn, 20, local_abstractconv_cudnn, 20,
'conv_dnn', 'conv_dnn',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论