提交 eefa42ac authored 作者: Nicolas Ballas's avatar Nicolas Ballas

move optim in specialize device

上级 e980fdd0
...@@ -2455,6 +2455,11 @@ def local_abstractconv_cudnn(node): ...@@ -2455,6 +2455,11 @@ def local_abstractconv_cudnn(node):
inp1 = node.inputs[0] inp1 = node.inputs[0]
inp2 = node.inputs[1] inp2 = node.inputs[1]
if ((not isinstance(node.op, AbstractConv2d) or
not isinstance(node.op, AbstractConv2d_gradWeights) or
not isinstance(node.op, AbstractConv2d_gradInputs))):
return None
if not isinstance(inp1.type, CudaNdarrayType) or \ if not isinstance(inp1.type, CudaNdarrayType) or \
not isinstance(inp2.type, CudaNdarrayType): not isinstance(inp2.type, CudaNdarrayType):
return None return None
......
...@@ -2784,22 +2784,22 @@ def local_abstractconv_gradinputs_gemm(node): ...@@ -2784,22 +2784,22 @@ def local_abstractconv_gradinputs_gemm(node):
# which ones take precedence over others. # which ones take precedence over others.
abstractconv_groupopt = theano.gof.optdb.LocalGroupDB() abstractconv_groupopt = theano.gof.optdb.LocalGroupDB()
abstractconv_groupopt.__name__ = "gpu_abstractconv_opts" abstractconv_groupopt.__name__ = "gpu_abstractconv_opts"
register_opt()(abstractconv_groupopt) register_specialize_device()(abstractconv_groupopt)
# cuDNN is first, but only registered if cuDNN is available. # cuDNN is first, but only registered if cuDNN is available.
conv_groupopt.register('local_abstractconv_dnn', dnn.local_abstractconv_cudnn, 20, conv_groupopt.register('local_abstractconv_dnn', dnn.local_abstractconv_cudnn, 20,
'conv_dnn', 'conv_dnn',
'fast_compile', 'fast_run', 'cudnn') 'gpu_opt', 'cudnn')
# The GEMM-based convolution comes last to catch all remaining cases. # The GEMM-based convolution comes last to catch all remaining cases.
# It can be disabled by excluding 'conv_gemm'. # It can be disabled by excluding 'conv_gemm'.
conv_groupopt.register('local_abstractconv_gemm', local_abstractconv_gemm, 30, conv_groupopt.register('local_abstractconv_gemm', local_abstractconv_gemm, 30,
'conv_gemm', 'conv_gemm',
'fast_compile', 'fast_run') 'gpu_opt')
conv_groupopt.register('local_abstractconv_gradweight_gemm', conv_groupopt.register('local_abstractconv_gradweight_gemm',
local_abstractconv_gradweight_gemm, 30, local_abstractconv_gradweight_gemm, 30,
#'conv_gemm', 'conv_gemm',
'fast_compile', 'fast_run') 'fast_compile', 'fast_run')
conv_groupopt.register('local_abstractconv_gradinputs_gemm', conv_groupopt.register('local_abstractconv_gradinputs_gemm',
local_abstractconv_gradinputs_gemm, 30, local_abstractconv_gradinputs_gemm, 30,
#'conv_gemm', 'conv_gemm',
'fast_compile', 'fast_run') 'gpu_opt')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论