提交 9153516b authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

removed corner case, wrote grad for tensors more clearly, added more comments

上级 545d6f5c
...@@ -6943,48 +6943,67 @@ class Dot(Op): ...@@ -6943,48 +6943,67 @@ class Dot(Op):
x, y = inp x, y = inp
gz, = grads gz, = grads
xdim, ydim = x.type.ndim, y.type.ndim xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
#grad is scalar #grad is scalar, so x is scalar or vector and y is same as x
if gz.type.ndim == 0: if gdim == 0:
xgrad = gz * y xgrad = gz * y
ygrad = gz * x ygrad = gz * x
#x is scalar
#x is scalar, y is not scalar, grad.shape == y.shape
elif xdim == 0: elif xdim == 0:
xgrad = (gz * y).sum() xgrad = (gz * y).sum()
ygrad = x * gz ygrad = x * gz
#y is scalar
#x is not scalar, y is scalar, grad.shape == x.shape
elif ydim == 0: elif ydim == 0:
xgrad = y * gz xgrad = y * gz
ygrad = (gz * x).sum() ygrad = (gz * x).sum()
#x is vector, y is matrix
#x is vector, y is matrix, grad is vector
elif xdim == 1 and ydim == 2: elif xdim == 1 and ydim == 2:
xgrad = dot(gz, y.T) xgrad = dot(gz, y.T)
ygrad = outer(x.T, gz) ygrad = outer(x.T, gz)
#x is matrix, y is vector
#x is matrix, y is vector, grad is vector
elif xdim == 2 and ydim == 1: elif xdim == 2 and ydim == 1:
xgrad = outer(gz, y.T) xgrad = outer(gz, y.T)
ygrad = dot(x.T, gz) ygrad = dot(x.T, gz)
#x is matrix, y is matrix
#x is matrix, y is matrix, grad is matrix
elif xdim == ydim == 2: elif xdim == ydim == 2:
xgrad = dot(gz, y.T) xgrad = dot(gz, y.T)
ygrad = dot(x.T, gz) ygrad = dot(x.T, gz)
#x is tensor, y is vector (corner case)
elif xdim > 2 and ydim == 1: # x or y is tensor, grad is tensor
xgrad = tensordot(y, gz, 0).transpose(range(xdim)[1:] + [0]) #
ygrad = tensordot(x, gz, [range(xdim - 1)] * 2) # the grad has the same dim as the dot product output, namely
#x or y is tensor # (x.shape[:-1] + y.shape[:-2] + [y.shape[-1]]). To get the grad
# wrt x or y, a tensordot is used to sum out non-compatible dims.
#
# for grad x:
# gradient is a tensordot over y and grad, summing out all but
# the second-to-last dim of y and transposed such that the first
# resulting dim goes last.
#
# for grad y:
# gradient is a tensordot over x and grad, summing out all but
# the last dim of x and transposed such that the first resulting
# dim goes last.
else: else:
sum0, sum1 = range(xdim), range(xdim - 1) x_axes0 = range(ydim-2)
sum0.pop(-1) x_axes1 = range(xdim - 1, gdim)
dims = range(ydim) x_tdims = range(1, xdim) + [0]
dims[-1:-1] = [dims.pop(0)] if ydim >= 2:
ygrad = tensordot(x, gz, [sum0, sum1]).transpose(dims) x_axes0 += [ydim - 1]
xgrad = tensordot(y, gz, [x_axes0, x_axes1]).transpose(x_tdims)
sum0, sum1 = range(ydim), range(xdim - 1, xdim + ydim - 2)
sum0.pop(-2) y_axes0 = range(xdim - 1)
dims = range(xdim)[1:] + [0] y_axes1 = range(xdim - 1)
xgrad = tensordot(y, gz, [sum0, sum1]).transpose(dims) y_tdims = range(1, ydim - 1) + [0]
if ydim >= 2:
y_tdims += [ydim - 1]
ygrad = tensordot(x, gz, [y_axes0, y_axes1]).transpose(y_tdims)
rval = xgrad, ygrad rval = xgrad, ygrad
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论