提交 daea21a5 authored 作者: João Victor Risso's avatar João Victor Risso

Add docstrings for make_node functions of GpuDnnTransformerSampler and GpuDnnTransformerGrid

上级 5510faf5
......@@ -2806,6 +2806,19 @@ class GpuDnnTransformerGrid(DnnBase):
DnnBase.__init__(self, ["c_code/dnn_sptf_grid.c"], "APPLY_SPECIFIC(dnn_sptf_grid)")
def make_node(self, theta, desc):
"""
Create a grid generator node for a cuDNN Spatial Transformer
Parameters
----------
theta : tensor
Affine transformation tensor containing one affine transformation
matrix per image. ``theta`` is usually generated by the localization
network.
desc : GpuDnnTransformerDesc
Spatial transformer descriptor
"""
context_name = infer_context_name(desc)
if (not isinstance(desc.type, CDataType) or
......@@ -2856,6 +2869,24 @@ class GpuDnnTransformerSampler(DnnBase):
DnnBase.__init__(self, ["c_code/dnn_sptf_sampler.c"], "APPLY_SPECIFIC(dnn_sptf_sampler)")
def make_node(self, img, grid, desc):
"""
Create a grid sampler node for a cuDNN Spatial Transformer
Parameters
----------
img : tensor
Images from which the pixels will be sampled. The implementation
assumes the tensor is in NCHW format, where N is the number of images,
C is the number of color channels, H is the height of the inputs, and
W is width of the inputs.
grid : GpuDnnTransformerGrid
Grid that contains the coordinates of the pixels to be sampled from
the inputs images.
desc : GpuDnnTransformerDesc
Spatial transformer descriptor
"""
context_name = infer_context_name(desc)
if (not isinstance(desc.type, CDataType) or
......@@ -3012,7 +3043,7 @@ def dnn_spatialtf(img, theta, scale_width=1, scale_height=1, precision=theano.co
Currently, cuDNN only supports 2D transformations with 2x3 affine
transformation matrices.
Also, the only grid sampler method available is the bilinear interpolation.
Bilinear interpolation is the only grid sampler method available.
"""
out_dims = (img.shape[0], img.shape[1],
theano.tensor.ceil(img.shape[2] * scale_height),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论