提交 bf67c08e authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Added r operator for ScalarFromTensor as well as raising and error if the

output is not differentiable.
上级 e5afb22c
......@@ -1583,6 +1583,10 @@ class ScalarFromTensor(Op):
s, = inp
dt, = grads
return [tensor_from_scalar(dt)]
def R_op(self, inputs, eval_points):
return None
def __str__(self):
return self.__class__.__name__
def c_code(self, node, name, inputs, outputs, sub):
......@@ -5329,6 +5333,9 @@ def Rop(f, wrt, eval_points):
for out in f:
if out in wrt:
rval.append( eval_points[wrt.index(out)])
elif seen_nodes[out.owner][out.owner.outputs.index(out)] is None:
raise ValueError(( 'The function is not differentiable with '
'respect to the provided inputs !'))
else:
rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)] )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论