提交 b731111f authored 作者: Frederic's avatar Frederic

Make TensorDotGrad raise an NotImplementedError for the case it return wrong shape.

A reshape fix this problem, but return the wrong results.
上级 0a532cbd
......@@ -6453,6 +6453,13 @@ pprint.assign(dot, printing.OperatorPrinter(printing.special['middle_dot'],
class TensorDotGrad(Op):
def __init__(self, axes):
self.axes = TensorDot.parse_axes(axes)
if isinstance(self.axes, (tuple, list)) and len(self.axes) == 2:
# The current perform don't implement correctly those cases
for i in range(len(self.axes[0]) - 1):
if self.axes[0][i] > self.axes[0][i + 1]:
raise NotImplementedError()
if self.axes[1][i] > self.axes[1][i + 1]:
raise NotImplementedError()
def __eq__(self, other):
return type(self) == type(other) and self.axes == other.axes
......@@ -6493,6 +6500,8 @@ class TensorDotGrad(Op):
newshapey = numpy.zeros(y.ndim)
newshapey[[newpos for newpos in idy]] = range(y.ndim)
gy[0] = numpy.transpose(_gy, newshapey)
assert gy[0].shape == y.shape
assert gx[0].shape == x.shape
def infer_shape(self, node, in_shapes):
inp0_shp = [node.inputs[0].shape[i]
......
......@@ -5124,25 +5124,40 @@ class test_tensordot(unittest.TestCase):
# Test matrix-matrix
amat = matrix()
axes = ((1,),(0,))
c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat,bmat],c)
aval = rand(4,7)
bval = rand(7,9)
self.assertTrue(numpy.allclose(numpy.tensordot(aval,bval,axes),
f3(aval,bval)))
utt.verify_grad(TensorDot(axes), [aval,bval])
for axes, shps in [[((0,), (0,)), [(4, 7), (4, 9)]],
[((0,), (1,)), [(4, 7), (9, 4)]],
[((1,), (0,)), [(4, 7), (7, 9)]],
[((1,), (1,)), [(4, 7), (9, 7)]],
[((0, 1), (0, 1)), [(4, 7), (4, 7)]],
# [((0, 1), (1, 0)), [(4, 7), (7, 4)]],
# [((1, 0), (1, 0)), [(4, 7), (4, 7)]],
# [((1, 0), (0, 1)), [(4, 7), (7, 4)]],
]:
c = tensordot(amat, bmat, axes)
f3 = inplace_func([amat, bmat], c)
aval = rand(*shps[0])
bval = rand(*shps[1])
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f3(aval, bval)))
utt.verify_grad(TensorDot(axes), [aval, bval])
# Test ndarray-matrix, sum over one dim of matrix
atens = tensor4()
axes = ((2,),(1,))
c = tensordot(atens, bmat, axes)
f4 = inplace_func([atens,bmat],c)
aval = rand(1,2,3,4)
bval = rand(2,3)
self.assertTrue(numpy.allclose(numpy.tensordot(aval,bval,axes),
f4(aval,bval)))
utt.verify_grad(TensorDot(axes), [aval,bval])
for axes, shps in [[((2,), (1,)), [(1, 2, 3, 4), (2, 3)]],
[((0,), (1,)), [(1, 2, 3, 4), (3, 1)]],
[((0,), (0,)), [(1, 2, 3, 4), (1, 3)]],
[((3,), (0,)), [(1, 2, 3, 4), (4, 1)]],
# [((3, 1), (0, 1)), [(1, 2, 3, 4), (4, 2)]],
# [((0, 1), (1, 0)), [(1, 2, 3, 4), (2, 1)]],
# [((3, 1), (1, 0)), [(1, 2, 3, 4), (2, 4)]],
]:
atens = tensor4()
c = tensordot(atens, bmat, axes)
f4 = inplace_func([atens, bmat], c)
aval = rand(*shps[0])
bval = rand(*shps[1])
self.assertTrue(numpy.allclose(numpy.tensordot(aval, bval, axes),
f4(aval, bval)))
utt.verify_grad(TensorDot(axes), [aval, bval])
# Test ndarray-ndarray
atens = tensor4()
......@@ -6074,10 +6089,7 @@ class TestInferShape(utt.InferShapeTester):
tensordot_grad(axes)(admat, bdmat, gzdscal),
[admat_val, bdmat_val, gzdscal_val], tensordot_grad)
# Note: The two following tests currently fail and should unlikely be,
# for the shape of a grad is very simple and similar to that of the inputs.
# Additional tests involving 3-tensors and 4-tensors will be included once
# this is resolved.
# tensordot_grad currently do not support not ordered axes
"""
gzdscal = dscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论