提交 b18862fa authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Extract the handle code to a reusable class.

上级 3d00f2b4
......@@ -12,33 +12,10 @@ from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable,
from theano.sandbox.cuda.blas import GpuConv
from theano.compat import PY3
class GpuDnnConvBase(GpuOp):
__props__ = ('border_mode', 'conv_mode')
def __init__(self, border_mode, conv_mode='conv'):
assert border_mode in ('valid', 'full')
self.border_mode = border_mode
assert conv_mode in ('conv', 'cross')
self.conv_mode = conv_mode
def __setstate__(self, props):
self.__dict__.update(props)
if not hasattr(self, 'conv_mode'):
self.conv_mode = 'conv'
def make_node(self, img, kern):
if img.type.ndim != 4:
raise TypeError('img must be 4D tensor')
if kern.type.ndim != 4:
raise TypeError('kern must be 4D tensor')
broadcastable = (img.type.broadcastable[0],
kern.type.broadcastable[0],
False, False)
return Apply(self, [img, kern], [CudaNdarrayType(broadcastable)()])
class DnnBase(GpuOp):
"""
Creates a handle for cudnn and pulls in the cudnn libraries and headers.
"""
def c_headers(self):
return ['cudnn.h', 'cudnn_helper.h']
......@@ -67,6 +44,33 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) {
}
}""" % (error_out,)]
class GpuDnnConvBase(DnnBase):
__props__ = ('border_mode', 'conv_mode')
def __init__(self, border_mode, conv_mode='conv'):
assert border_mode in ('valid', 'full')
self.border_mode = border_mode
assert conv_mode in ('conv', 'cross')
self.conv_mode = conv_mode
def __setstate__(self, props):
self.__dict__.update(props)
if not hasattr(self, 'conv_mode'):
self.conv_mode = 'conv'
def make_node(self, img, kern):
if img.type.ndim != 4:
raise TypeError('img must be 4D tensor')
if kern.type.ndim != 4:
raise TypeError('kern must be 4D tensor')
broadcastable = (img.type.broadcastable[0],
kern.type.broadcastable[0],
False, False)
return Apply(self, [img, kern], [CudaNdarrayType(broadcastable)()])
def c_support_code_struct(self, node, struct_id):
types = ['cudnn' + d.capitalize() + 'Descriptor_t'
for d in self.descriptors]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论