提交 527b324e authored 作者: affanv14's avatar affanv14

add conv3d2d optimizer

上级 f499af62
...@@ -24,7 +24,7 @@ from theano.scalar.basic import log, neg, true_div ...@@ -24,7 +24,7 @@ from theano.scalar.basic import log, neg, true_div
from theano.scalar.basic_scipy import Erfinv, Erfcinv from theano.scalar.basic_scipy import Erfinv, Erfcinv
from theano.scan_module import scan_utils, scan_op, scan_opt from theano.scan_module import scan_utils, scan_op, scan_opt
from theano.tensor.nnet import bn from theano.tensor.nnet import bn, conv3d2d
from theano.tensor.nnet.conv import ConvOp from theano.tensor.nnet.conv import ConvOp
from theano.tensor.nnet.blocksparse import SparseBlockGemv, SparseBlockOuter from theano.tensor.nnet.blocksparse import SparseBlockGemv, SparseBlockOuter
from theano.tensor.nnet.abstract_conv import (BaseAbstractConv, from theano.tensor.nnet.abstract_conv import (BaseAbstractConv,
...@@ -1739,6 +1739,32 @@ def local_abstractconv3d_gemm(node): ...@@ -1739,6 +1739,32 @@ def local_abstractconv3d_gemm(node):
return [rval] return [rval]
@local_optimizer([AbstractConv3d])
def local_abstractconv3d2d(node):
if not isinstance(node.op, AbstractConv3d):
return None
img, kern = node.inputs
if (not isinstance(img.type, GpuArrayType) or
not isinstance(kern.type, GpuArrayType)):
return None
ctx = infer_context_name(img, kern)
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
if subsample == (1, 1, 1) and filter_dilation == (1, 1, 1):
rval = conv3d2d.conv3d(gpu_contiguous(img.dimshuffle(0, 2, 1, 3, 4)),
gpu_contiguous(kern.dimshuffle(0, 2, 1, 3, 4)),
border_mode=border_mode)
rval = as_gpuarray_variable(rval.dimshuffle(0, 2, 1, 3, 4),
context_name=ctx)
return [rval]
else:
return None
@local_optimizer([AbstractConv2d_gradWeights]) @local_optimizer([AbstractConv2d_gradWeights])
def local_abstractconv_gradweights_gemm(node): def local_abstractconv_gradweights_gemm(node):
if not isinstance(node.op, AbstractConv2d_gradWeights): if not isinstance(node.op, AbstractConv2d_gradWeights):
...@@ -2578,6 +2604,7 @@ conv_metaopt.register([local_abstractconv_gemm_alternative]) ...@@ -2578,6 +2604,7 @@ conv_metaopt.register([local_abstractconv_gemm_alternative])
conv_metaopt.register([local_abstractconv_gemm_gradweights_alt]) conv_metaopt.register([local_abstractconv_gemm_gradweights_alt])
conv_metaopt.register([local_abstractconv_gradinputs_gemm_alt]) conv_metaopt.register([local_abstractconv_gradinputs_gemm_alt])
conv_metaopt.register([local_abstractconv_cudnn_alternative]) conv_metaopt.register([local_abstractconv_cudnn_alternative])
conv_metaopt.register([local_abstractconv3d2d])
abstractconv_groupopt.register('conv_metaopt', conv_metaopt, 'conv_meta', position=0) abstractconv_groupopt.register('conv_metaopt', conv_metaopt, 'conv_meta', position=0)
# Register cuDNN batch normalization implementation # Register cuDNN batch normalization implementation
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论