提交 89d0523c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix vectorize_node function name

上级 e1809275
...@@ -2948,9 +2948,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None ...@@ -2948,9 +2948,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
@_vectorize_node.register(Dot) @_vectorize_node.register(Dot)
def vectorize_node_to_matmul(op, node, batched_x, batched_y): def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
old_x, old_y = node.inputs old_x, old_y = node.inputs
if old_x.type.ndim == 2 and old_y.type.ndim == 2: if old_x.type.ndim == 2 and old_y.type.ndim == 2:
# If original input is equivalent to a matrix-matrix product,
# return specialized Matmul Op to avoid unnecessary new Ops.
return matmul(batched_x, batched_y).owner return matmul(batched_x, batched_y).owner
else: else:
return vectorize_node_fallback(op, node, batched_x, batched_y) return vectorize_node_fallback(op, node, batched_x, batched_y)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论