提交 cccef96c authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: fix comments and remove assertion

上级 a3fc110c
......@@ -3425,29 +3425,26 @@ class BatchedDot(Op):
gz, = grads
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
# grad is a batch of scalars, so x is a batch of vectors and y is a batch of vectors
# grad is a vector, so x is a matrix and y is a matrix
if gdim == 1:
xgrad = gz.dimshuffle(0, 'x') * y
ygrad = gz.dimshuffle(0, 'x') * x
# x is a batch of vectors, y is a batch of matrices, grad is a batch of vectors
# x is a matrix, y is a tensor3, grad is a matrix
elif xdim == 2 and ydim == 3:
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = x.dimshuffle(0, 1, 'x') * gz.dimshuffle(0, 'x', 1)
# x is a batch of matrices, y is a batch of vectors, grad is a batch of vectors
# x is a tensor3, y is a matrix, grad is a matrix
elif xdim == 3 and ydim == 2:
xgrad = gz.dimshuffle(0, 1, 'x') * y.dimshuffle(0, 'x', 1)
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
# x is a batch of matrices, y is a batch of matrices, grad is a batch of matrices
# x is a tensor3, y is a tensor3, grad is a tensor3
elif xdim == ydim == 3:
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
else:
assert False
# If x or y contain broadcastable dimensions but only one of
# them know that a matching dimensions is broadcastable, the
# above code don't always return the right broadcast pattern.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论