提交 95e8b482 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

reversed tensordot arguments to remove transpose for xgrad. updated comments.

ygrad still requires a transpose.
上级 4beb4dea
...@@ -6977,34 +6977,33 @@ class Dot(Op): ...@@ -6977,34 +6977,33 @@ class Dot(Op):
# x or y is tensor, grad is tensor # x or y is tensor, grad is tensor
# #
# the grad has the same dim as the dot product output, namely # the output grad has the same dim as the dot product output, namely
# (x.shape[:-1] + y.shape[:-2] + [y.shape[-1]]). To get the grad # x.shape[:-1] + y.shape[:-2] + y.shape[-1:]. To get the grad
# wrt x or y, a tensordot is used to sum out non-compatible dims. # wrt x or y, a tensordot is used to sum out non-compatible dims.
# #
# for grad x: # for grad wrt x:
# gradient is a tensordot over y and grad, summing out all but # grad is a tensordot of the output grad and y, summing out all but
# the second-to-last dim of y (if it exists) and transposed # the second-to-last dim of y. If y is a vector, no sum is taken.
# such that the first resulting dim goes last.
# #
# for grad y: # for grad wrt y:
# gradient is a tensordot over x and grad, summing out all but # grad is a tensordot of the output grad and x, summing out
# the last dim of x and transposed such that the first # all but the last dim of x. If y is not a vector, the tensordot is
# resulting dim goes second-to-last (unless ydim == 1, when # transposed so that its last dim becomes its second-to-last.
# it goes last).
else: else:
x_axes0 = range(ydim-2) gy_axes = range(xdim - 1, gdim)
x_axes1 = range(xdim - 1, gdim) if ydim != 1:
x_tdims = range(1, xdim) + [0] y_axes = [ax for ax in range(ydim) if ax != ydim - 2]
if ydim >= 2: else:
x_axes0 += [ydim - 1] y_axes = []
xgrad = tensordot(y, gz, [x_axes0, x_axes1]).transpose(x_tdims) xgrad = tensordot(gz, y, [gy_axes, y_axes])
y_axes0 = range(xdim - 1) gx_axes = y_axes = range(xdim - 1)
y_axes1 = range(xdim - 1) if ydim != 1:
y_tdims = range(1, ydim - 1) + [0] t_dims = range(ydim)
if ydim >= 2: t_dims[-2], t_dims[-1] = t_dims[-1], t_dims[-2]
y_tdims += [ydim - 1] ygrad = tensordot(gz, x, [gx_axes, y_axes]).transpose(t_dims)
ygrad = tensordot(x, gz, [y_axes0, y_axes1]).transpose(y_tdims) else:
ygrad = tensordot(gz, x, [gx_axes, y_axes])
rval = xgrad, ygrad rval = xgrad, ygrad
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论