提交 95701b81 authored 作者: Tanjay94's avatar Tanjay94

Fixed infer_shape method.

上级 d16a6c4d
......@@ -478,7 +478,21 @@ Op Example
def infer_shape(self, node, shapes):
xshp, yshp = shapes
return [xshp[:-1] + yshp[-1:]]
x, y = node.inputs
# vector / vector
if x.ndim == 1 and y.ndim == 1:
return [()]
# matrix / vector
if x.ndim == 2 and y.ndim == 1:
return [xshp[:-1]]
# vector / matrix
if x.ndim == 1 and y.ndim == 2:
return [yshp[-1:]]
# matrix / matrix
if x.ndim == 2 and y.ndim == 2:
return [xshp[:-1] + yshp[-1:]]
raise NotImplementedError()
You can try it as follows:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论