提交 0ef4d76e authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

updates to TensorDotGrad and TensorDot

上级 e9c3bf4e
...@@ -6428,7 +6428,6 @@ class TensorDotGrad(Op): ...@@ -6428,7 +6428,6 @@ class TensorDotGrad(Op):
for i in range(node.inputs[1].ndim)] for i in range(node.inputs[1].ndim)]
return [inp0_shp, inp1_shp] return [inp0_shp, inp1_shp]
tensordot_grad = TensorDotGrad tensordot_grad = TensorDotGrad
...@@ -6508,12 +6507,10 @@ class TensorDot(Op): ...@@ -6508,12 +6507,10 @@ class TensorDot(Op):
shape_x, shape_y = in_shapes shape_x, shape_y = in_shapes
out_shape = [] out_shape = []
if isinstance(self.axes, (list, tuple)): if isinstance(self.axes, (list, tuple)):
iter = (i for i in range(len(shape_x)) iter = (i for i in range(len(shape_x)) if i not in self.axes[0])
for j in self.axes[0] if i != j)
for i in iter: for i in iter:
out_shape.append(shape_x[i]) out_shape.append(shape_x[i])
iter = (i for i in range(len(shape_y)) iter = (i for i in range(len(shape_y)) if i not in self.axes[1])
for j in self.axes[1] if i != j)
for i in iter: for i in iter:
out_shape.append(shape_y[i]) out_shape.append(shape_y[i])
else: else:
......
...@@ -6034,22 +6034,162 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6034,22 +6034,162 @@ class TestInferShape(utt.InferShapeTester):
self._compile_and_check([admat, bdmat, gzdmat], self._compile_and_check([admat, bdmat, gzdmat],
tensordot_grad(axes)(admat, bdmat, gzdmat), tensordot_grad(axes)(admat, bdmat, gzdmat),
[admat_val, bdmat_val, gzdmat_val], tensordot_grad) [admat_val, bdmat_val, gzdmat_val], tensordot_grad)
admat_val = rand(5, 4)
bdmat_val = rand(5, 4)
gzdscal = dscalar()
gzdscal_val = rand()
axes = 2
self._compile_and_check([admat, bdmat, gzdscal],
tensordot_grad(axes)(admat, bdmat, gzdscal),
[admat_val, bdmat_val, gzdscal_val], tensordot_grad)
admat_val = rand(4, 5)
bdmat_val = rand(5, 3)
gzdmat_val = rand(4, 3)
axes = ((1, ), (0, )) axes = ((1, ), (0, ))
self._compile_and_check([admat, bdmat, gzdmat], self._compile_and_check([admat, bdmat, gzdmat],
tensordot_grad(axes)(admat, bdmat, gzdmat), tensordot_grad(axes)(admat, bdmat, gzdmat),
[admat_val, bdmat_val, gzdmat_val], tensordot_grad) [admat_val, bdmat_val, gzdmat_val], tensordot_grad)
axes = ((1, 0))
self._compile_and_check([admat, bdmat, gzdmat],
tensordot_grad(axes)(admat, bdmat, gzdmat),
[admat_val, bdmat_val, gzdmat_val], tensordot_grad)
admat_val = rand(4, 5)
bdmat_val = rand(3, 4)
gzdmat_val = rand(5, 3)
axes = ((0, ), (1, ))
self._compile_and_check([admat, bdmat, gzdmat],
tensordot_grad(axes)(admat, bdmat, gzdmat),
[admat_val, bdmat_val, gzdmat_val], tensordot_grad)
gzdscal = dscalar()
admat_val = rand(5, 4)
bdmat_val = rand(5, 4)
gzdscal_val = rand()
axes = ((0, 1), (0, 1))
self._compile_and_check([admat, bdmat, gzdscal],
tensordot_grad(axes)(admat, bdmat, gzdscal),
[admat_val, bdmat_val, gzdscal_val], tensordot_grad)
# Note: The two following tests currently fail and should unlikely be
# for the shape of a grad is very simple and similar to that of the inputs.
# Additional tests involving 3-tensors and 4-tensors will be included once
# this is resolved. Suggestion: first solve issues with 'tensordot' next.
"""
gzdscal = dscalar()
admat_val = rand(5, 4)
bdmat_val = rand(4, 5)
gzdscal_val = rand()
axes = ((0, 1), (1, 0))
self._compile_and_check([admat, bdmat, gzdscal],
tensordot_grad(axes)(admat, bdmat, gzdscal),
[admat_val, bdmat_val, gzdscal_val], tensordot_grad)
gzdscal = dscalar()
admat_val = rand(5, 4)
bdmat_val = rand(5, 4)
gzdscal_val = rand()
axes = ((1, 0 ), (1, 0))
self._compile_and_check([admat, bdmat, gzdscal],
tensordot_grad(axes)(admat, bdmat, gzdscal),
[admat_val, bdmat_val, gzdscal_val], tensordot_grad)
"""
# tensordot # tensordot
admat = dmatrix()
bdmat = dmatrix()
admat_val = rand(4, 5)
bdmat_val = rand(5, 3)
axes = 1 axes = 1
self._compile_and_check([admat, bdmat], self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)], [TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot) [admat_val, bdmat_val], TensorDot)
admat_val = rand(5, 4)
bdmat_val = rand(5, 4)
axes = 2
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
admat_val = rand(4, 5)
bdmat_val = rand(5, 3)
axes = ((1, ), (0, )) axes = ((1, ), (0, ))
self._compile_and_check([admat, bdmat], self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)], [TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot) [admat_val, bdmat_val], TensorDot)
axes = ((1, 0))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
admat_val = rand(4, 5)
bdmat_val = rand(3, 4)
axes = ((0, ), (1, ))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
admat_val = rand(5, 4)
bdmat_val = rand(4, 5)
axes = ((1,), (0,))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
admat_val = rand(5, 4)
bdmat_val = rand(5, 4)
axes = ((0, 1), (0, 1))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
admat_val = rand(5, 4)
bdmat_val = rand(4, 5)
axes = ((1, 0), (0, 1))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
adtens3 = dtensor3()
admat_val = rand(5, 4)
adtens3_val = rand(5, 4, 3)
axes = 2
self._compile_and_check([admat, adtens3],
[TensorDot(axes)(admat, adtens3)],
[admat_val, adtens3_val], TensorDot)
adtens3_val = rand(4, 5, 3)
axes = ((1, 0), (0, 1))
self._compile_and_check([admat, adtens3],
[TensorDot(axes)(admat, adtens3)],
[admat_val, adtens3_val], TensorDot)
adtens3_val = rand(4, 3, 5)
axes = ((1, 0), (0, 2))
self._compile_and_check([admat, adtens3],
[TensorDot(axes)(admat, adtens3)],
[admat_val, adtens3_val], TensorDot)
adtens4 = dtensor4()
admat_val = rand(5, 4)
adtens4_val = rand(5, 4, 3, 2)
axes = 2
self._compile_and_check([admat, adtens4],
[TensorDot(axes)(admat, adtens4)],
[admat_val, adtens4_val], TensorDot)
adtens4_val = rand(4, 3, 2, 5)
axes = ((1, 0), (0, 3))
self._compile_and_check([admat, adtens4],
[TensorDot(axes)(admat, adtens4)],
[admat_val, adtens4_val], TensorDot)
# Flatten # Flatten
adtens = tensor3() adtens = tensor3()
adtens_val = rand(4, 5, 3) adtens_val = rand(4, 5, 3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论