提交 9abe4125 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add assertion after code review

上级 41b83184
......@@ -795,6 +795,17 @@ def _populate_grad_dict(var_to_node_to_idx,
else:
new_output_grads.append(og)
# Make sure that, if new_output_grads[i] has a dtype:
# - it is the same dtype as outputs[i]
# - if the dtype is an int, then new_output_grads[i] is 0.
for o, ng in zip(node.outputs, new_output_grads):
o_dt = getattr(o.type, 'dtype', None)
ng_dt = getattr(ng.type, 'dtype', None)
if ng_dt:
assert ng_dt == o_dt
if ng_dt in theano.tensor.discrete_dtypes:
assert theano.get_constant_value(ng) == 0
input_grads = node.op.grad(inputs, new_output_grads)
if input_grads is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论