提交 f84bb240 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add error message for incompatible static shape in Dot Op

上级 660916d6
...@@ -3025,6 +3025,11 @@ class Dot(Op): ...@@ -3025,6 +3025,11 @@ class Dot(Op):
) )
sx, sy = (input.type.shape for input in inputs) sx, sy = (input.type.shape for input in inputs)
if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]:
raise ValueError(
f"Incompatible shared dimension for dot product: {sx}, {sy}"
)
if len(sy) == 2: if len(sy) == 2:
sz = sx[:-1] + sy[-1:] sz = sx[:-1] + sy[-1:]
elif len(sy) == 1: elif len(sy) == 1:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论