提交 c012bcdf authored 作者: affanv14's avatar affanv14

add condition to check that convolution is not 3D

上级 4c443363
...@@ -565,6 +565,8 @@ class GpuDnnConv(DnnBase): ...@@ -565,6 +565,8 @@ class GpuDnnConv(DnnBase):
SUPPORTED_DNN_CONV_ALGO_RUNTIME): SUPPORTED_DNN_CONV_ALGO_RUNTIME):
raise ValueError("convolution algo %s can't be used for " raise ValueError("convolution algo %s can't be used for "
"3d convolutions", (self.algo,)) "3d convolutions", (self.algo,))
if img.type.ndim == 5 and self.num_groups != 1:
raise ValueError("Grouped convolutions not implemented for 3D convolutions")
if (not isinstance(desc.type, CDataType) or if (not isinstance(desc.type, CDataType) or
desc.type.ctype != 'cudnnConvolutionDescriptor_t'): desc.type.ctype != 'cudnnConvolutionDescriptor_t'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论