提交 72ed94f5 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

factored out the _is_zero logic

上级 7681f4c7
...@@ -929,14 +929,9 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -929,14 +929,9 @@ def _populate_grad_dict(var_to_node_to_idx,
# it's not undefined or disconnected # it's not undefined or disconnected
# The only other valid thing it can be is 0 # The only other valid thing it can be is 0
no_constant_value = True is_zero = _is_zero(term)
try: assert is_zero in ['yes', 'no', 'maybe']:
constant_value = theano.get_constant_value(term) if is_zero == 'maybe':
no_constant_value = False
except TypeError:
pass
if no_constant_value:
msg = "%s.grad returned %s of type %s for input" msg = "%s.grad returned %s of type %s for input"
msg += " %d. This input's only connections to " msg += " %d. This input's only connections to "
msg += "the cost through this op are via " msg += "the cost through this op are via "
...@@ -950,8 +945,7 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -950,8 +945,7 @@ def _populate_grad_dict(var_to_node_to_idx,
msg = msg % (str(node.op), str(term), msg = msg % (str(node.op), str(term),
str(type(term)), i) str(type(term)), i)
raise ValueError(msg) if is_zero == 'no':
if constant_value != 0:
msg = "%s.grad returned %s of type %s for input" msg = "%s.grad returned %s of type %s for input"
msg += " %d. Since this input is only connected " msg += " %d. Since this input is only connected "
msg += "to integer-valued outputs, it should " msg += "to integer-valued outputs, it should "
...@@ -959,7 +953,7 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -959,7 +953,7 @@ def _populate_grad_dict(var_to_node_to_idx,
msg += "%s." msg += "%s."
msg % (str(node.op), str(term), str(type(term)), msg % (str(node.op), str(term), str(type(term)),
i, str(constant_value)) i, str(theano.get_constant_value(term)))
raise ValueError(msg) raise ValueError(msg)
...@@ -1688,3 +1682,30 @@ def hessian(cost, wrt, consider_constant=None, ...@@ -1688,3 +1682,30 @@ def hessian(cost, wrt, consider_constant=None,
"script that generated the error)") "script that generated the error)")
hessians.append(hess) hessians.append(hess)
return format_as(using_list, using_tuple, hessians) return format_as(using_list, using_tuple, hessians)
def _is_zero(x):
"""
Returns 'yes', 'no', or 'maybe' indicating whether x
is always 0.
'maybe' means that x is an expression that is complicated enough
that we can't tell that it simplifies to 0.
"""
if not hasattr(x, 'type'):
return np.all(x == 0.)
if isinstance(term.type, NullType):
return 'no'
if isinstance(term.type, DisconnectedType):
return 'yes'
no_constant_value = True
try:
constant_value = theano.get_constant_value(term)
no_constant_value = False
except TypeError:
pass
if no_constant_value:
return 'maybe'
if constant_value != 0.:
return no
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论