提交 674cd4f1 authored 作者: carriepl's avatar carriepl 提交者: Frederic

Add precision param to dnnConv (gpua backend)

上级 46c83387
......@@ -740,7 +740,7 @@ class GpuDnnConvGradI(DnnBase):
def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
conv_mode='conv', direction_hint=None, workmem=None,
algo=None):
algo=None, precision=None):
"""
GPU convolution using cuDNN from NVIDIA.
......@@ -774,6 +774,10 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
Convolution implementation to use. Some of its values may
require certain versions of CuDNN to be installed. Default is
the value of :attr:`config.dnn.conv.algo_fwd`.
precision : {'float16', 'float32', 'float64', 'floatX'}
Description of the dtype in which the computation of the convolution
should be done. Default is the value of
:attr:`config.dnn.conv.precision`.
.. warning:: The cuDNN library only works with GPUs that have a compute
capability of 3.0 or higer. This means that older GPUs will not
......@@ -803,7 +807,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
shape_i(kerns, 1, fgraph),
shape_i(img, 1, fgraph), shape2, shape3)
desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1),
conv_mode='cross')(out.shape)
conv_mode='cross', precision=precision)(out.shape)
conv = GpuDnnConvGradW()(img, kerns, out, desc)
return as_gpuarray_variable(conv.dimshuffle(1, 0, 2, 3), ctx_name)
......@@ -821,7 +825,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
shape_i(kerns, 1, fgraph),
shape2, shape3)
desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1),
conv_mode=conv_mode)(kerns.shape)
conv_mode=conv_mode, precision=precision)(kerns.shape)
return GpuDnnConvGradI()(kerns, img, out, desc)
# Standard case: We use GpuDnnConv with suitable padding.
......@@ -830,7 +834,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
img = gpu_contiguous(img)
kerns = gpu_contiguous(kerns)
desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample,
conv_mode=conv_mode)(kerns.shape)
conv_mode=conv_mode, precision=precision)(kerns.shape)
desc_op = desc.owner.op
out_shp = GpuDnnConv.get_out_shape(img.shape, kerns.shape,
desc_op.border_mode,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论