提交 bf67c08e authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Added r operator for ScalarFromTensor as well as raising and error if the

output is not differentiable.
上级 e5afb22c
...@@ -1583,6 +1583,10 @@ class ScalarFromTensor(Op): ...@@ -1583,6 +1583,10 @@ class ScalarFromTensor(Op):
s, = inp s, = inp
dt, = grads dt, = grads
return [tensor_from_scalar(dt)] return [tensor_from_scalar(dt)]
def R_op(self, inputs, eval_points):
return None
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
...@@ -5329,6 +5333,9 @@ def Rop(f, wrt, eval_points): ...@@ -5329,6 +5333,9 @@ def Rop(f, wrt, eval_points):
for out in f: for out in f:
if out in wrt: if out in wrt:
rval.append( eval_points[wrt.index(out)]) rval.append( eval_points[wrt.index(out)])
elif seen_nodes[out.owner][out.owner.outputs.index(out)] is None:
raise ValueError(( 'The function is not differentiable with '
'respect to the provided inputs !'))
else: else:
rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)] ) rval.append(seen_nodes[out.owner][out.owner.outputs.index(out)] )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论