提交 f8c749ee authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add support for float16 convolution with cudnn.

上级 d5944c96
...@@ -393,6 +393,8 @@ _one = constant(numpy.asarray(1.0, dtype='float64')) ...@@ -393,6 +393,8 @@ _one = constant(numpy.asarray(1.0, dtype='float64'))
def ensure_dt(val, default, name, dtype): def ensure_dt(val, default, name, dtype):
if dtype == 'float16':
dtype = 'float32'
if val is None: if val is None:
val = default.clone() val = default.clone()
if not isinstance(val, Variable): if not isinstance(val, Variable):
...@@ -422,7 +424,7 @@ class GpuDnnConv(DnnBase): ...@@ -422,7 +424,7 @@ class GpuDnnConv(DnnBase):
Default is the value of :attr:`config.dnn.conv.algo_fwd`. Default is the value of :attr:`config.dnn.conv.algo_fwd`.
""" """
_f16_ok = True
__props__ = ('algo', 'inplace') __props__ = ('algo', 'inplace')
def __init__(self, algo=None, inplace=False): def __init__(self, algo=None, inplace=False):
...@@ -605,7 +607,7 @@ class GpuDnnConvGradW(DnnBase): ...@@ -605,7 +607,7 @@ class GpuDnnConvGradW(DnnBase):
Default is the value of :attr:`config.dnn.conv.algo_bwd_filter`. Default is the value of :attr:`config.dnn.conv.algo_bwd_filter`.
""" """
_f16_ok = True
__props__ = ('algo', 'inplace') __props__ = ('algo', 'inplace')
def __init__(self, inplace=False, algo=None): def __init__(self, inplace=False, algo=None):
...@@ -720,7 +722,6 @@ gpu_dnn_conv_gradW.cache = {} ...@@ -720,7 +722,6 @@ gpu_dnn_conv_gradW.cache = {}
class GpuDnnConvGradI(DnnBase): class GpuDnnConvGradI(DnnBase):
""" """
The convolution gradient with respect to the inputs. The convolution gradient with respect to the inputs.
...@@ -735,7 +736,7 @@ class GpuDnnConvGradI(DnnBase): ...@@ -735,7 +736,7 @@ class GpuDnnConvGradI(DnnBase):
Default is the value of :attr:`config.dnn.conv.algo_bwd_data`. Default is the value of :attr:`config.dnn.conv.algo_bwd_data`.
""" """
_f16_ok = True
__props__ = ('algo', 'inplace',) __props__ = ('algo', 'inplace',)
def __init__(self, inplace=False, algo=None): def __init__(self, inplace=False, algo=None):
...@@ -1149,7 +1150,7 @@ class GpuDnnPool(DnnBase): ...@@ -1149,7 +1150,7 @@ class GpuDnnPool(DnnBase):
(padX, padY) or (padX, padY, padZ) (padX, padY) or (padX, padY, padZ)
""" """
_f16_ok = True
__props__ = ('mode',) __props__ = ('mode',)
def __init__(self, mode='max'): def __init__(self, mode='max'):
...@@ -1234,7 +1235,7 @@ class GpuDnnPoolGrad(DnnBase): ...@@ -1234,7 +1235,7 @@ class GpuDnnPoolGrad(DnnBase):
(padX, padY) or (padX, padY, padZ) (padX, padY) or (padX, padY, padZ)
""" """
_f16_ok = True
__props__ = ('mode',) __props__ = ('mode',)
def __init__(self, mode='max'): def __init__(self, mode='max'):
......
...@@ -308,6 +308,8 @@ class Scalar(Type): ...@@ -308,6 +308,8 @@ class Scalar(Type):
""" % locals() """ % locals()
def c_extract(self, name, sub, check_input=True): def c_extract(self, name, sub, check_input=True):
if self.dtype == 'float16':
raise NotImplementedError('float16')
specs = self.dtype_specs() specs = self.dtype_specs()
if(check_input): if(check_input):
pre = """ pre = """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论