提交 8476ebe8 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

R operator of Dot Op.

上级 a0048ab1
...@@ -5102,6 +5102,16 @@ class Dot(Op): ...@@ -5102,6 +5102,16 @@ class Dot(Op):
rval = dot(gz, y.T), dot(x.T, gz) rval = dot(gz, y.T), dot(x.T, gz)
return cast(rval[0], x.dtype), cast(rval[1], y.dtype) return cast(rval[0], x.dtype), cast(rval[1], y.dtype)
def R_op(self, inputs, eval_points):
# R_op for a \dot b evaluted at c for a and d for b is
# simply c \dot b + a \dot d
if None in eval_points:
return [None]
t1 = self.make_node(eval_points[0], inputs[1]).outputs[0]
t2 = self.make_node(inputs[0], eval_points[1]).outputs[0]
return [t1+t2]
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
xshp, yshp = shapes xshp, yshp = shapes
x, y = node.inputs x, y = node.inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论