提交 a35f07e3 authored 作者: Frederic's avatar Frederic

make sure the input var are of the right type.

上级 5fc89c03
......@@ -5,6 +5,7 @@ import theano
from theano import config, gof
from theano.sandbox.cuda.nvcc_compiler import NVCC_compiler
from theano.sandbox.gpuarray.type import GpuArrayType
from theano.sandbox.gpuarray.basic_ops import as_gpuarray_variable
class GpuConv(gof.Op):
......@@ -126,7 +127,8 @@ class GpuConv(gof.Op):
raise TypeError('img must be 4D tensor')
if kern.type.ndim != 4:
raise TypeError('kern must be 4D tensor')
img = as_gpuarray_variable(img)
kern = as_gpuarray_variable(kern)
broadcastable = [img.type.broadcastable[0], kern.type.broadcastable[0],
False, False]
out = GpuArrayType(img.dtype, broadcastable)()
......@@ -195,8 +197,8 @@ class GpuConv(gof.Op):
# these files
files = ['conv_kernel.cu', 'conv_full_kernel.cu', 'conv.cu']
codes = ["CUdeviceptr (*cuda_get_ptr_raw)(gpudata *g);",
"float* cuda_get_ptr(PyGpuArrayObject * o){return (float*) cuda_get_ptr_raw(o->ga.data);}",
"const float* cuda_get_ptr(const PyGpuArrayObject * o){return (float*) cuda_get_ptr_raw(o->ga.data);}"]
"float* cuda_get_ptr(PyGpuArrayObject * o){return (float*) (cuda_get_ptr_raw(o->ga.data) + o->ga.offset);}",
"const float* cuda_get_ptr(const PyGpuArrayObject * o){return (float*) (cuda_get_ptr_raw(o->ga.data) + o->ga.offset);}"]
codes += [open(os.path.join(os.path.split(__file__)[0], f)).read()
for f in files]
return reduce(str.__add__, codes)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论