提交 9af28c1c authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

update to infer_shape in PermuteRowElements

上级 775fd596
......@@ -5816,14 +5816,10 @@ class PermuteRowElements(Op):
def infer_shape(self, node, in_shapes):
shp_x = in_shapes[0]
shp_y = in_shapes[1]
if len(shp_x) > len(shp_y):
out_shape = shp_x
elif len(shp_x) < len(shp_y):
out_shape = shp_y
else:
out_shape = []
for i in range(len(shp_x)):
out_shape.append(maximum(shp_x[i], shp_y[i]))
assert len(shp_x) == len(shp_y)
out_shape = []
for i in range(len(shp_x)):
out_shape.append(maximum(shp_x[i], shp_y[i]))
return [out_shape]
def grad(self, inp, grads):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论