提交 fdfcd78e authored 作者: Ian Goodfellow's avatar Ian Goodfellow

cleanup of new features

上级 6217ef05
...@@ -13,6 +13,7 @@ import warnings ...@@ -13,6 +13,7 @@ import warnings
_logger = logging.getLogger('theano.gradient') _logger = logging.getLogger('theano.gradient')
import numpy # for numeric_grad import numpy # for numeric_grad
np = numpy
import theano import theano
...@@ -551,7 +552,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -551,7 +552,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
grad_dict[elem] = DisconnectedType()() grad_dict[elem] = DisconnectedType()()
cost_name = None cost_name = None
if add_names: if add_names and cost is not None:
cost_name = cost.name cost_name = cost.name
# Make sure we didn't initialize the grad_dict with any ints # Make sure we didn't initialize the grad_dict with any ints
...@@ -930,7 +931,7 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -930,7 +931,7 @@ def _populate_grad_dict(var_to_node_to_idx,
# The only other valid thing it can be is 0 # The only other valid thing it can be is 0
is_zero = _is_zero(term) is_zero = _is_zero(term)
assert is_zero in ['yes', 'no', 'maybe']: assert is_zero in ['yes', 'no', 'maybe']
if is_zero == 'maybe': if is_zero == 'maybe':
msg = "%s.grad returned %s of type %s for input" msg = "%s.grad returned %s of type %s for input"
msg += " %d. This input's only connections to " msg += " %d. This input's only connections to "
...@@ -1585,14 +1586,14 @@ def _is_zero(x): ...@@ -1585,14 +1586,14 @@ def _is_zero(x):
""" """
if not hasattr(x, 'type'): if not hasattr(x, 'type'):
return np.all(x == 0.) return np.all(x == 0.)
if isinstance(term.type, NullType): if isinstance(x.type, NullType):
return 'no' return 'no'
if isinstance(term.type, DisconnectedType): if isinstance(x.type, DisconnectedType):
return 'yes' return 'yes'
no_constant_value = True no_constant_value = True
try: try:
constant_value = theano.get_constant_value(term) constant_value = theano.get_constant_value(x)
no_constant_value = False no_constant_value = False
except TypeError: except TypeError:
pass pass
...@@ -1601,4 +1602,6 @@ def _is_zero(x): ...@@ -1601,4 +1602,6 @@ def _is_zero(x):
return 'maybe' return 'maybe'
if constant_value != 0.: if constant_value != 0.:
return no return 'no'
return 'yes'
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论