提交 50acd143 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

better type checking in grad function

上级 99773d3b
......@@ -776,8 +776,42 @@ def _populate_grad_dict(var_to_node_to_idx,
input_to_outputs in connection_pattern
]
if True in inputs_connected:
# At least one input of this op is connected to the cost so we must
#List of bools indicating if each output is an integer dtype
output_is_int = [hasattr(output.type, 'dtype') and
output.type.dtype in theano.tensor.discrete_dtypes
for output in node.outputs]
#List of bools indicating if each output is NullType
ograd_is_nan = [isinstance(output.type, NullType)
for output in output_grads]
# List of bools indicating if each input only has NullType outputs
only_connected_to_nan = [(True not in
[in_to_out and out_to_cost and not out_nan
for in_to_out, out_to_cost, out_nan in
zip(in_to_outs, outputs_connected, ograd_is_nan)])
for in_to_outs in connection_pattern]
if True not in inputs_connected:
# All outputs of this op are disconnected so we can skip
# Calling the op's grad method and report that the inputs
# are disconnected
# (The op's grad method could do this too, but this saves the
# implementer the trouble of worrying about this case)
input_grads = [DisconnectedType()() for ipt in inputs]
elif False not in only_connected_to_nan:
# All inputs are only connected to nan gradients, so we don't
# need to bother calling the grad method. We know the gradient
# with respect to all connected inputs is nan.
input_grads = []
for connected in inputs_connected:
if connected:
input_grads.append(NullType()())
else:
input_grads.append(DisconnectedType()())
else:
# At least one input of this op is connected to the cost so and
# not all output gradients are undefined so we must
# call the op's grad method
# Each Op's grad function requires inputs and output_grads
......@@ -848,13 +882,6 @@ def _populate_grad_dict(var_to_node_to_idx,
if len(input_grads) != len(inputs):
raise ValueError(("%s returned the wrong number of" +\
" gradient terms.") % str(node.op))
else:
# All outputs of this op are disconnected so we can skip
# Calling the op's grad method and report that the inputs
# are disconnected
# (The op's grad method could do this too, but this saves the
# implementer the trouble of worrying about this case)
input_grads = [DisconnectedType()() for ipt in inputs]
# must convert to list in case the op returns a tuple
# we won't be able to post-process out the Nones if it does that
......@@ -862,18 +889,15 @@ def _populate_grad_dict(var_to_node_to_idx,
# Do type checking on the result
#List of bools indicating if each output is an integer dtype
output_is_int = [hasattr(output.type, 'dtype') and
output.type.dtype in theano.tensor.discrete_dtypes
for output in node.outputs]
#List of bools indicating if each input only has integer outputs
# List of bools indicating if each input only has integer outputs
only_connected_to_int = [(True not in
[in_to_out and out_to_cost and not out_int
for in_to_out, out_to_cost, out_int in
zip(in_to_outs, outputs_connected, output_is_int)])
for in_to_outs in connection_pattern]
for i, term in enumerate(input_grads):
# Disallow Nones
......@@ -898,6 +922,10 @@ def _populate_grad_dict(var_to_node_to_idx,
' returned an integer-valued variable.'
' (Input index %d, dtype %s)' % (i,
term.type.dtype))
if only_connected_to_nan[i]:
assert isinstance(term.type, NullType)
if only_connected_to_int[i]:
# This term has only integer outputs and we know
# it's not undefined or disconnected
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论