提交 1ed4181b authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: Op TensorDotGrad

上级 6e9e6d6b
......@@ -6381,6 +6381,14 @@ class TensorDotGrad(Op):
newshapey[[newpos for newpos in idy]] = range(y.ndim)
gy[0] = numpy.transpose(_gy, newshapey)
def infer_shape(self, node, in_shapes):
inp0_shp = [node.inputs[0].shape[i]
for i in range(node.inputs[0].ndim)]
inp1_shp = [node.inputs[1].shape[i]
for i in range(node.inputs[1].ndim)]
return [inp0_shp, inp1_shp]
tensordot_grad = TensorDotGrad
......
......@@ -6016,6 +6016,40 @@ def test_transpose():
assert tensor.transpose(tensor.dmatrix()).name is None
class TestInferShape(utt.InferShapeTester):
def test_infer_shape(self):
# tensordot_grad
admat = dmatrix()
bdmat = dmatrix()
gzdmat = dmatrix()
admat_val = rand(4, 5)
bdmat_val = rand(5, 3)
gzdmat_val = rand(4, 3)
axes = 1
self._compile_and_check([admat, bdmat, gzdmat],
tensordot_grad(axes)(admat, bdmat, gzdmat),
[admat_val, bdmat_val, gzdmat_val], tensordot_grad)
axes = ((1, ), (0, ))
self._compile_and_check([admat, bdmat, gzdmat],
tensordot_grad(axes)(admat, bdmat, gzdmat),
[admat_val, bdmat_val, gzdmat_val], tensordot_grad)
if __name__ == '__main__':
t = TestInferShape('setUp')
t.setUp()
t.test_infer_shape()
"""
if __name__ == '__main__':
if 0:
unittest.main()
......@@ -6025,3 +6059,4 @@ if __name__ == '__main__':
suite = unittest.TestLoader()
suite = suite.loadTestsFromTestCase(testcase)
unittest.TextTestRunner(verbosity=2).run(suite)
"""
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论