提交 481b6632 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fix bug looking for dtype of scalar, when scalars only have type.dtype

上级 896cd396
......@@ -449,10 +449,10 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
# g_cost may be Disconnected or NullType. A creative use of the function,
# sure, but nonetheless one we can and should support. So before we try
# to cast it make sure it even has a dtype
if hasattr(g_cost, 'dtype') and cost.dtype not in tensor.discrete_dtypes:
if hasattr(g_cost.type, 'dtype') and cost.dtype not in tensor.discrete_dtypes:
# Here we enforce the constraint that floating point variables have
# the same dtype as their gradient.
g_cost = g_cost.astype(cost.dtype)
g_cost = g_cost.astype(cost.type.dtype)
# DO NOT enforce g_cost to be 0 if cost is an integer.
# This is to be enforced by the Op.grad method for the Op that outputs cost.
assert g_cost not in tensor.discrete_dtypes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论