Unverified 提交 ddcc92dc authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6604 from abergeron/dnn_mix

Fix for mixed-dtype input in dnn_conv
......@@ -947,10 +947,14 @@ def _dnn_conv(img, kerns, alpha=1, beta=0, out=None, border_mode='valid', subsam
conv_mode='conv', algo=None, precision=None, num_groups=1):
ctx_name = infer_context_name(img, kerns)
img = gpu_contiguous(as_gpuarray_variable(img, ctx_name))
kerns = gpu_contiguous(as_gpuarray_variable(kerns, ctx_name))
img = as_gpuarray_variable(img, ctx_name)
kerns = as_gpuarray_variable(kerns, ctx_name)
precision = get_precision(precision, [img, kerns])
img = gpu_contiguous(img.astype(precision))
kerns = gpu_contiguous(kerns.astype(precision))
desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, dilation=dilation,
conv_mode=conv_mode, precision=precision, num_groups=num_groups)(kerns.shape)
desc_op = desc.owner.op
......@@ -974,11 +978,15 @@ def _dnn_gradweight(img, topgrad, kerns_shp, alpha=1, beta=0, out=None, border_m
dilation=(1, 1), conv_mode='conv', algo=None, precision=None, num_groups=1):
ctx_name = infer_context_name(img, topgrad)
img = gpu_contiguous(as_gpuarray_variable(img, ctx_name))
topgrad = gpu_contiguous(as_gpuarray_variable(topgrad, ctx_name))
img = as_gpuarray_variable(img, ctx_name)
topgrad = as_gpuarray_variable(topgrad, ctx_name)
kerns_shp = theano.tensor.as_tensor_variable(kerns_shp)
precision = get_precision(precision, [img, topgrad], for_grad=True)
img = gpu_contiguous(img.astype(precision))
topgrad = gpu_contiguous(topgrad.astype(precision))
desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, dilation=dilation,
conv_mode=conv_mode, precision=precision, num_groups=num_groups)(kerns_shp)
if beta == 0:
......@@ -995,11 +1003,15 @@ def _dnn_gradinput(kerns, topgrad, img_shp, alpha=1, beta=0, out=None, border_mo
dilation=(1, 1), conv_mode='conv', algo=None, precision=None, num_groups=1):
ctx_name = infer_context_name(kerns, topgrad)
kerns = gpu_contiguous(as_gpuarray_variable(kerns, ctx_name))
topgrad = gpu_contiguous(as_gpuarray_variable(topgrad, ctx_name))
kerns = as_gpuarray_variable(kerns, ctx_name)
topgrad = as_gpuarray_variable(topgrad, ctx_name)
img_shp = theano.tensor.as_tensor_variable(img_shp)
precision = get_precision(precision, [kerns, topgrad], for_grad=True)
kerns = gpu_contiguous(kerns.astype(precision))
topgrad = gpu_contiguous(topgrad.astype(precision))
desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, dilation=dilation,
conv_mode=conv_mode, precision=precision, num_groups=num_groups)(kerns.shape)
if beta == 0:
......
......@@ -220,6 +220,48 @@ def test_dnn_conv_invalid_precision():
yield (run_dnn_conv_invalid_precision, 3)
def test_dnn_conv_mixed_dtype():
mf = T.ftensor4()
md = T.dtensor4()
def assert_types(conv):
dt = conv.owner.inputs[0].dtype
assert conv.owner.inputs[1].dtype == dt
assert conv.owner.inputs[2].dtype == dt
assert_types(dnn.dnn_conv(md, mf, precision='as_input'))
assert_types(dnn.dnn_conv(mf, md, precision='as_input'))
assert_types(dnn.dnn_gradweight(mf, md, kerns_shp=mf.shape,
precision='as_input'))
assert_types(dnn.dnn_gradweight(md, mf, kerns_shp=mf.shape,
precision='as_input'))
assert_types(dnn.dnn_gradinput(mf, md, img_shp=mf.shape,
precision='as_input'))
assert_types(dnn.dnn_gradinput(md, mf, img_shp=mf.shape,
precision='as_input'))
def test_dnn_conv3d_mixed_dtype():
mf = T.ftensor5()
md = T.dtensor5()
def assert_types(conv):
dt = conv.owner.inputs[0].dtype
assert conv.owner.inputs[1].dtype == dt
assert conv.owner.inputs[2].dtype == dt
assert_types(dnn.dnn_conv3d(md, mf, precision='as_input'))
assert_types(dnn.dnn_conv3d(mf, md, precision='as_input'))
assert_types(dnn.dnn_gradweight3d(mf, md, kerns_shp=mf.shape,
precision='as_input'))
assert_types(dnn.dnn_gradweight3d(md, mf, kerns_shp=mf.shape,
precision='as_input'))
assert_types(dnn.dnn_gradinput3d(mf, md, img_shp=mf.shape,
precision='as_input'))
assert_types(dnn.dnn_gradinput3d(md, mf, img_shp=mf.shape,
precision='as_input'))
def test_pooling():
if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论