提交 d3cb3ad4 authored 作者: Boris Fomitchev's avatar Boris Fomitchev 提交者: notoraptor

Fixed 3d convolution params

上级 35535b19
......@@ -857,7 +857,6 @@ class GpuDnnConvGradI(DnnBase):
if algo is None:
algo = config.dnn.conv.algo_bwd_data
self.algo = algo
assert cudnn.cudnnConvolutionBwdDataAlgo_t.has_alias(self.algo) or self.algo in SUPPORTED_DNN_CONV_ALGO_RUNTIME
self.conv_algo = cudnn.cudnnConvolutionBwdDataAlgo_t.CUDNN_CONVOLUTION_BWD_DATA_ALGO_0
......@@ -1205,13 +1204,13 @@ def dnn_gradweight(img, topgrad, kerns_shp, border_mode='valid',
def dnn_gradweight3d(img, topgrad, kerns_shp, border_mode='valid',
subsample=(1, 1, 1), dilation=(1, 1, 1), conv_mode='conv',
precision=None, num_groups=1):
precision=None, algo=None, num_groups=1):
"""
3d version of dnn_gradweight
"""
return dnn_gradweight(img, topgrad, kerns_shp, border_mode,
subsample, dilation, conv_mode, precision,
num_groups)
algo, num_groups)
def dnn_gradinput(kerns, topgrad, img_shp, border_mode='valid',
......@@ -1237,12 +1236,12 @@ def dnn_gradinput(kerns, topgrad, img_shp, border_mode='valid',
def dnn_gradinput3d(kerns, topgrad, img_shp, border_mode='valid',
subsample=(1, 1, 1), dilation=(1, 1, 1), conv_mode='conv',
precision=None, num_groups=1):
precision=None, algo=None, num_groups=1):
"""
3d version of `dnn_gradinput`.
"""
return dnn_gradinput(kerns, topgrad, img_shp, border_mode, subsample,
dilation, conv_mode, precision,
dilation, conv_mode, precision, algo,
num_groups)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论