提交 d3cb3ad4 authored 作者: Boris Fomitchev's avatar Boris Fomitchev 提交者: notoraptor

Fixed 3d convolution params

上级 35535b19
...@@ -857,7 +857,6 @@ class GpuDnnConvGradI(DnnBase): ...@@ -857,7 +857,6 @@ class GpuDnnConvGradI(DnnBase):
if algo is None: if algo is None:
algo = config.dnn.conv.algo_bwd_data algo = config.dnn.conv.algo_bwd_data
self.algo = algo self.algo = algo
assert cudnn.cudnnConvolutionBwdDataAlgo_t.has_alias(self.algo) or self.algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME assert cudnn.cudnnConvolutionBwdDataAlgo_t.has_alias(self.algo) or self.algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME
self.conv_algo = cudnn.cudnnConvolutionBwdDataAlgo_t.CUDNN_CONVOLUTION_BWD_DATA_ALGO_0 self.conv_algo = cudnn.cudnnConvolutionBwdDataAlgo_t.CUDNN_CONVOLUTION_BWD_DATA_ALGO_0
...@@ -1205,13 +1204,13 @@ def dnn_gradweight(img, topgrad, kerns_shp, border_mode='valid', ...@@ -1205,13 +1204,13 @@ def dnn_gradweight(img, topgrad, kerns_shp, border_mode='valid',
def dnn_gradweight3d(img, topgrad, kerns_shp, border_mode='valid', def dnn_gradweight3d(img, topgrad, kerns_shp, border_mode='valid',
subsample=(1, 1, 1), dilation=(1, 1, 1), conv_mode='conv', subsample=(1, 1, 1), dilation=(1, 1, 1), conv_mode='conv',
precision=None, num_groups=1): precision=None, algo=None, num_groups=1):
""" """
3d version of dnn_gradweight 3d version of dnn_gradweight
""" """
return dnn_gradweight(img, topgrad, kerns_shp, border_mode, return dnn_gradweight(img, topgrad, kerns_shp, border_mode,
subsample, dilation, conv_mode, precision, subsample, dilation, conv_mode, precision,
num_groups) algo, num_groups)
def dnn_gradinput(kerns, topgrad, img_shp, border_mode='valid', def dnn_gradinput(kerns, topgrad, img_shp, border_mode='valid',
...@@ -1237,12 +1236,12 @@ def dnn_gradinput(kerns, topgrad, img_shp, border_mode='valid', ...@@ -1237,12 +1236,12 @@ def dnn_gradinput(kerns, topgrad, img_shp, border_mode='valid',
def dnn_gradinput3d(kerns, topgrad, img_shp, border_mode='valid', def dnn_gradinput3d(kerns, topgrad, img_shp, border_mode='valid',
subsample=(1, 1, 1), dilation=(1, 1, 1), conv_mode='conv', subsample=(1, 1, 1), dilation=(1, 1, 1), conv_mode='conv',
precision=None, num_groups=1): precision=None, algo=None, num_groups=1):
""" """
3d version of `dnn_gradinput`. 3d version of `dnn_gradinput`.
""" """
return dnn_gradinput(kerns, topgrad, img_shp, border_mode, subsample, return dnn_gradinput(kerns, topgrad, img_shp, border_mode, subsample,
dilation, conv_mode, precision, dilation, conv_mode, precision, algo,
num_groups) num_groups)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论