提交 80acf202 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rename XDot Op to Dot

上级 e9219159
...@@ -145,7 +145,7 @@ def softmax(x, dim=None): ...@@ -145,7 +145,7 @@ def softmax(x, dim=None):
return exp_x / exp_x.sum(dim=dim) return exp_x / exp_x.sum(dim=dim)
class XDot(XOp): class Dot(XOp):
"""Matrix multiplication between two XTensorVariables. """Matrix multiplication between two XTensorVariables.
This operation performs matrix multiplication between two tensors, automatically This operation performs matrix multiplication between two tensors, automatically
...@@ -247,6 +247,6 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None): ...@@ -247,6 +247,6 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
if d not in union: if d not in union:
raise ValueError(f"Dimension {d} not found in either input") raise ValueError(f"Dimension {d} not found in either input")
result = XDot(dims=tuple(dim_set))(x, y) result = Dot(dims=tuple(dim_set))(x, y)
return result return result
...@@ -4,12 +4,12 @@ from pytensor.graph import node_rewriter ...@@ -4,12 +4,12 @@ from pytensor.graph import node_rewriter
from pytensor.tensor import einsum from pytensor.tensor import einsum
from pytensor.tensor.shape import specify_shape from pytensor.tensor.shape import specify_shape
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.math import XDot from pytensor.xtensor.math import Dot
from pytensor.xtensor.rewriting.utils import register_lower_xtensor from pytensor.xtensor.rewriting.utils import register_lower_xtensor
@register_lower_xtensor @register_lower_xtensor
@node_rewriter(tracks=[XDot]) @node_rewriter(tracks=[Dot])
def lower_dot(fgraph, node): def lower_dot(fgraph, node):
"""Rewrite XDot to tensor.dot. """Rewrite XDot to tensor.dot.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论