提交 df5c67f8 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

add some type checking

上级 2196f4a4
...@@ -446,11 +446,21 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -446,11 +446,21 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
raise TypeError('Elements of consider_constant must be ' raise TypeError('Elements of consider_constant must be '
'variables, but got ' + str(type(elem))) 'variables, but got ' + str(type(elem)))
if isinstance(wrt, set):
raise TypeError("wrt must not be a set. sets have no defined "
"iteration order, so we can't return gradients in a matching"
" order.")
using_list = isinstance(wrt, list) using_list = isinstance(wrt, list)
using_tuple = isinstance(wrt, tuple) using_tuple = isinstance(wrt, tuple)
if not using_list and not using_tuple: if not using_list and not using_tuple:
wrt = [wrt] wrt = [wrt]
for elem in wrt:
if not isinstance(elem, Variable):
raise TypeError("Expected Variable, got " + str(elem) +
" of type "+str(type(elem)))
var_to_node_to_idx = _populate_var_to_node_to_idx([cost], wrt) var_to_node_to_idx = _populate_var_to_node_to_idx([cost], wrt)
# build a dict mapping var to the gradient of cost with respect to var # build a dict mapping var to the gradient of cost with respect to var
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论