Add initial implementation of grid generator op and fix typos

上级 f9a62a99
......@@ -2842,7 +2842,7 @@ class GpuDnnSpatialTfDesc(COp):
__props__ = ('dimensions', 'precision')
params_type = ParamsType(dim0=int_t, dim1=int_t, dim2=int_t, dim3=int_t,
precision=cudnn.cudnnDataType_t)
nb_dims=int_t, precision=cudnn.cudnnDataType_t)
def c_headers(self):
return ['cudnn.h', 'cudnn_helper.h']
......@@ -2888,7 +2888,7 @@ class GpuDnnSpatialTfDesc(COp):
# Grid height
dim1 = property(lambda self: self.dimensions[1])
# Number of feature maps
dim2 = property(lambda self: self.dimensions[2] if len(self.subsample) > 2 else 1)
dim2 = property(lambda self: self.dimensions[2] if len(self.dimensions) > 2 else 1)
# Number of images
dim3 = property(lambda self: self.dimensions[3] if len(self.dimensions) > 3 else 1)
# Number of dimensions in the output tensor
......@@ -2913,19 +2913,29 @@ class GpuDnnGridGeneratorOp(DnnBase):
DnnBase.__init__(self, ["c_code/spatialtf_grid.c"], "spatialtf_grid")
def dnn_context(self, node):
return node.outputs[1].type.context_name
return node.outputs[0].type.context_name
def make_node(self, desc, theta, cx=None):
def make_node(self, desc, dimensions, theta, cx=None):
if cx is None:
context_name = infer_context_name(theta)
context_name = infer_context_name(desc, theta)
else:
context_name = infer_context_name(theta, cx)
context_name = infer_context_name(desc, theta, cx)
# TODO: create output grid
grid = GpuArrayType()
precision = get_precision(None, [theta])
inputs = [desc, theta]
outputs = []
width, height = dimensions[:2]
num_feature_maps = dimensions[2] if len(dimensions) > 2 else 1
num_images = dimensions[3] if len(dimensions) > 3 else 1
dimensions_var = as_tensor_variable(dimensions)
# Allocate GPU memory for grid of coordinates
grid = GpuArrayType(dtype=precision,
broadcastable=(False, False, False, False,),
context_name=context_name)()
inputs = [desc, theta, dimensions_var]
outputs = [grid]
return Apply(self, inputs, outputs)
......@@ -2966,8 +2976,8 @@ def dnn_spatialtf_context(dimensions, precision="float32"):
return GpuDnnSpatialTfDesc(dimensions, precision)()
def dnn_spatialtf_grid():
pass
def dnn_spatialtf_grid(desc, dimensions, theta):
return GpuDnnGridGeneratorOp()(desc, dimensions, theta)
def dnn_spatialtf_sampler():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论