提交 4dd87c9e authored 作者: notoraptor's avatar notoraptor

Raise an error in any case if precision is float16 for grad convs.

上级 27f03a00
......@@ -982,9 +982,6 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1),
"""
# Establish dtype in which to perform the computation of the convolution
precision = get_precision(precision, [img, kerns])
if workmem is not None:
if algo is not None:
raise ValueError("You can't use both algo and workmem")
......@@ -1008,8 +1005,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1),
shape_i(img, 3, fgraph) - shape_i(kerns, 3, fgraph) + 1)
out_shp = assert_conv_shape(out_shp)
out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*out_shp)
if precision == 'float16':
precision = 'float32'
precision = get_precision(precision, [img, kerns], for_grad=True)
desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1), dilation=(1, 1),
conv_mode='cross', precision=precision)(out.shape)
conv = GpuDnnConvGradW()(img, kerns, out, desc)
......@@ -1029,8 +1025,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1),
shape_i(img, 3, fgraph) + (shape_i(kerns, 3, fgraph) - 1) * dilation[1])
out_shp = assert_conv_shape(out_shp)
out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*out_shp)
if precision == 'float16':
precision = 'float32'
precision = get_precision(precision, [img, kerns], for_grad=True)
desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1), dilation=dilation,
conv_mode=conv_mode, precision=precision)(kerns.shape)
return GpuDnnConvGradI()(kerns, img, out, desc)
......@@ -1040,6 +1035,8 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1),
# if the img contains negative strides
img = gpu_contiguous(img)
kerns = gpu_contiguous(kerns)
# Establish dtype in which to perform the computation of the convolution
precision = get_precision(precision, [img, kerns])
desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, dilation=dilation,
conv_mode=conv_mode, precision=precision,
num_groups=num_groups)(kerns.shape)
......@@ -1113,9 +1110,6 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1
"""
# Establish dtype in which to perform the computation of the convolution
precision = get_precision(precision, [img, kerns])
fgraph = getattr(img, 'fgraph', None) or getattr(kerns, 'fgraph', None)
ctx_name = infer_context_name(img, kerns)
if (border_mode == 'valid' and subsample == (1, 1, 1) and dilation == (1, 1, 1) and
......@@ -1135,8 +1129,7 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1
shape_i(img, 4, fgraph) - shape_i(kerns, 4, fgraph) + 1)
out_shp = assert_conv_shape(out_shp)
out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*out_shp)
if precision == 'float16':
precision = 'float32'
precision = get_precision(precision, [img, kerns], for_grad=True)
desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1, 1), dilation=(1, 1, 1),
conv_mode='cross', precision=precision)(out.shape)
conv = GpuDnnConvGradW()(img, kerns, out, desc)
......@@ -1157,8 +1150,7 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1
shape_i(img, 4, fgraph) + (shape_i(kerns, 4, fgraph) - 1) * dilation[2])
out_shp = assert_conv_shape(out_shp)
out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*out_shp)
if precision == 'float16':
precision = 'float32'
precision = get_precision(precision, [img, kerns], for_grad=True)
desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1, 1), dilation=dilation,
conv_mode=conv_mode, precision=precision)(kerns.shape)
return GpuDnnConvGradI()(kerns, img, out, desc)
......@@ -1168,6 +1160,8 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1
# if the img contains negative strides
img = gpu_contiguous(img)
kerns = gpu_contiguous(kerns)
# Establish dtype in which to perform the computation of the convolution
precision = get_precision(precision, [img, kerns])
desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, dilation=dilation,
conv_mode=conv_mode, precision=precision,
num_groups=num_groups)(kerns.shape)
......
......@@ -171,17 +171,29 @@ def test_dnn_conv_inplace():
assert len([n for n in topo if isinstance(n.op, GpuAllocEmpty)]) == 2
def test_dnn_conv_invalid_precision():
img = T.tensor4()
kerns = T.tensor4()
topgrad = T.tensor4()
shape = (1, 2, 3, 4)
def run_dnn_conv_invalid_precision(ndim):
bc = (False,) * (ndim + 2)
img = T.tensor(theano.config.floatX, broadcastable=bc)
kerns = T.tensor(theano.config.floatX, broadcastable=bc)
topgrad = T.tensor(theano.config.floatX, broadcastable=bc)
shape = np.arange(ndim + 2)
if ndim == 2:
dnn_conv_func = dnn.dnn_conv
dnn_gradw_func = dnn.dnn_gradweight
dnn_gradi_func = dnn.dnn_gradinput
elif ndim == 3:
dnn_conv_func = dnn.dnn_conv3d
dnn_gradw_func = dnn.dnn_gradweight3d
dnn_gradi_func = dnn.dnn_gradinput3d
def dnn_gradw(precision):
return dnn.dnn_gradweight(img, topgrad, shape, precision=precision)
return dnn_gradw_func(img, topgrad, shape, precision=precision)
def dnn_gradi(precision):
return dnn.dnn_gradinput(kerns, topgrad, shape, precision=precision)
return dnn_gradi_func(kerns, topgrad, shape, precision=precision)
def dnn_conv(precision, border_mode, direction_hint):
return dnn_conv_func(img, kerns, border_mode=border_mode, direction_hint=direction_hint, precision=precision)
dnn_gradw('float64')
dnn_gradw('float32')
......@@ -191,6 +203,22 @@ def test_dnn_conv_invalid_precision():
dnn_gradi('float32')
assert_raises(TypeError, dnn_gradi, 'float16')
for precision in ('float64', 'float32'):
dnn_conv(precision, 'valid', None)
dnn_conv(precision, 'valid', 'bprop weights')
dnn_conv(precision, 'full', None)
dnn_conv(precision, 'full', 'forward!')
dnn_conv('float16', 'valid', None)
assert_raises(TypeError, dnn_conv, 'float16', 'valid', 'bprop weights')
assert_raises(TypeError, dnn_conv, 'float16', 'full', None)
dnn_conv('float16', 'full', 'forward!')
def test_dnn_conv_invalid_precision():
yield (run_dnn_conv_invalid_precision, 2)
yield (run_dnn_conv_invalid_precision, 3)
def test_pooling():
if not dnn.dnn_available(test_ctx_name):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论