提交 05dc20de authored 作者: --global's avatar --global

Ensure direction hint has valid value

上级 564cd12b
......@@ -975,6 +975,10 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
capability of 3.0 or higer. This means that older GPU will not
work with this Op.
"""
# Ensure the value of direction_hint is supported
assert direction_hint in [None, 'bprop weights', 'forward']
fgraph = getattr(img, 'fgraph', None) or getattr(kerns, 'fgraph', None)
if (border_mode == 'valid' and subsample == (1, 1) and
direction_hint == 'bprop weights'):
......@@ -1059,6 +1063,10 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
:warning: dnn_conv3d only works with cuDNN library 3.0
"""
# Ensure the value of direction_hint is supported
assert direction_hint in [None, 'bprop weights', 'forward']
fgraph = getattr(img, 'fgraph', None) or getattr(kerns, 'fgraph', None)
if (border_mode == 'valid' and subsample == (1, 1, 1) and
direction_hint == 'bprop weights'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论