提交 29024484 authored 作者: f0k's avatar f0k

Slightly cleaned up registration of GPU convolution optimizers and corresponding tags

上级 ea9e3e54
...@@ -1087,7 +1087,7 @@ if cuda_available: ...@@ -1087,7 +1087,7 @@ if cuda_available:
from theano.sandbox.cuda.opt import ( from theano.sandbox.cuda.opt import (
local_optimizer, gpu_optimizer, gpu_seqopt) local_optimizer, gpu_optimizer, gpu_seqopt)
@register_opt('cudnn') #@register_opt('cudnn') # this optimizer is registered in opt.py instead.
@local_optimizer([GpuConv]) @local_optimizer([GpuConv])
def local_conv_dnn(node): def local_conv_dnn(node):
if not dnn_available(): if not dnn_available():
......
...@@ -1105,12 +1105,9 @@ def local_gpu_softmax_with_bias(node): ...@@ -1105,12 +1105,9 @@ def local_gpu_softmax_with_bias(node):
return [host_from_gpu(gpu_sm)] return [host_from_gpu(gpu_sm)]
return False return False
# Convolution, maxpooling
# Convolution
from theano.tensor.nnet import conv from theano.tensor.nnet import conv
# We need a fixed order for the user interface.
conv_groupopt = theano.gof.optdb.LocalGroupDB()
conv_groupopt.__name__ = "gpu_conv_opts"
register_opt('fast_compile', 'fast_run', 'gpu')(conv_groupopt)
def _gpu_conv_to_fftconv(node): def _gpu_conv_to_fftconv(node):
...@@ -1163,22 +1160,8 @@ def local_conv_fft_full(node): ...@@ -1163,22 +1160,8 @@ def local_conv_fft_full(node):
return return
@local_optimizer([GpuConv])
def local_gpu_conv(node):
"""
If cudnn is available, use it. Otherwise, use the gemm version.
"""
if (isinstance(node.op, GpuConv) and
theano.sandbox.cuda.dnn.dnn_available()):
return theano.sandbox.cuda.dnn.local_conv_dnn.transform(node)
# If dnn isn't avail, the local_gpu_conv_legacy wil introduce the
# legacy opt. Then the local_conv_gemm will convert it to gemm
# opt.
@local_optimizer([gpu_from_host, conv.ConvOp]) @local_optimizer([gpu_from_host, conv.ConvOp])
def local_gpu_conv_legacy(node): def local_gpu_conv(node):
""" """
gpu_from_host(conv) -> gpu_conv(gpu_from_host) gpu_from_host(conv) -> gpu_conv(gpu_from_host)
...@@ -1334,19 +1317,31 @@ def local_conv_gemm(node): ...@@ -1334,19 +1317,31 @@ def local_conv_gemm(node):
gpu_contiguous(kern), gpu_contiguous(img))] gpu_contiguous(kern), gpu_contiguous(img))]
# Legacy opt first, as this is the only that move to the GPU. # First we register the optimizer that moves convolutions to the GPU.
# Then fft, as disabled dy default. So if use enable it, it have prio register_opt()(local_gpu_conv)
# Then default, use dnn if avail
# Then default, use gemm if dnn or fft didn't worked. # Then we create a group of optimizers that replace the legacy GpuConv
# Normally, gemm should catch all case, so the legacy should never run. # with other implementations. They are tried in a specific order so we
conv_groupopt.register('local_gpu_conv_legacy', local_gpu_conv_legacy, 0, # can control which ones take precedence over others.
'fast_compile', 'fast_run') conv_groupopt = theano.gof.optdb.LocalGroupDB()
conv_groupopt.register("conv_fft_valid", local_conv_fft_valid, 1) conv_groupopt.__name__ = "gpu_conv_opts"
conv_groupopt.register("conv_fft_full", local_conv_fft_full, 1) register_opt()(conv_groupopt)
# Use dnn if avail, so have the dnn tag to be able to disable it.
conv_groupopt.register('local_gpu_conv', local_gpu_conv, 10, # FFT gets the highest priority (lowest number), but is disabled by default.
# It can be enabled by including 'conv_fft'.
conv_groupopt.register('conv_fft_valid', local_conv_fft_valid, 10,
'conv_fft')
conv_groupopt.register('conv_fft_full', local_conv_fft_full, 10,
'conv_fft')
# cuDNN is the second, but only registered if cuDNN is available.
# It can be disabled by excluding 'conv_dnn' or 'cudnn'.
from . import dnn
if dnn.dnn_available():
conv_groupopt.register('conv_dnn', dnn.local_conv_dnn, 20,
'fast_compile', 'fast_run', 'cudnn') 'fast_compile', 'fast_run', 'cudnn')
conv_groupopt.register('local_conv_gemm', local_conv_gemm, 12, # The GEMM-based convolution comes last to catch all remaining cases.
# It can be disabled by excluding 'conv_gemm'.
conv_groupopt.register('conv_gemm', local_conv_gemm, 30,
'fast_compile', 'fast_run') 'fast_compile', 'fast_run')
...@@ -1500,6 +1495,7 @@ def local_convtransp3d_gemm(node): ...@@ -1500,6 +1495,7 @@ def local_convtransp3d_gemm(node):
gpu_optimizer.register("convtransp3d_gemm", local_convtransp3d_gemm) gpu_optimizer.register("convtransp3d_gemm", local_convtransp3d_gemm)
# Pooling
import theano.tensor.signal.downsample as downsample import theano.tensor.signal.downsample as downsample
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论