Add checks for input dimensions in GpuDnnTrasnformerGradI

上级 01a10e2d
......@@ -2937,17 +2937,17 @@ class GpuDnnTransformer(DnnBase):
return Apply(self, inputs, outputs)
def L_op(self, inputs, outputs, grads):
img, theta, grid_dims, desc, alpha, beta = inputs
img, theta, output, desc, alpha, beta = inputs
_, grid = outputs
dy = grads[0]
dimg, dgrid = GpuDnnTransformerGradI(self.dtype)(img, theta, grid, grid_dims, dy,
dimg, dgrid = GpuDnnTransformerGradI(self.dtype)(img, theta, grid, dy,
desc, alpha, beta)
dtheta = GpuDnnTransformerGradT(self.dtype)(dgrid, desc)
return [dimg, dtheta,
theano.gradient.grad_undefined(self, 2, grid_dims),
theano.gradient.grad_undefined(self, 2, output),
DisconnectedType()(),
theano.gradient.grad_undefined(self, 4, alpha),
theano.gradient.grad_undefined(self, 5, beta)]
......@@ -2971,21 +2971,31 @@ class GpuDnnTransformerGradI(DnnBase):
self.dtype = dtype
def make_node(self, img, theta, grid, dy, desc, alpha, beta):
context_name = infer_context_name(img)
context_name = infer_context_name(img, theta, grid, dy, desc)
if img.ndim != 4:
raise RuntimeError('img must have 4 dimensions.')
if theta.ndim != 3:
raise RuntimeError('theta must have 3 dimensions')
if (not isinstance(desc.type, CDataType) or
desc.type.ctype != 'cudnnSpatialTransformerDescriptor_t'):
raise ValueError('desc must be cudnnSpatialTransformerDescriptor_t')
img = as_gpuarray_variable(gpu_contiguous(img), context_name)
if img.ndim != 4:
raise TypeError('img must have 4 dimensions.')
theta = as_gpuarray_variable(gpu_contiguous(theta), context_name)
if theta.ndim != 3:
raise TypeError('theta must have 3 dimensions')
grid = as_gpuarray_variable(gpu_contiguous(grid), context_name)
if img.ndim != grid.ndim:
raise TypeError('grid should have the same number of dimensions as img')
# Setup grid dimensions from descriptor's input
grid_dims = as_tensor_variable(desc.owner.inputs[0])
dy = as_gpuarray_variable(dy, context_name)
if img.ndim != 4:
raise TypeError('img must have 4 dimensions.')
alpha = as_scalar(alpha)
beta = as_scalar(beta)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论