提交 f02b01d4 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

got rid of warn_type flag--the new type system doesn't enforce that a

gradient and an input have the same type, so the warning makes no sense
上级 3625c701
......@@ -478,8 +478,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, warn_type=False,
cost_name = cost.name
rval = _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, warn_type,
cost_name)
grad_dict, wrt, cost_name)
for i in xrange(len(rval)):
if isinstance(rval[i].type, DisconnectedType):
......@@ -632,7 +631,7 @@ def _populate_var_to_node_to_idx(outputs, wrt):
def _populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, warn_type, cost_name=None):
grad_dict, wrt, cost_name=None):
"""
Common code shared between grad_sources_inputs and grad
......@@ -765,32 +764,17 @@ def _populate_grad_dict(var_to_node_to_idx,
msg += ' connection_pattern method for it.'
warnings.warn(msg)
# Process out any Nones
for i, term in enumerate(input_grads):
# Disallow Nones
if term is None:
# we don't know what None means. in the past it has been
# used to
# mean undefined, zero, or disconnected. So for now we
# assume it is
# zero. Assuming it is zero prevents
# us from disconnecting NaNs above.
# eventually we should disallow this
# return type and force all ops
# to return the correct thing
#raise AssertionError(('%s returned None for' +\
# ' a gradient term, '
# 'this is prohibited') % node.op)
input_grads[i] = node.inputs[i].zeros_like()
if warn_type:
g_r_type = term_dict[node][i].type
r_type = inputs[i].type
if g_r_type != r_type:
_logger.warning(
'%s.grad returned a different type (%s) '
'for input %i of type (%s)',
node.op, g_r_type, i, r_type)
# We don't know what None means. in the past it has been
# used to mean undefined, zero, or disconnected.
# We therefore don't allow it because its usage has become
# so muddied.
raise TypeError(('%s returned None for' +\
' a gradient term, '
'this is prohibited') % node.op)
#cache the result
term_dict[node] = input_grads
......@@ -845,7 +829,7 @@ def _populate_grad_dict(var_to_node_to_idx,
return rval
def grad_sources_inputs(sources, graph_inputs, warn_type=True):
def grad_sources_inputs(sources, graph_inputs):
"""
Used to compute the gradient of a cost with respect to all the
variables between graph_input and cost, but in the special
......@@ -889,10 +873,6 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
:type graph_inputs: list of Variable
:param graph_inputs: variables considered to be constant
(do not backpropagate through them)
:type warn_type: bool
:param warn_type: True will trigger warnings via the logging module when
the gradient on an expression has a different type than the original
expression
:rtype: dictionary whose keys and values are of type Variable
:return: mapping from each Variable encountered in the backward
......@@ -934,7 +914,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
grad_dict[elem] = DisconnectedType()()
_populate_grad_dict(var_to_node_to_idx,
grad_dict, wrt, warn_type)
grad_dict, wrt)
# post-process out the DisconnectedTypes
for key in grad_dict:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论