提交 5e6820a6 authored 作者: --global's avatar --global

Update GpuDnnConv3dGradI to use new config flags

上级 c586d80b
......@@ -994,15 +994,25 @@ class GpuDnnConv3dGradI(GpuDnnConvGradI):
:param descr: the convolution descriptor
"""
__props__ = ('inplace',)
__props__ = ('algo', 'inplace',)
__input_name__ = ('kernel', 'grad', 'output', 'descriptor', 'alpha', 'beta')
def __init__(self, inplace=False, workmem=None):
### deterministic (default value) is not yet supported for conv3d
if workmem == None:
workmem = 'none'
super(GpuDnnConv3dGradI, self).__init__(inplace, workmem)
assert self.workmem in ['none', 'time', 'guess', 'guess_once']
def __init__(self, inplace=False, workmem=None, algo=None):
"""
:param workmem: *deprecated*, use param algo instead
:param algo: either 'none', 'guess_once', 'guess_on_shape_change',
'time_once' or 'time_on_shape_change'.
Default is the value of :attr:`config.dnn.conv.algo_bwd.
"""
if workmem is not None:
warnings.warn(("GpuDnnConv3dGradI: parameter 'workmem' is "
"deprecated. Use 'algo' instead."), stacklevel=3)
assert algo == None
algo = workmem
super(GpuDnnConv3dGradI, self).__init__(inplace=inplace,
algo="guess_once")
assert self.algo in ['none', 'guess_once', 'guess_on_shape_change']
def grad(self, inp, grads):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论