提交 99fa5f39 authored 作者: Frederic's avatar Frederic

small code simplification and reuse the provided shapes.

上级 b731111f
......@@ -6504,11 +6504,7 @@ class TensorDotGrad(Op):
assert gx[0].shape == x.shape
def infer_shape(self, node, in_shapes):
inp0_shp = [node.inputs[0].shape[i]
for i in range(node.inputs[0].ndim)]
inp1_shp = [node.inputs[1].shape[i]
for i in range(node.inputs[1].ndim)]
return [inp0_shp, inp1_shp]
return in_shapes[:2]
tensordot_grad = TensorDotGrad
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论