Refactor spatial transformer C implementation to use helper functions

上级 a203ad71
...@@ -2898,7 +2898,7 @@ class GpuDnnTransformer(DnnBase): ...@@ -2898,7 +2898,7 @@ class GpuDnnTransformer(DnnBase):
default_output = 0 default_output = 0
def __init__(self, dtype): def __init__(self, dtype):
DnnBase.__init__(self, ["c_code/dnn_sptf.c"], "dnn_sptf") DnnBase.__init__(self, ["c_code/dnn_sptf.c"], "APPLY_SPECIFIC(dnn_sptf)")
self.dtype = dtype self.dtype = dtype
def make_node(self, img, theta, output, grid_dims, desc, alpha=None, beta=None): def make_node(self, img, theta, output, grid_dims, desc, alpha=None, beta=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论