提交 45c53b93 authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op TensorDot

上级 1ed4181b
......@@ -6464,6 +6464,23 @@ class TensorDot(Op):
e.args = e.args + (x.shape, y.shape, self.axes)
raise
def infer_shape(self, node, in_shapes):
shape_x, shape_y = in_shapes
out_shape = []
if isinstance(self.axes, (list, tuple)):
iter = (i for i in range(len(shape_x))
for j in self.axes[0] if i != j)
for i in iter:
out_shape.append(shape_x[i])
iter = (i for i in range(len(shape_y))
for j in self.axes[1] if i != j)
for i in iter:
out_shape.append(shape_y[i])
else:
out_shape = list(shape_x)[shape_x.ndim - self.axes] + \
list(shape_y)[shape_y.ndim - self.axes, shape_y.ndim]
return [out_shape]
def grad(self, inp, grads):
x, y = inp
gz, = grads
......
......@@ -6036,9 +6036,20 @@ class TestInferShape(utt.InferShapeTester):
tensordot_grad(axes)(admat, bdmat, gzdmat),
[admat_val, bdmat_val, gzdmat_val], tensordot_grad)
# tensordot
axes = 1
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
axes = ((1, ), (0, ))
self._compile_and_check([admat, bdmat],
[TensorDot(axes)(admat, bdmat)],
[admat_val, bdmat_val], TensorDot)
if __name__ == '__main__':
t = TestInferShape('setUp')
t.setUp()
t.test_infer_shape()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论