提交 5c5cc7ca authored 作者: Ian Goodfellow's avatar Ian Goodfellow

make grad method enforce right number of dimensions

上级 bd861e48
......@@ -873,6 +873,7 @@ def _populate_grad_dict(var_to_node_to_idx,
# populate grad_dict[var] and return it
def access_grad_cache(var):
if var not in grad_dict:
# If var is not in grad_dict already, we must compute it
if var in var_to_node_to_idx:
terms = []
node_to_idx = var_to_node_to_idx[var]
......@@ -895,6 +896,11 @@ def _populate_grad_dict(var_to_node_to_idx,
if isinstance(term.type, DisconnectedType):
continue
if hasattr(var,'ndim') and term.ndim != var.ndim:
raise ValueError(("%s.grad returned a term with"
" %d dimensions, but %d are required.") % (
str(node.op), term.ndim, var.ndim))
terms.append(term)
# Add up the terms to get the total gradient on this variable
......@@ -911,6 +917,7 @@ def _populate_grad_dict(var_to_node_to_idx,
# this variable isn't connected to the cost in the computational
# graph
grad_dict[var] = DisconnectedType()()
# end if cache miss
return grad_dict[var]
rval = [access_grad_cache(elem) for elem in wrt]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论