提交 0fa6f785 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6026 from notoraptor/re-allow-runtime-algos-cudnn-conv3d

(small fix) Re-allow runtime algorithms for conv3d Ops.
......@@ -559,7 +559,8 @@ class GpuDnnConv(DnnBase):
raise TypeError("The number of dimensions of "
"img, kern and output must match")
if img.type.ndim == 5 and self.algo not in cudnn.conv3d_fwd_algorithms:
if img.type.ndim == 5 and self.algo not in (cudnn.conv3d_fwd_algorithms +
SUPPORTED_DNN_CONV_ALGO_RUNTIME):
raise ValueError("convolution algo %s can't be used for "
"3d convolutions", (self.algo,))
......@@ -729,7 +730,8 @@ class GpuDnnConvGradW(DnnBase):
raise TypeError("The number of dimensions of "
"img, topgrad and output must match")
if img.type.ndim == 5 and self.algo not in cudnn.conv3d_bwd_filter_algorithms:
if img.type.ndim == 5 and self.algo not in (cudnn.conv3d_bwd_filter_algorithms +
SUPPORTED_DNN_CONV_ALGO_RUNTIME):
raise ValueError("convolution algo %s can't be used for "
"3d convolutions", (self.algo,))
......@@ -832,7 +834,8 @@ class GpuDnnConvGradI(DnnBase):
raise TypeError("The number of dimensions of "
"kern, topgrad and output must match")
if kern.type.ndim == 5 and self.algo not in cudnn.conv3d_bwd_data_algorithms:
if kern.type.ndim == 5 and self.algo not in (cudnn.conv3d_bwd_data_algorithms +
SUPPORTED_DNN_CONV_ALGO_RUNTIME):
raise ValueError("convolution algo %s can't be used for "
"3d convolutions", (self.algo,))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论