提交 fffe6d61 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix assertion

上级 b82aee77
...@@ -458,7 +458,8 @@ def grad(cost, wrt, consider_constant=None, ...@@ -458,7 +458,8 @@ def grad(cost, wrt, consider_constant=None,
g_cost = g_cost.astype(cost.type.dtype) g_cost = g_cost.astype(cost.type.dtype)
# DO NOT enforce g_cost to be 0 if cost is an integer. # DO NOT enforce g_cost to be 0 if cost is an integer.
# This is to be enforced by the Op.grad method for the Op that outputs cost. # This is to be enforced by the Op.grad method for the Op that outputs cost.
assert g_cost not in tensor.discrete_dtypes if hasattr(g_cost.type, 'dtype'):
assert g_cost.type.dtype not in tensor.discrete_dtypes
grad_dict[cost] = g_cost grad_dict[cost] = g_cost
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论