提交 95e8b482 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

reversed tensordot arguments to remove transpose for xgrad. updated comments.

ygrad still requires a transpose.
上级 4beb4dea
......@@ -6977,34 +6977,33 @@ class Dot(Op):
# x or y is tensor, grad is tensor
#
# the grad has the same dim as the dot product output, namely
# (x.shape[:-1] + y.shape[:-2] + [y.shape[-1]]). To get the grad
# the output grad has the same dim as the dot product output, namely
# x.shape[:-1] + y.shape[:-2] + y.shape[-1:]. To get the grad
# wrt x or y, a tensordot is used to sum out non-compatible dims.
#
# for grad x:
# gradient is a tensordot over y and grad, summing out all but
# the second-to-last dim of y (if it exists) and transposed
# such that the first resulting dim goes last.
# for grad wrt x:
# grad is a tensordot of the output grad and y, summing out all but
# the second-to-last dim of y. If y is a vector, no sum is taken.
#
# for grad y:
# gradient is a tensordot over x and grad, summing out all but
# the last dim of x and transposed such that the first
# resulting dim goes second-to-last (unless ydim == 1, when
# it goes last).
# for grad wrt y:
# grad is a tensordot of the output grad and x, summing out
# all but the last dim of x. If y is not a vector, the tensordot is
# transposed so that its last dim becomes its second-to-last.
else:
x_axes0 = range(ydim-2)
x_axes1 = range(xdim - 1, gdim)
x_tdims = range(1, xdim) + [0]
if ydim >= 2:
x_axes0 += [ydim - 1]
xgrad = tensordot(y, gz, [x_axes0, x_axes1]).transpose(x_tdims)
y_axes0 = range(xdim - 1)
y_axes1 = range(xdim - 1)
y_tdims = range(1, ydim - 1) + [0]
if ydim >= 2:
y_tdims += [ydim - 1]
ygrad = tensordot(x, gz, [y_axes0, y_axes1]).transpose(y_tdims)
gy_axes = range(xdim - 1, gdim)
if ydim != 1:
y_axes = [ax for ax in range(ydim) if ax != ydim - 2]
else:
y_axes = []
xgrad = tensordot(gz, y, [gy_axes, y_axes])
gx_axes = y_axes = range(xdim - 1)
if ydim != 1:
t_dims = range(ydim)
t_dims[-2], t_dims[-1] = t_dims[-1], t_dims[-2]
ygrad = tensordot(gz, x, [gx_axes, y_axes]).transpose(t_dims)
else:
ygrad = tensordot(gz, x, [gx_axes, y_axes])
rval = xgrad, ygrad
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论