Rename _GpuDnnTransformerDescriptor to GpuDnnTransformerDescriptor

上级 a947598a
...@@ -2833,7 +2833,7 @@ def local_abstractconv3d_cudnn_graph(op, context_name, inputs, outputs): ...@@ -2833,7 +2833,7 @@ def local_abstractconv3d_cudnn_graph(op, context_name, inputs, outputs):
return [rval] return [rval]
class _GpuDnnTransformerDescriptor(COp): class GpuDnnTransformerDescriptor(COp):
""" """
This Op builds a spatial transformer descriptor for use in spatial transformer network This Op builds a spatial transformer descriptor for use in spatial transformer network
...@@ -2883,7 +2883,7 @@ class _GpuDnnTransformerDescriptor(COp): ...@@ -2883,7 +2883,7 @@ class _GpuDnnTransformerDescriptor(COp):
return node return node
def c_code_cache_version(self): def c_code_cache_version(self):
return (super(_GpuDnnTransformerDescriptor, self).c_code_cache_version(), version()) return (super(GpuDnnTransformerDescriptor, self).c_code_cache_version(), version())
class GpuDnnTransformer(DnnBase): class GpuDnnTransformer(DnnBase):
...@@ -3052,7 +3052,7 @@ def dnn_spatialtf(inp, theta, scale_width=1, scale_height=1, alpha=None, beta=No ...@@ -3052,7 +3052,7 @@ def dnn_spatialtf(inp, theta, scale_width=1, scale_height=1, alpha=None, beta=No
theta = gpu_contiguous(theta) theta = gpu_contiguous(theta)
# Create spatial transformer descriptor # Create spatial transformer descriptor
desc = _GpuDnnTransformerDescriptor(dtype)(grid_dims) desc = GpuDnnTransformerDescriptor(dtype)(grid_dims)
# Create grid dimensions variable # Create grid dimensions variable
grid_dims_var = as_tensor_variable(grid_dims) grid_dims_var = as_tensor_variable(grid_dims)
# Setup spatial transformer # Setup spatial transformer
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论