提交 14b81781 authored 作者: --global's avatar --global

Update GpuDnnConv3dGradW to use new config flags

上级 319cf621
......@@ -796,6 +796,7 @@ class GpuDnnConvGradW(DnnBase, COp):
def infer_shape(self, node, shape):
return [shape[2]]
class GpuDnnConv3dGradW(GpuDnnConvGradW):
"""
The convolution gradient with respect to the weights.
......@@ -805,12 +806,25 @@ class GpuDnnConv3dGradW(GpuDnnConvGradW):
:param descr: the convolution descriptor
"""
__props__ = ('workmem', 'inplace',)
__props__ = ('algo', 'inplace',)
__input_name__ = ('image', 'grad', 'output', 'descriptor', 'alpha', 'beta')
def __init__(self, inplace=False, workmem=None):
super(GpuDnnConv3dGradW, self).__init__(inplace=inplace, workmem='none')
assert self.workmem in ['none', 'time','guess', 'guess_once']
def __init__(self, inplace=False, workmem=None, algo=None):
"""
:param workmem: *deprecated*, use param algo instead
:param algo: either 'none', 'guess_once', 'guess_on_shape_change',
'time_once' or 'time_on_shape_change'.
Default is the value of :attr:`config.dnn.conv.algo_bwd.
"""
if workmem is not None:
warnings.warn(("GpuDnnConv3dGradW: parameter 'workmem' is "
"deprecated. Use 'algo' instead."), stacklevel=3)
assert algo == None
algo = workmem
super(GpuDnnConv3dGradW, self).__init__(inplace=inplace,
algo='guess_once')
assert self.algo in ['none', 'guess_once', 'guess_on_shape_change']
def grad(self, inp, grads):
img, top, output, desc, alpha, beta = inp
......@@ -849,7 +863,6 @@ class GpuDnnConv3dGradW(GpuDnnConvGradW):
[output.type()])
class GpuDnnConvGradI(DnnBase, COp):
"""
The convolution gradient with respect to the inputs.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论