提交 532297aa authored 作者: James Bergstra's avatar James Bergstra

type checking on tensor grad

上级 70036906
...@@ -1342,6 +1342,9 @@ def grad(cost, wrt, g_cost=None): ...@@ -1342,6 +1342,9 @@ def grad(cost, wrt, g_cost=None):
kind of zero is returned. kind of zero is returned.
""" """
if not isinstance(cost, TensorResult):
raise TypeError('In tensor.grad(), cost argument should be a TensorResult.', cost)
if g_cost is None: if g_cost is None:
g_cost = ones_like(cost) g_cost = ones_like(cost)
inputs = gof.graph.inputs([cost]) inputs = gof.graph.inputs([cost])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论