提交 1430e999 authored 作者: Bart van Merriënboer's avatar Bart van Merriënboer

Merge pull request #2602 from lamblin/fix_bart

Fix the broadcastable pattern of tensordot's output
......@@ -4997,6 +4997,7 @@ def tensordot(a, b, axes=2):
'of b (b.ndim=%i, axes=%i)' % (b.ndim, axes))
outshape = concatenate([a.shape[:a.ndim - axes], b.shape[axes:]])
outbcast = a.broadcastable[:a.ndim - axes] + b.broadcastable[axes:]
outndim = a.ndim + b.ndim - (2 * axes)
a_shape_0 = b_shape_0 = a_shape_1 = b_shape_1 = 1
......@@ -5012,7 +5013,10 @@ def tensordot(a, b, axes=2):
a_reshaped = a.reshape((a_shape_0, a_shape_1), ndim=2)
b_reshaped = b.reshape((b_shape_0, b_shape_1), ndim=2)
return _dot(a_reshaped, b_reshaped).reshape(outshape, outndim)
out = _dot(a_reshaped, b_reshaped).reshape(outshape, outndim)
# Make sure the broadcastable pattern of the result is correct,
# since some shape information can be lost in the reshapes.
return patternbroadcast(out, outbcast)
# if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first.
......
......@@ -5529,6 +5529,29 @@ class test_tensordot(unittest.TestCase):
f3(aval, bval)))
utt.verify_grad(self.TensorDot(axes), [aval, bval])
def test_broadcastable1(self):
x = TensorType(dtype=floatX, broadcastable=(True, False, False))('x')
y = tensor3('y')
z = tensordot(x, y)
assert z.broadcastable == (True, False)
f = inplace_func([x, y], z)
xv = rand(1, 3, 4)
yv = rand(3, 4, 5)
zv = f(xv, yv)
self.assertTrue(numpy.allclose(numpy.tensordot(xv, yv), zv))
def test_broadcastable2(self):
x = TensorType(dtype=floatX, broadcastable=(True, False, False))('x')
y = tensor3('y')
axes = [[2, 1], [0, 1]]
z = tensordot(x, y, axes=axes)
assert z.broadcastable == (True, False)
f = inplace_func([x, y], z)
xv = rand(1, 3, 4)
yv = rand(4, 3, 5)
zv = f(xv, yv)
self.assertTrue(numpy.allclose(numpy.tensordot(xv, yv, axes=axes), zv))
def test_smallest_stack():
sx, sy = dscalar(), dscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论