提交 f669bd8e authored 作者: João Victor Risso's avatar João Victor Risso

Rename grid_dims to out_dims in spatial transformer ops

上级 eaa12937
......@@ -2772,13 +2772,13 @@ class GpuDnnTransformerDesc(COp):
assert cudnn.cudnnDataType_t.has_alias(precision)
self.precision = precision
def make_node(self, dimensions):
dimensions = as_tensor_variable(dimensions)
assert dimensions.dtype in theano.tensor.basic.integer_dtypes
assert dimensions.ndim == 1
dimensions = theano.tensor.basic.cast(dimensions, 'int64')
def make_node(self, out_dims):
out_dims = as_tensor_variable(out_dims)
assert out_dims.dtype in theano.tensor.basic.integer_dtypes
assert out_dims.ndim == 1
out_dims = theano.tensor.basic.cast(out_dims, 'int64')
node = Apply(self, [dimensions],
node = Apply(self, [out_dims],
[CDataType("cudnnSpatialTransformerDescriptor_t",
freefunc="cudnnDestroySpatialTransformerDescriptor")()])
# DebugMode cannot compare the values of CDataType variables, so by
......@@ -2814,26 +2814,26 @@ class GpuDnnTransformerGrid(DnnBase):
assert theta.dtype in ('float16', 'float32', 'float64')
assert theta.ndim == 3
# Setup grid dimensions using input from descriptor
grid_dims = as_tensor_variable(desc.owner.inputs[0])
assert grid_dims.dtype in theano.tensor.basic.integer_dtypes
assert grid_dims.ndim == 1
# Setup output dimensions using input from descriptor
out_dims = as_tensor_variable(desc.owner.inputs[0])
assert out_dims.dtype in theano.tensor.basic.integer_dtypes
assert out_dims.ndim == 1
# Ensure 64-bit ints are passed to the C code
grid_dims = theano.tensor.basic.cast(grid_dims, 'int64')
out_dims = theano.tensor.basic.cast(out_dims, 'int64')
grid = GpuArrayType(dtype=theta.dtype,
broadcastable=(theta.type.ndim + 1) * (False,),
context_name=context_name)()
inputs = [theta, grid_dims, desc]
inputs = [theta, out_dims, desc]
outputs = [grid]
return Apply(self, inputs, outputs)
def grad(self, inputs, grads):
theta, grid_dims, desc = inputs
theta, out_dims, desc = inputs
dgrid = grads[0]
dtheta = GpuDnnTransformerGradT()(dgrid, desc)
return [dtheta, grad_not_implemented(self, 1, grid_dims), DisconnectedType()()]
return [dtheta, grad_not_implemented(self, 1, out_dims), DisconnectedType()()]
def connection_pattern(self, node):
# not connected to desc
......@@ -3005,13 +3005,13 @@ def dnn_spatialtf(img, theta, scale_width=1, scale_height=1, precision=theano.co
Also, the only grid sampler method available is the bilinear interpolation.
"""
grid_dims = (img.shape[0], img.shape[1],
out_dims = (img.shape[0], img.shape[1],
img.shape[2] * scale_height,
img.shape[3] * scale_width)
grid_dims = tuple([as_scalar(v).astype('int32') for v in grid_dims])
out_dims = tuple([as_scalar(v).astype('int64') for v in out_dims])
# Create spatial transformer descriptor
desc = GpuDnnTransformerDesc(precision)(grid_dims)
desc = GpuDnnTransformerDesc(precision)(out_dims)
context_name = infer_context_name(desc)
img = gpu_contiguous(as_gpuarray_variable(img, context_name))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论