提交 a082d96d authored 作者: Ian Goodfellow's avatar Ian Goodfellow

more to scalar dtype fix

上级 481b6632
...@@ -449,7 +449,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -449,7 +449,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
# g_cost may be Disconnected or NullType. A creative use of the function, # g_cost may be Disconnected or NullType. A creative use of the function,
# sure, but nonetheless one we can and should support. So before we try # sure, but nonetheless one we can and should support. So before we try
# to cast it make sure it even has a dtype # to cast it make sure it even has a dtype
if hasattr(g_cost.type, 'dtype') and cost.dtype not in tensor.discrete_dtypes: if hasattr(g_cost.type, 'dtype') and cost.type.dtype not in tensor.discrete_dtypes:
# Here we enforce the constraint that floating point variables have # Here we enforce the constraint that floating point variables have
# the same dtype as their gradient. # the same dtype as their gradient.
g_cost = g_cost.astype(cost.type.dtype) g_cost = g_cost.astype(cost.type.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论