提交 90f90ac3 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

Fix subsampling in GpuCorr3dMM_gradWeigth

上级 b0bacd7b
......@@ -9,6 +9,7 @@ from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda import GpuOp
from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable,
gpu_contiguous)
from theano.tensor import as_tensor_variable
class GpuDot22(GpuOp):
......@@ -1307,8 +1308,8 @@ class GpuCorr3dMM(BaseGpuCorr3dMM):
batchsize or number of filters) may also work around the CUBLAS bug.
"""
def __init__(self, border_mode="valid",
subsample=(1, 1, 1),
pad=(0, 0, 0)):
subsample=(1, 1, 1),
pad=(0, 0, 0)):
"""
:param border_mode: currently supports "valid" only; "full" can be
simulated by setting `pad="full"` (at the cost of performance), or
......@@ -1375,6 +1376,9 @@ class GpuCorr3dMM_gradWeights(BaseGpuCorr3dMM):
def make_node(self, img, topgrad, shape=None):
img = as_cuda_ndarray_variable(img)
topgrad = as_cuda_ndarray_variable(topgrad)
if shape is not None:
shape = as_tensor_variable(shape)
if img.type.ndim != 5:
raise TypeError('img must be 5D tensor')
if topgrad.type.ndim != 5:
......
......@@ -1380,7 +1380,7 @@ def local_convgrad3d_gemm(node):
f = gpu_contiguous(f.dimshuffle(0, 4, 1, 2, 3))
rval = GpuCorr3dMM_gradWeights(subsample=(sx, sy, sz))(x, f,
shape=node.inputs[2])
shape=node.inputs[2][1:4])
# Shuffle from (ic, oc, 0, 1, t) to (oc, 0, 1, t, ic)
return [rval.dimshuffle(0, 2, 3, 4, 1)]
......
......@@ -69,10 +69,8 @@ class TestCorr3DMM(unittest.TestCase):
subsample=(1, 1, 1)):
inputs_val = numpy.random.random(inputs_shape).astype('float32')
dCdH_val = numpy.random.random(dCdH_shape).astype('float32')
filters_val = numpy.random.random(filters_shape).astype('float32')
inputs = shared(inputs_val)
dCdH = shared(dCdH_val)
filters = shared(filters_val)
conv = theano.tensor.nnet.convGrad3D(V=inputs, dCdH=dCdH,
WShape=filters_shape,
......@@ -85,7 +83,7 @@ class TestCorr3DMM(unittest.TestCase):
else:
conv_gemm = GpuCorr3dMM_gradWeights(subsample=subsample)(img,
topgrad,
shape=filters.shape[1:4])
shape=filters_shape[1:4])
conv_gemm = conv_gemm.dimshuffle(0, 2, 3, 4, 1)
f_ref = theano.function([], conv)
f = theano.function([], conv_gemm)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论