提交 f7101c41 authored 作者: briancheung's avatar briancheung

Added small gradient calculation for batched_dot

上级 85159185
......@@ -199,6 +199,20 @@ class BatchedDotOp(GpuOp):
} while (0)
"""
def grad(self, inp, grads):
x, y = inp
gz, = grads
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
rval = xgrad, ygrad
for elem in rval:
assert elem.dtype.find('float') != -1
return rval
batched_dot = BatchedDotOp()
class GpuDot22(GpuOp):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论