提交 6abbbc40 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

update

上级 93b4fb57
...@@ -11,8 +11,6 @@ from theano.gof import Apply, Op ...@@ -11,8 +11,6 @@ from theano.gof import Apply, Op
import numpy import numpy
try: try:
# TODO: move these back out to global scope when they no longer
# cause an atexit error
from scipy.signal.signaltools import _valfrommode, _bvalfromboundary from scipy.signal.signaltools import _valfrommode, _bvalfromboundary
from scipy.signal.sigtools import _convolve2d from scipy.signal.sigtools import _convolve2d
imported_scipy_signal = True imported_scipy_signal = True
...@@ -441,18 +439,15 @@ class BaseAbstractConv2d(Op): ...@@ -441,18 +439,15 @@ class BaseAbstractConv2d(Op):
# This may change in the future. # This may change in the future.
return False return False
def corr2d(self, img, kern, mode="valid"): def conv2d(self, img, kern, mode="valid"):
""" """
Basic slow python implementatation for DebugMode Basic slow python implementatation for DebugMode
""" """
if not imported_scipy_signal: if not imported_scipy_signal:
raise theano.gof.utils.MethodNotDefined( raise NotImplementedError(
"c_headers", type(self), self.__class__.__name__, "AbstractConv perform requires the python package"
"Need the python package for scipy.signal to be installed " " for scipy.signal to be installed.")
"for the python implementation. You can use the C"
" implementation instead.")
if not (mode in ('valid', 'full')): if not (mode in ('valid', 'full')):
raise ValueError( raise ValueError(
'invalid mode {}, which must be either ' 'invalid mode {}, which must be either '
...@@ -525,7 +520,7 @@ class AbstractConv2d(BaseAbstractConv2d): ...@@ -525,7 +520,7 @@ class AbstractConv2d(BaseAbstractConv2d):
img = new_img img = new_img
if not self.filter_flip: if not self.filter_flip:
kern = kern[:, :, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1]
conv_out = self.corr2d(img, kern, mode) conv_out = self.conv2d(img, kern, mode="valid")
conv_out = conv_out[:, :, ::self.subsample[0], ::self.subsample[1]] conv_out = conv_out[:, :, ::self.subsample[0], ::self.subsample[1]]
o[0] = node.outputs[0].type.filter(conv_out) o[0] = node.outputs[0].type.filter(conv_out)
...@@ -654,7 +649,7 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d): ...@@ -654,7 +649,7 @@ class AbstractConv2d_gradWeights(BaseAbstractConv2d):
topgrad = topgrad.transpose(1, 0, 2, 3)[:, :, ::-1, ::-1] topgrad = topgrad.transpose(1, 0, 2, 3)[:, :, ::-1, ::-1]
img = img.transpose(1, 0, 2, 3) img = img.transpose(1, 0, 2, 3)
kern = self.corr2d(img, topgrad, mode) kern = self.conv2d(img, topgrad, mode="valid")
if self.filter_flip: if self.filter_flip:
kern = kern.transpose(1, 0, 2, 3)[:, :, ::-1, ::-1] kern = kern.transpose(1, 0, 2, 3)[:, :, ::-1, ::-1]
else: else:
...@@ -762,7 +757,6 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d): ...@@ -762,7 +757,6 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d):
pad_h, pad_w = (kern.shape[2] // 2, kern.shape[3] // 2) pad_h, pad_w = (kern.shape[2] // 2, kern.shape[3] // 2)
elif isinstance(mode, tuple): elif isinstance(mode, tuple):
pad_h, pad_w = map(int, self.border_mode) pad_h, pad_w = map(int, self.border_mode)
mode = "valid"
if self.subsample[0] > 1 or self.subsample[1] > 1: if self.subsample[0] > 1 or self.subsample[1] > 1:
new_shape = (topgrad.shape[0], topgrad.shape[1], new_shape = (topgrad.shape[0], topgrad.shape[1],
shape[0] + 2 * pad_h - kern.shape[2] + 1, shape[0] + 2 * pad_h - kern.shape[2] + 1,
...@@ -773,7 +767,7 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d): ...@@ -773,7 +767,7 @@ class AbstractConv2d_gradInputs(BaseAbstractConv2d):
kern = kern.transpose(1, 0, 2, 3) kern = kern.transpose(1, 0, 2, 3)
if self.filter_flip: if self.filter_flip:
topgrad = topgrad[:, :, ::-1, ::-1] topgrad = topgrad[:, :, ::-1, ::-1]
img = self.corr2d(topgrad, kern, mode="full") img = self.conv2d(topgrad, kern, mode="full")
if self.filter_flip: if self.filter_flip:
img = img[:, :, ::-1, ::-1] img = img[:, :, ::-1, ::-1]
if pad_h > 0 or pad_w > 0: if pad_h > 0 or pad_w > 0:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论