提交 c5728f55 authored 作者: Frederic's avatar Frederic

Better handling of transfer and better error reporting and fix refcount.

上级 ea8153b2
...@@ -7,6 +7,7 @@ from theano import tensor ...@@ -7,6 +7,7 @@ from theano import tensor
from theano.compat.six import StringIO from theano.compat.six import StringIO
from theano.sandbox.cuda.type import CudaNdarrayType from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda import GpuOp from theano.sandbox.cuda import GpuOp
from theano.sandbox.cuda import as_cuda_ndarray_variable
class GpuDot22(GpuOp): class GpuDot22(GpuOp):
...@@ -542,6 +543,8 @@ class GpuConvMM(GpuOp): ...@@ -542,6 +543,8 @@ class GpuConvMM(GpuOp):
self.pad) self.pad)
def make_node(self, img, kern): def make_node(self, img, kern):
img = as_cuda_ndarray_variable(img)
kern = as_cuda_ndarray_variable(kern)
if img.type.ndim != 4: if img.type.ndim != 4:
raise TypeError('img must be 4D tensor') raise TypeError('img must be 4D tensor')
if kern.type.ndim != 4: if kern.type.ndim != 4:
...@@ -575,7 +578,7 @@ class GpuConvMM(GpuOp): ...@@ -575,7 +578,7 @@ class GpuConvMM(GpuOp):
def c_code_cache_version(self): def c_code_cache_version(self):
# raise this whenever modifying any of the support_code_files # raise this whenever modifying any of the support_code_files
return (0, 21) return (0, 22)
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
# REMEMBER TO RAISE c_code_cache_version when changing any of # REMEMBER TO RAISE c_code_cache_version when changing any of
......
...@@ -103,8 +103,18 @@ CudaNdarray* validMM(const CudaNdarray *input, ...@@ -103,8 +103,18 @@ CudaNdarray* validMM(const CudaNdarray *input,
// filters: (number of filters, nInputPlane, rows, columns) // filters: (number of filters, nInputPlane, rows, columns)
int nOutputPlane = CudaNdarray_HOST_DIMS(weight)[0]; int nOutputPlane = CudaNdarray_HOST_DIMS(weight)[0];
long batchSize = CudaNdarray_HOST_DIMS(input)[0]; long batchSize = CudaNdarray_HOST_DIMS(input)[0];
assert(kW == kH); //filters must be square (kW == kH) if (kW != kH){
assert(dW == dH); //stride must be square (dW == dH) PyErr_SetString(PyExc_ValueError,
"GpuConvMM support only square kernel\n"
);
return NULL;
}
if (kW != kH){
PyErr_SetString(PyExc_ValueError,
"GpuConvMM support only square images\n"
);
return NULL;
}
long inputHeight = CudaNdarray_HOST_DIMS(input)[2]; long inputHeight = CudaNdarray_HOST_DIMS(input)[2];
long inputWidth = CudaNdarray_HOST_DIMS(input)[3]; long inputWidth = CudaNdarray_HOST_DIMS(input)[3];
long outputWidth = (inputWidth + 2*padding - kW) / dW + 1; long outputWidth = (inputWidth + 2*padding - kW) / dW + 1;
...@@ -171,7 +181,6 @@ CudaNdarray* validMM(const CudaNdarray *input, ...@@ -171,7 +181,6 @@ CudaNdarray* validMM(const CudaNdarray *input,
} }
Py_DECREF(columns); Py_DECREF(columns);
return output; return output;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论