Unverified 提交 6ad1c5cf authored 作者: Pham Nguyen Hung's avatar Pham Nguyen Hung 提交者: GitHub

Implement Dot and BatchedDot in PyTensor (#878)

上级 426931b0
from pytensor.link.pytorch.linker import PytorchLinker
......@@ -2,9 +2,12 @@
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
# # Load dispatch specializations
import pytensor.link.pytorch.dispatch.blas
import pytensor.link.pytorch.dispatch.scalar
import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.math
import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.shape
import pytensor.link.pytorch.dispatch.sort
# isort: on
import torch
from pytensor.link.pytorch.dispatch import pytorch_funcify
from pytensor.tensor.blas import BatchedDot
@pytorch_funcify.register(BatchedDot)
def pytorch_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b):
if a.shape[0] != b.shape[0]:
raise TypeError("Shapes must match in the 0-th dimension")
return torch.bmm(a, b)
return batched_dot
import torch
from pytensor.link.pytorch.dispatch import pytorch_funcify
from pytensor.tensor.math import Dot
@pytorch_funcify.register(Dot)
def pytorch_funcify_Dot(op, **kwargs):
def dot(x, y):
return torch.matmul(x, y)
return dot
import numpy as np
import pytest
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import blas as pt_blas
from pytensor.tensor.type import tensor3
from tests.link.pytorch.test_basic import compare_pytorch_and_py
def test_pytorch_BatchedDot():
# tensor3 . tensor3
a = tensor3("a")
a_test = np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
b = tensor3("b")
b_test = np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
out = pt_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
pytensor_pytorch_fn, _ = compare_pytorch_and_py(fgraph, [a_test, b_test])
# A dimension mismatch should raise a TypeError for compatibility
inputs = [a_test[:-1], b_test]
with pytest.raises(TypeError):
pytensor_pytorch_fn(*inputs)
import numpy as np
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.type import matrix, scalar, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py
def test_pytorch_dot():
y = vector("y")
y_test = np.r_[1.0, 2.0].astype(config.floatX)
x = vector("x")
x_test = np.r_[3.0, 4.0].astype(config.floatX)
A = matrix("A")
A_test = np.array([[6, 3], [3, 0]], dtype=config.floatX)
alpha = scalar("alpha")
alpha_test = np.array(3.0, dtype=config.floatX)
beta = scalar("beta")
beta_test = np.array(5.0, dtype=config.floatX)
# 2D * 2D
out = A.dot(A * alpha) + beta * A
fgraph = FunctionGraph([A, alpha, beta], [out])
compare_pytorch_and_py(fgraph, [A_test, alpha_test, beta_test])
# 1D * 2D and 1D * 1D
out = y.dot(alpha * A).dot(x) + beta * y
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
compare_pytorch_and_py(fgraph, [y_test, x_test, A_test, alpha_test, beta_test])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论