提交 05dc20de authored 作者: --global's avatar --global

Ensure direction hint has valid value

上级 564cd12b
...@@ -975,6 +975,10 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -975,6 +975,10 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
capability of 3.0 or higer. This means that older GPU will not capability of 3.0 or higer. This means that older GPU will not
work with this Op. work with this Op.
""" """
# Ensure the value of direction_hint is supported
assert direction_hint in [None, 'bprop weights', 'forward']
fgraph = getattr(img, 'fgraph', None) or getattr(kerns, 'fgraph', None) fgraph = getattr(img, 'fgraph', None) or getattr(kerns, 'fgraph', None)
if (border_mode == 'valid' and subsample == (1, 1) and if (border_mode == 'valid' and subsample == (1, 1) and
direction_hint == 'bprop weights'): direction_hint == 'bprop weights'):
...@@ -1059,6 +1063,10 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), ...@@ -1059,6 +1063,10 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
:warning: dnn_conv3d only works with cuDNN library 3.0 :warning: dnn_conv3d only works with cuDNN library 3.0
""" """
# Ensure the value of direction_hint is supported
assert direction_hint in [None, 'bprop weights', 'forward']
fgraph = getattr(img, 'fgraph', None) or getattr(kerns, 'fgraph', None) fgraph = getattr(img, 'fgraph', None) or getattr(kerns, 'fgraph', None)
if (border_mode == 'valid' and subsample == (1, 1, 1) and if (border_mode == 'valid' and subsample == (1, 1, 1) and
direction_hint == 'bprop weights'): direction_hint == 'bprop weights'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论