提交 99fa5f39 authored 作者: Frederic's avatar Frederic

small code simplification and reuse the provided shapes.

上级 b731111f
...@@ -6504,11 +6504,7 @@ class TensorDotGrad(Op): ...@@ -6504,11 +6504,7 @@ class TensorDotGrad(Op):
assert gx[0].shape == x.shape assert gx[0].shape == x.shape
def infer_shape(self, node, in_shapes): def infer_shape(self, node, in_shapes):
inp0_shp = [node.inputs[0].shape[i] return in_shapes[:2]
for i in range(node.inputs[0].ndim)]
inp1_shp = [node.inputs[1].shape[i]
for i in range(node.inputs[1].ndim)]
return [inp0_shp, inp1_shp]
tensordot_grad = TensorDotGrad tensordot_grad = TensorDotGrad
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论