提交 674cd4f1 authored 作者: carriepl's avatar carriepl 提交者: Frederic

Add precision param to dnnConv (gpua backend)

上级 46c83387
...@@ -740,7 +740,7 @@ class GpuDnnConvGradI(DnnBase): ...@@ -740,7 +740,7 @@ class GpuDnnConvGradI(DnnBase):
def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
conv_mode='conv', direction_hint=None, workmem=None, conv_mode='conv', direction_hint=None, workmem=None,
algo=None): algo=None, precision=None):
""" """
GPU convolution using cuDNN from NVIDIA. GPU convolution using cuDNN from NVIDIA.
...@@ -774,6 +774,10 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -774,6 +774,10 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
Convolution implementation to use. Some of its values may Convolution implementation to use. Some of its values may
require certain versions of CuDNN to be installed. Default is require certain versions of CuDNN to be installed. Default is
the value of :attr:`config.dnn.conv.algo_fwd`. the value of :attr:`config.dnn.conv.algo_fwd`.
precision : {'float16', 'float32', 'float64', 'floatX'}
Description of the dtype in which the computation of the convolution
should be done. Default is the value of
:attr:`config.dnn.conv.precision`.
.. warning:: The cuDNN library only works with GPUs that have a compute .. warning:: The cuDNN library only works with GPUs that have a compute
capability of 3.0 or higer. This means that older GPUs will not capability of 3.0 or higer. This means that older GPUs will not
...@@ -803,7 +807,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -803,7 +807,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
shape_i(kerns, 1, fgraph), shape_i(kerns, 1, fgraph),
shape_i(img, 1, fgraph), shape2, shape3) shape_i(img, 1, fgraph), shape2, shape3)
desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1), desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1),
conv_mode='cross')(out.shape) conv_mode='cross', precision=precision)(out.shape)
conv = GpuDnnConvGradW()(img, kerns, out, desc) conv = GpuDnnConvGradW()(img, kerns, out, desc)
return as_gpuarray_variable(conv.dimshuffle(1, 0, 2, 3), ctx_name) return as_gpuarray_variable(conv.dimshuffle(1, 0, 2, 3), ctx_name)
...@@ -821,7 +825,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -821,7 +825,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
shape_i(kerns, 1, fgraph), shape_i(kerns, 1, fgraph),
shape2, shape3) shape2, shape3)
desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1), desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1),
conv_mode=conv_mode)(kerns.shape) conv_mode=conv_mode, precision=precision)(kerns.shape)
return GpuDnnConvGradI()(kerns, img, out, desc) return GpuDnnConvGradI()(kerns, img, out, desc)
# Standard case: We use GpuDnnConv with suitable padding. # Standard case: We use GpuDnnConv with suitable padding.
...@@ -830,7 +834,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -830,7 +834,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
img = gpu_contiguous(img) img = gpu_contiguous(img)
kerns = gpu_contiguous(kerns) kerns = gpu_contiguous(kerns)
desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample,
conv_mode=conv_mode)(kerns.shape) conv_mode=conv_mode, precision=precision)(kerns.shape)
desc_op = desc.owner.op desc_op = desc.owner.op
out_shp = GpuDnnConv.get_out_shape(img.shape, kerns.shape, out_shp = GpuDnnConv.get_out_shape(img.shape, kerns.shape,
desc_op.border_mode, desc_op.border_mode,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论