提交 b438341d authored 作者: David Warde-Farley's avatar David Warde-Farley

PEP8/readability.

上级 19f7768f
...@@ -5339,19 +5339,18 @@ class TensorDotGrad(Op): ...@@ -5339,19 +5339,18 @@ class TensorDotGrad(Op):
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, y, gz = inp x, y, gz = inp
gx, gy = out gx, gy = out
sum_over_y = range(y.ndim) sum_over_y = range(y.ndim)
[sum_over_y.remove(q) for q in self.axes[1]] [sum_over_y.remove(q) for q in self.axes[1]]
sum_over_x = range(x.ndim) sum_over_x = range(x.ndim)
[sum_over_x.remove(q) for q in self.axes[0]] [sum_over_x.remove(q) for q in self.axes[0]]
tdot_axes = [range(x.ndim - len(self.axes[0]), gz.ndim), sum_over_y]
_gx = numpy.tensordot(gz, y, [range(x.ndim-len(self.axes[0]),gz.ndim), sum_over_y]) _gx = numpy.tensordot(gz, y, tdot_axes)
idx = numpy.hstack((sum_over_x, self.axes[0])) idx = numpy.hstack((sum_over_x, self.axes[0]))
newshapex = numpy.zeros(x.ndim) newshapex = numpy.zeros(x.ndim)
newshapex[[newpos for newpos in idx]] = [i for i in xrange(x.ndim)] newshapex[[newpos for newpos in idx]] = [i for i in xrange(x.ndim)]
gx[0] = numpy.transpose(_gx, newshapex) gx[0] = numpy.transpose(_gx, newshapex)
tdot_axes = [sum_over_x, range(x.ndim - len(self.axes[0]))]
_gy = numpy.tensordot(x, gz, [sum_over_x, range(x.ndim-len(self.axes[0]))]) _gy = numpy.tensordot(x, gz, tdot_axes)
idy = numpy.hstack((self.axes[1], sum_over_y)) idy = numpy.hstack((self.axes[1], sum_over_y))
newshapey = numpy.zeros(y.ndim) newshapey = numpy.zeros(y.ndim)
newshapey[[newpos for newpos in idy]] = [i for i in xrange(y.ndim)] newshapey[[newpos for newpos in idy]] = [i for i in xrange(y.ndim)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论