提交 72ed94f5 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

factored out the _is_zero logic

上级 7681f4c7
......@@ -929,14 +929,9 @@ def _populate_grad_dict(var_to_node_to_idx,
# it's not undefined or disconnected
# The only other valid thing it can be is 0
no_constant_value = True
try:
constant_value = theano.get_constant_value(term)
no_constant_value = False
except TypeError:
pass
if no_constant_value:
is_zero = _is_zero(term)
assert is_zero in ['yes', 'no', 'maybe']:
if is_zero == 'maybe':
msg = "%s.grad returned %s of type %s for input"
msg += " %d. This input's only connections to "
msg += "the cost through this op are via "
......@@ -950,8 +945,7 @@ def _populate_grad_dict(var_to_node_to_idx,
msg = msg % (str(node.op), str(term),
str(type(term)), i)
raise ValueError(msg)
if constant_value != 0:
if is_zero == 'no':
msg = "%s.grad returned %s of type %s for input"
msg += " %d. Since this input is only connected "
msg += "to integer-valued outputs, it should "
......@@ -959,7 +953,7 @@ def _populate_grad_dict(var_to_node_to_idx,
msg += "%s."
msg % (str(node.op), str(term), str(type(term)),
i, str(constant_value))
i, str(theano.get_constant_value(term)))
raise ValueError(msg)
......@@ -1688,3 +1682,30 @@ def hessian(cost, wrt, consider_constant=None,
"script that generated the error)")
hessians.append(hess)
return format_as(using_list, using_tuple, hessians)
def _is_zero(x):
"""
Returns 'yes', 'no', or 'maybe' indicating whether x
is always 0.
'maybe' means that x is an expression that is complicated enough
that we can't tell that it simplifies to 0.
"""
if not hasattr(x, 'type'):
return np.all(x == 0.)
if isinstance(term.type, NullType):
return 'no'
if isinstance(term.type, DisconnectedType):
return 'yes'
no_constant_value = True
try:
constant_value = theano.get_constant_value(term)
no_constant_value = False
except TypeError:
pass
if no_constant_value:
return 'maybe'
if constant_value != 0.:
return no
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论