Rename precision to dtype in spatialtf

上级 105b33b2
......@@ -27,7 +27,7 @@ int APPLY_SPECIFIC(spatialtf_desc)(cudnnSpatialTransformerDescriptor_t * desc,
// Currently, only the bilinear sampler is supported by cuDNN,
// so it is not available as a parameter
err = cudnnSetSpatialTransformerNdDescriptor( *desc, CUDNN_SAMPLER_BILINEAR,
params->precision, params->nb_dims, out_tensor_dims );
params->dtype, params->nb_dims, out_tensor_dims );
if ( CUDNN_STATUS_SUCCESS != err )
{
PyErr_Format( PyExc_MemoryError,
......
......@@ -2840,9 +2840,9 @@ class GpuDnnSpatialTfDesc(COp):
operations.
"""
__props__ = ('dimensions', 'precision')
__props__ = ('dimensions', 'dtype')
params_type = ParamsType(nimages=int_t, nchannels=int_t, height=int_t, width=int_t,
nb_dims=int_t, precision=cudnn.cudnnDataType_t)
nb_dims=int_t, dtype=cudnn.cudnnDataType_t)
def c_headers(self):
return ['cudnn.h', 'cudnn_helper.h']
......@@ -2859,7 +2859,7 @@ class GpuDnnSpatialTfDesc(COp):
def do_constant_folding(self, node):
return False
def __init__(self, dimensions, precision="float32"):
def __init__(self, dimensions, dtype="float32"):
COp.__init__(self, ["c_code/spatialtf_desc.c"], "APPLY_SPECIFIC(spatialtf_desc)")
# dimensions must have at least width and height
......@@ -2871,8 +2871,8 @@ class GpuDnnSpatialTfDesc(COp):
# not exceed 4 dimensions (width, height, num_feature_maps, num_images)
assert len(self.dimensions) <= 4
assert cudnn.cudnnDataType_t.has_alias(precision)
self.precision = precision
assert cudnn.cudnnDataType_t.has_alias(dtype)
self.dtype = dtype
def make_node(self):
node = Apply(self, [],
......@@ -2908,14 +2908,14 @@ class GpuDnnGridGenerator(DnnBase):
operations.
"""
__props__ = ('precision',)
__props__ = ('dtype',)
_cop_num_inputs = 3
_cop_num_outputs = 1
def __init__(self, precision):
def __init__(self, dtype):
DnnBase.__init__(self, ["c_code/spatialtf_grid.c"], "spatialtf_grid")
self.precision = precision
self.dtype = dtype
def dnn_context(self, node):
return node.outputs[0].type.context_name
......@@ -2930,7 +2930,7 @@ class GpuDnnGridGenerator(DnnBase):
assert theta.ndim == 3
# Allocate GPU memory for grid of coordinates
grid = GpuArrayType(dtype=self.precision,
grid = GpuArrayType(dtype=self.dtype,
broadcastable=(False, False, False, False,),
context_name=context_name)()
......@@ -2947,14 +2947,14 @@ class GpuDnnGridSampler(DnnBase):
operations.
"""
__props__ = ('precision',)
__props__ = ('dtype',)
_cop_num_inputs = 6
_cop_num_outputs = 1
def __init__(self, precision):
def __init__(self, dtype):
DnnBase.__init__(self, ["c_code/spatialtf_sampler.c"], "spatialtf_sampler")
self.precision = precision
self.dtype = dtype
def dnn_context(self, node):
return node.outputs[0].type.context_name
......@@ -2967,7 +2967,7 @@ class GpuDnnGridSampler(DnnBase):
grid = as_gpuarray_variable(grid, context_name)
grid_dimensions = as_tensor_variable(grid_dimensions)
output = GpuArrayType(dtype=self.precision,
output = GpuArrayType(dtype=self.dtype,
broadcastable=img.type.ndim * (False,),
context_name=context_name)()
......@@ -2993,7 +2993,7 @@ class GpuDnnGridSampler(DnnBase):
pass
def dnn_spatialtf(img, theta, grid_dims, alpha=None, beta=None, precision=None):
def dnn_spatialtf(img, theta, grid_dims, alpha=None, beta=None, dtype=None):
"""
GPU spatial transformer using cuDNN from NVIDIA.
"""
......@@ -3001,18 +3001,18 @@ def dnn_spatialtf(img, theta, grid_dims, alpha=None, beta=None, precision=None):
img = gpu_contiguous(img)
theta = gpu_contiguous(theta)
precision = get_precision(precision, [img, theta])
dtype = get_precision(dtype, [img, theta])
# Create spatial transformer descriptor
desc = GpuDnnSpatialTfDesc(grid_dims, precision)()
desc = GpuDnnSpatialTfDesc(grid_dims, dtype)()
# Create grid dimensions variable
grid_dims_var = as_tensor_variable(grid_dims)
# Setup grid of coordinates
grid_coord = GpuDnnGridGenerator(precision)(grid_dims_var, theta, desc)
grid_coord = GpuDnnGridGenerator(dtype)(grid_dims_var, theta, desc)
grid_sampler = GpuDnnGridSampler(precision)(img, grid_coord, grid_dims_var, desc,
grid_sampler = GpuDnnGridSampler(dtype)(img, grid_coord, grid_dims_var, desc,
alpha, beta)
return grid_sampler
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论