提交 80acf202 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rename XDot Op to Dot

上级 e9219159
......@@ -145,7 +145,7 @@ def softmax(x, dim=None):
return exp_x / exp_x.sum(dim=dim)
class XDot(XOp):
class Dot(XOp):
"""Matrix multiplication between two XTensorVariables.
This operation performs matrix multiplication between two tensors, automatically
......@@ -247,6 +247,6 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
if d not in union:
raise ValueError(f"Dimension {d} not found in either input")
result = XDot(dims=tuple(dim_set))(x, y)
result = Dot(dims=tuple(dim_set))(x, y)
return result
......@@ -4,12 +4,12 @@ from pytensor.graph import node_rewriter
from pytensor.tensor import einsum
from pytensor.tensor.shape import specify_shape
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.math import XDot
from pytensor.xtensor.math import Dot
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
@register_lower_xtensor
@node_rewriter(tracks=[XDot])
@node_rewriter(tracks=[Dot])
def lower_dot(fgraph, node):
"""Rewrite XDot to tensor.dot.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论