提交 77bce880 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add support for float16/float64 to Corr3dMM.

上级 0c2eb3f0
......@@ -496,6 +496,7 @@ class BaseGpuCorrMM(CGpuKernelBase):
return [os.path.dirname(__file__)]
def c_code_cache_version(self):
# Raise this whenever modifying the code below.
return (2,)
def c_code_helper(self, bottom, weights, top, direction, sub, height=None, width=None):
......@@ -958,7 +959,7 @@ class GpuCorrMM_gradInputs(BaseGpuCorrMM):
return [[1], [1], [0], [0]] # no connection to height, width
class BaseGpuCorr3dMM(CGpuKernelBase, BlasOp):
class BaseGpuCorr3dMM(CGpuKernelBase):
"""
Base class for `GpuCorr3dMM`, `GpuCorr3dMM_gradWeights` and
`GpuCorr3dMM_gradInputs`. Cannot be used directly.
......@@ -972,10 +973,11 @@ class BaseGpuCorr3dMM(CGpuKernelBase, BlasOp):
Perform subsampling of the output (default: (1, 1, 1)).
filter_dilation
Perform subsampling of the input, also known as dilation (default: (1, 1, 1)).
"""
"""
check_broadcast = False
__props__ = ('border_mode', 'subsample', 'filter_dilation')
_f16_ok = True
def __init__(self, border_mode="valid", subsample=(1, 1, 1),
filter_dilation=(1, 1, 1)):
......@@ -1033,9 +1035,15 @@ class BaseGpuCorr3dMM(CGpuKernelBase, BlasOp):
def get_params(self, node):
return node.inputs[0].type.context
def c_headers(self):
return ["<gpuarray/array.h>", "<gpuarray/blas.h>", "gpuarray_helper.h"]
def c_header_dirs(self):
return [os.path.dirname(__file__)]
def c_code_cache_version(self):
# raise this whenever modifying any of the support_code_files
return (0, 2)
# raise this whenever modifying the code below.
return (2,)
def c_code_helper(self, bottom, weights, top, direction, sub,
height=None, width=None, depth=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论