提交 cc8c9e4c authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: fix grad

上级 2202eb41
......@@ -3425,26 +3425,29 @@ class BatchedDot(Op):
gz, = grads
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
# grad is scalar, so x is vector and y is vector
if gdim == 0:
# grad is a batch of scalars, so x is a batch of vectors and y is a batch of vectors
if gdim == 1:
xgrad = gz.dimshuffle(0, 'x') * y
ygrad = gz.dimshuffle(0, 'x') * x
# x is vector, y is matrix, grad is vector
# x is a batch of vectors, y is a batch of matrices, grad is a batch of vectors
elif xdim == 2 and ydim == 3:
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = x.dimshuffle(0, 1, 'x') * gz.dimshuffle(0, 'x', 1)
# x is matrix, y is vector, grad is vector
# x is a batch of matrices, y is a batch of vectors, grad is a batch of vectors
elif xdim == 3 and ydim == 2:
xgrad = gz.dimshuffle(0, 1, 'x') * y.dimshuffle(0, 'x', 1)
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
# x is matrix, y is matrix, grad is matrix
# x is a batch of matrices, y is a batch of matrices, grad is a batch of matrices
elif xdim == ydim == 3:
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
else:
assert False
# If x or y contain broadcastable dimensions but only one of
# them know that a matching dimensions is broadcastable, the
# above code don't always return the right broadcast pattern.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论