提交 18d6bd2f authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added check that gradient of int is 0

上级 4699d448
...@@ -741,6 +741,18 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -741,6 +741,18 @@ def _populate_grad_dict(var_to_node_to_idx,
# Do type checking on the result # Do type checking on the result
#List of bools indicating if each output is an integer dtype
output_is_int = [ hasattr(output.type,'dtype') and
output.type.dtype.find('int') != -1
for output in node.outputs]
#List of bools indicating if each input only has integer outputs
only_connected_to_int = [ (True not in
[ in_to_out and out_to_cost and not out_int
for in_to_out, out_to_cost, out_int in
zip(in_to_outs, outputs_connected, output_is_int) ])
for in_to_outs in connection_pattern ]
for i, term in enumerate(input_grads): for i, term in enumerate(input_grads):
# Disallow Nones # Disallow Nones
...@@ -755,14 +767,35 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -755,14 +767,35 @@ def _populate_grad_dict(var_to_node_to_idx,
if not isinstance(term.type, if not isinstance(term.type,
(NullType,DisconnectedType)): (NullType,DisconnectedType)):
if not hasattr(term.type,'dtype'):
print term
assert False
if term.type.dtype.find('float') == -1: if term.type.dtype.find('float') == -1:
raise TypeError(str(node.op)+'.grad illegally ' raise TypeError(str(node.op)+'.grad illegally '
' returned an integer-valued variable.' ' returned an integer-valued variable.'
' (Input index %d, dtype %s)' % (i, ' (Input index %d, dtype %s)' % (i,
term.type.dtype)) term.type.dtype))
if only_connected_to_int[i]:
# This term has only integer outputs and we know
# it's not undefined or disconnected
# The only other valid thing it can be is 0
is_zero = False
try:
if tensor.get_constant_value(term) == 0:
is_zero = True
except:
pass
if not is_zero:
msg = "%s.grad returned %s of type %s for input"
msg += " %d. This input is only connected to "
msg += "integer-valued outputs so it should be "
msg += "NullType, DisconnectedType, or some form "
msg += "of zeros."
msg = msg % (str(node.op), str(term),
str(type(term)), i)
raise ValueError(msg)
#Check that op.connection_pattern matches the connectivity #Check that op.connection_pattern matches the connectivity
#logic driving the op.grad method #logic driving the op.grad method
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论