提交 f669bd8e authored 作者: João Victor Risso's avatar João Victor Risso

Rename grid_dims to out_dims in spatial transformer ops

上级 eaa12937
...@@ -2772,13 +2772,13 @@ class GpuDnnTransformerDesc(COp): ...@@ -2772,13 +2772,13 @@ class GpuDnnTransformerDesc(COp):
assert cudnn.cudnnDataType_t.has_alias(precision) assert cudnn.cudnnDataType_t.has_alias(precision)
self.precision = precision self.precision = precision
def make_node(self, dimensions): def make_node(self, out_dims):
dimensions = as_tensor_variable(dimensions) out_dims = as_tensor_variable(out_dims)
assert dimensions.dtype in theano.tensor.basic.integer_dtypes assert out_dims.dtype in theano.tensor.basic.integer_dtypes
assert dimensions.ndim == 1 assert out_dims.ndim == 1
dimensions = theano.tensor.basic.cast(dimensions, 'int64') out_dims = theano.tensor.basic.cast(out_dims, 'int64')
node = Apply(self, [dimensions], node = Apply(self, [out_dims],
[CDataType("cudnnSpatialTransformerDescriptor_t", [CDataType("cudnnSpatialTransformerDescriptor_t",
freefunc="cudnnDestroySpatialTransformerDescriptor")()]) freefunc="cudnnDestroySpatialTransformerDescriptor")()])
# DebugMode cannot compare the values of CDataType variables, so by # DebugMode cannot compare the values of CDataType variables, so by
...@@ -2814,26 +2814,26 @@ class GpuDnnTransformerGrid(DnnBase): ...@@ -2814,26 +2814,26 @@ class GpuDnnTransformerGrid(DnnBase):
assert theta.dtype in ('float16', 'float32', 'float64') assert theta.dtype in ('float16', 'float32', 'float64')
assert theta.ndim == 3 assert theta.ndim == 3
# Setup grid dimensions using input from descriptor # Setup output dimensions using input from descriptor
grid_dims = as_tensor_variable(desc.owner.inputs[0]) out_dims = as_tensor_variable(desc.owner.inputs[0])
assert grid_dims.dtype in theano.tensor.basic.integer_dtypes assert out_dims.dtype in theano.tensor.basic.integer_dtypes
assert grid_dims.ndim == 1 assert out_dims.ndim == 1
# Ensure 64-bit ints are passed to the C code # Ensure 64-bit ints are passed to the C code
grid_dims = theano.tensor.basic.cast(grid_dims, 'int64') out_dims = theano.tensor.basic.cast(out_dims, 'int64')
grid = GpuArrayType(dtype=theta.dtype, grid = GpuArrayType(dtype=theta.dtype,
broadcastable=(theta.type.ndim + 1) * (False,), broadcastable=(theta.type.ndim + 1) * (False,),
context_name=context_name)() context_name=context_name)()
inputs = [theta, grid_dims, desc] inputs = [theta, out_dims, desc]
outputs = [grid] outputs = [grid]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def grad(self, inputs, grads): def grad(self, inputs, grads):
theta, grid_dims, desc = inputs theta, out_dims, desc = inputs
dgrid = grads[0] dgrid = grads[0]
dtheta = GpuDnnTransformerGradT()(dgrid, desc) dtheta = GpuDnnTransformerGradT()(dgrid, desc)
return [dtheta, grad_not_implemented(self, 1, grid_dims), DisconnectedType()()] return [dtheta, grad_not_implemented(self, 1, out_dims), DisconnectedType()()]
def connection_pattern(self, node): def connection_pattern(self, node):
# not connected to desc # not connected to desc
...@@ -3005,13 +3005,13 @@ def dnn_spatialtf(img, theta, scale_width=1, scale_height=1, precision=theano.co ...@@ -3005,13 +3005,13 @@ def dnn_spatialtf(img, theta, scale_width=1, scale_height=1, precision=theano.co
Also, the only grid sampler method available is the bilinear interpolation. Also, the only grid sampler method available is the bilinear interpolation.
""" """
grid_dims = (img.shape[0], img.shape[1], out_dims = (img.shape[0], img.shape[1],
img.shape[2] * scale_height, img.shape[2] * scale_height,
img.shape[3] * scale_width) img.shape[3] * scale_width)
grid_dims = tuple([as_scalar(v).astype('int32') for v in grid_dims]) out_dims = tuple([as_scalar(v).astype('int64') for v in out_dims])
# Create spatial transformer descriptor # Create spatial transformer descriptor
desc = GpuDnnTransformerDesc(precision)(grid_dims) desc = GpuDnnTransformerDesc(precision)(out_dims)
context_name = infer_context_name(desc) context_name = infer_context_name(desc)
img = gpu_contiguous(as_gpuarray_variable(img, context_name)) img = gpu_contiguous(as_gpuarray_variable(img, context_name))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论