提交 5e6820a6 authored 作者: --global's avatar --global

Update GpuDnnConv3dGradI to use new config flags

上级 c586d80b
...@@ -994,15 +994,25 @@ class GpuDnnConv3dGradI(GpuDnnConvGradI): ...@@ -994,15 +994,25 @@ class GpuDnnConv3dGradI(GpuDnnConvGradI):
:param descr: the convolution descriptor :param descr: the convolution descriptor
""" """
__props__ = ('inplace',) __props__ = ('algo', 'inplace',)
__input_name__ = ('kernel', 'grad', 'output', 'descriptor', 'alpha', 'beta') __input_name__ = ('kernel', 'grad', 'output', 'descriptor', 'alpha', 'beta')
def __init__(self, inplace=False, workmem=None): def __init__(self, inplace=False, workmem=None, algo=None):
### deterministic (default value) is not yet supported for conv3d """
if workmem == None: :param workmem: *deprecated*, use param algo instead
workmem = 'none' :param algo: either 'none', 'guess_once', 'guess_on_shape_change',
super(GpuDnnConv3dGradI, self).__init__(inplace, workmem) 'time_once' or 'time_on_shape_change'.
assert self.workmem in ['none', 'time', 'guess', 'guess_once'] Default is the value of :attr:`config.dnn.conv.algo_bwd.
"""
if workmem is not None:
warnings.warn(("GpuDnnConv3dGradI: parameter 'workmem' is "
"deprecated. Use 'algo' instead."), stacklevel=3)
assert algo == None
algo = workmem
super(GpuDnnConv3dGradI, self).__init__(inplace=inplace,
algo="guess_once")
assert self.algo in ['none', 'guess_once', 'guess_on_shape_change']
def grad(self, inp, grads): def grad(self, inp, grads):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论