Add checks for input dimensions in GpuDnnTrasnformerGradI

上级 01a10e2d
...@@ -2937,17 +2937,17 @@ class GpuDnnTransformer(DnnBase): ...@@ -2937,17 +2937,17 @@ class GpuDnnTransformer(DnnBase):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def L_op(self, inputs, outputs, grads): def L_op(self, inputs, outputs, grads):
img, theta, grid_dims, desc, alpha, beta = inputs img, theta, output, desc, alpha, beta = inputs
_, grid = outputs _, grid = outputs
dy = grads[0] dy = grads[0]
dimg, dgrid = GpuDnnTransformerGradI(self.dtype)(img, theta, grid, grid_dims, dy, dimg, dgrid = GpuDnnTransformerGradI(self.dtype)(img, theta, grid, dy,
desc, alpha, beta) desc, alpha, beta)
dtheta = GpuDnnTransformerGradT(self.dtype)(dgrid, desc) dtheta = GpuDnnTransformerGradT(self.dtype)(dgrid, desc)
return [dimg, dtheta, return [dimg, dtheta,
theano.gradient.grad_undefined(self, 2, grid_dims), theano.gradient.grad_undefined(self, 2, output),
DisconnectedType()(), DisconnectedType()(),
theano.gradient.grad_undefined(self, 4, alpha), theano.gradient.grad_undefined(self, 4, alpha),
theano.gradient.grad_undefined(self, 5, beta)] theano.gradient.grad_undefined(self, 5, beta)]
...@@ -2971,21 +2971,31 @@ class GpuDnnTransformerGradI(DnnBase): ...@@ -2971,21 +2971,31 @@ class GpuDnnTransformerGradI(DnnBase):
self.dtype = dtype self.dtype = dtype
def make_node(self, img, theta, grid, dy, desc, alpha, beta): def make_node(self, img, theta, grid, dy, desc, alpha, beta):
context_name = infer_context_name(img) context_name = infer_context_name(img, theta, grid, dy, desc)
if img.ndim != 4: if (not isinstance(desc.type, CDataType) or
raise RuntimeError('img must have 4 dimensions.') desc.type.ctype != 'cudnnSpatialTransformerDescriptor_t'):
if theta.ndim != 3: raise ValueError('desc must be cudnnSpatialTransformerDescriptor_t')
raise RuntimeError('theta must have 3 dimensions')
img = as_gpuarray_variable(gpu_contiguous(img), context_name) img = as_gpuarray_variable(gpu_contiguous(img), context_name)
if img.ndim != 4:
raise TypeError('img must have 4 dimensions.')
theta = as_gpuarray_variable(gpu_contiguous(theta), context_name) theta = as_gpuarray_variable(gpu_contiguous(theta), context_name)
if theta.ndim != 3:
raise TypeError('theta must have 3 dimensions')
grid = as_gpuarray_variable(gpu_contiguous(grid), context_name) grid = as_gpuarray_variable(gpu_contiguous(grid), context_name)
if img.ndim != grid.ndim:
raise TypeError('grid should have the same number of dimensions as img')
# Setup grid dimensions from descriptor's input # Setup grid dimensions from descriptor's input
grid_dims = as_tensor_variable(desc.owner.inputs[0]) grid_dims = as_tensor_variable(desc.owner.inputs[0])
dy = as_gpuarray_variable(dy, context_name) dy = as_gpuarray_variable(dy, context_name)
if img.ndim != 4:
raise TypeError('img must have 4 dimensions.')
alpha = as_scalar(alpha) alpha = as_scalar(alpha)
beta = as_scalar(beta) beta = as_scalar(beta)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论