提交 208b48ef authored 作者: Yann N. Dauphin's avatar Yann N. Dauphin

improved input validation

上级 01a30774
......@@ -2169,6 +2169,17 @@ class GpuDownsampleFactorMaxGradGrad(GpuOp):
self.ignore_border = ignore_border
def make_node(self, x, z, gx):
x = as_cuda_ndarray_variable(x)
z = as_cuda_ndarray_variable(z)
gx = as_cuda_ndarray_variable(gx)
if x.type.ndim != 4:
raise TypeError('x must be 4D tensor')
if z.type.ndim != 4:
raise TypeError('z must be 4D tensor')
if gx.type.ndim != 4:
raise TypeError('gx must be 4D tensor')
return Apply(self, [x, z, gx], [x.type()])
def c_code_cache_version(self):
......@@ -2262,8 +2273,6 @@ class GpuDownsampleFactorMaxGradGrad(GpuOp):
""" % locals()
def c_support_code_apply(self, node, nodename):
ignore_border = int(self.ignore_border)
return """
// ds0 is the downsampling factor in rows, ds1 in columns
template<int ds0, int ds1>
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论