提交 9baed894 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

updates

上级 9367e589
...@@ -14,7 +14,6 @@ import theano.sandbox.cuda as cuda ...@@ -14,7 +14,6 @@ import theano.sandbox.cuda as cuda
if not cuda.cuda_available: if not cuda.cuda_available:
raise SkipTest('Optional package cuda disabled') raise SkipTest('Optional package cuda disabled')
if theano.config.mode == 'FAST_COMPILE': if theano.config.mode == 'FAST_COMPILE':
mode_with_gpu = theano.compile.mode.get_mode('FAST_RUN').including('gpu') mode_with_gpu = theano.compile.mode.get_mode('FAST_RUN').including('gpu')
mode_without_gpu = theano.compile.mode.get_mode('FAST_RUN').excluding('gpu') mode_without_gpu = theano.compile.mode.get_mode('FAST_RUN').excluding('gpu')
...@@ -37,7 +36,6 @@ class TestConv2d(unittest.TestCase): ...@@ -37,7 +36,6 @@ class TestConv2d(unittest.TestCase):
self.filter_flip = [True, False] self.filter_flip = [True, False]
def get_output_shape(self, inputs_shape, filters_shape, subsample, border_mode): def get_output_shape(self, inputs_shape, filters_shape, subsample, border_mode):
if border_mode == "valid": if border_mode == "valid":
border_mode = (0, 0) border_mode = (0, 0)
if border_mode == "full": if border_mode == "full":
......
""" """
Define abstract conv2d interface Define abstract conv2d interface
""" """
__docformat__ = "restructuredtext en"
import logging import logging
import theano import theano
from theano.tensor import as_tensor_variable from theano.tensor import (as_tensor_variable, patternbroadcast)
from theano.tensor import TensorType
from theano.gof import Apply, Op from theano.gof import Apply, Op
from theano.gof import local_optimizer
from theano.tensor.opt import register_specialize_device
__docformat__ = "restructuredtext en"
_logger = logging.getLogger("theano.tensor.nnet.conv2d") _logger = logging.getLogger("theano.tensor.nnet.conv2d")
...@@ -168,6 +174,7 @@ def conv2d(input, ...@@ -168,6 +174,7 @@ def conv2d(input,
return conv_op(input, filters) return conv_op(input, filters)
class BaseAbstractConv2d(Op): class BaseAbstractConv2d(Op):
""" """
Base class for AbstractConv Base class for AbstractConv
...@@ -187,6 +194,7 @@ class BaseAbstractConv2d(Op): ...@@ -187,6 +194,7 @@ class BaseAbstractConv2d(Op):
element is not known at compile time. element is not known at compile time.
kshp is defined w.r.t the forward conv. kshp is defined w.r.t the forward conv.
:type border_mode: str, int or tuple of two int :type border_mode: str, int or tuple of two int
:param border_mode: Either of the following: :param border_mode: Either of the following:
* ``'valid'``: apply filter wherever it completely overlaps with the * ``'valid'``: apply filter wherever it completely overlaps with the
...@@ -219,6 +227,7 @@ class BaseAbstractConv2d(Op): ...@@ -219,6 +227,7 @@ class BaseAbstractConv2d(Op):
imshp=None, kshp=None, imshp=None, kshp=None,
border_mode="valid", subsample=(1, 1), border_mode="valid", subsample=(1, 1),
filter_flip=True): filter_flip=True):
if isinstance(border_mode, int): if isinstance(border_mode, int):
border_mode = (border_mode, border_mode) border_mode = (border_mode, border_mode)
if isinstance(border_mode, tuple): if isinstance(border_mode, tuple):
...@@ -297,6 +306,7 @@ class AbstractConv2d(BaseAbstractConv2d): ...@@ -297,6 +306,7 @@ class AbstractConv2d(BaseAbstractConv2d):
self.border_mode, self.border_mode,
self.subsample, self.subsample,
self.filter_flip)( self.filter_flip)(
bottom, top, weights.shape[-2:]) bottom, top, weights.shape[-2:])
return d_bottom, d_weights return d_bottom, d_weights
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论