提交 35e87e0a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove duplicated BLAS rewriting code

Accidentally introduced in c655b028 Also move tests to the rewriting test file
上级 a0fe30de
差异被折叠。
......@@ -2,11 +2,39 @@ import numpy as np
import pytest
from pytensor import function
from pytensor import tensor as pt
from pytensor.compile import get_default_mode
from pytensor.tensor import matmul, tensor, vectorize
from pytensor.graph import FunctionGraph
from pytensor.tensor import (
col,
dscalar,
dvector,
matmul,
matrix,
mul,
neg,
row,
scalar,
sqrt,
tensor,
vector,
vectorize,
)
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.blas import (
_as_scalar,
_factor_canonicalized,
_gemm_canonicalize,
_is_real_matrix,
res_is_a,
specialize_matmul_to_batched_dot,
)
def XYZab():
return matrix(), matrix(), matrix(), scalar(), scalar()
@pytest.mark.parametrize("valid_case", (True, False))
......@@ -46,3 +74,136 @@ def test_specialize_matmul_to_batched_dot(valid_case):
vectorize_pt(x_test, y_test),
vectorize_np(x_test, y_test),
)
def test_gemm_factor():
X, Y = matrix("X"), matrix("Y")
assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)])
assert [(2.0, X)] == _factor_canonicalized([(1.0, X), (1.0, X)])
def test_gemm_canonicalize():
X, Y, Z, a, b = (
matrix("X"),
matrix("Y"),
matrix("Z"),
scalar("a"),
scalar("b"),
)
c, d = scalar("c"), scalar("d")
u = row("u")
v = vector("v")
w = col("w")
can = []
fg = FunctionGraph([X, Y, Z], [X + Y + Z], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), (1.0, Z)]
can = []
fg = FunctionGraph([X, Y, u], [X + Y + u], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), (1.0, u)], can
can = []
fg = FunctionGraph([X, Y, v], [X + Y + v], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
# [(1.0, X), (1.0, Y), (1.0, InplaceDimShuffle{x,0}(v))]
assert can[:2] == [(1.0, X), (1.0, Y)]
assert isinstance(can[2], tuple)
assert len(can[2]) == 2
assert can[2][0] == 1.0
assert can[2][1].owner
assert isinstance(can[2][1].owner.op, DimShuffle)
assert can[2][1].owner.inputs == [v]
can = []
fg = FunctionGraph([X, Y, w], [X + Y + w], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), (1.0, w)], can
can = []
fg = FunctionGraph([a, X, Y, b, Z, c], [a * X + Y - b * Z * c], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can[0] == (a, X)
assert can[1] == (1.0, Y)
assert can[2][0].owner.op == mul
assert can[2][0].owner.inputs[0].owner.op == neg
assert can[2][0].owner.inputs[0].owner.inputs[0] == c
assert can[2][0].owner.inputs[1] == b
can = []
fg = FunctionGraph(
[a, X, Y, b, Z, c, d], [(-d) * X - (a * X + Y - b * Z * c)], clone=False
)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can[0][0].owner.op == neg
assert can[0][0].owner.inputs[0] == d
assert can[0][1] == X
assert can[1][0].owner.op == neg
assert can[1][0].owner.inputs[0] == a
assert can[2] == (-1.0, Y)
assert can[3][0].owner.op == mul
assert can[3][0].owner.inputs == [c, b]
def test_res_is_a():
X, Y, Z, a, b = XYZab()
assert not res_is_a(None, a, sqrt)
assert not res_is_a(None, a + a, sqrt)
assert res_is_a(None, sqrt(a + a), sqrt)
sqrt_term = sqrt(a + a)
fg = FunctionGraph([a], [2 * sqrt_term], clone=False)
assert res_is_a(fg, sqrt_term, sqrt, 2)
assert not res_is_a(fg, sqrt_term, sqrt, 0)
class TestAsScalar:
def test_basic(self):
# Test that it works on scalar constants
a = pt.constant(2.5)
b = pt.constant(np.asarray([[[0.5]]]))
b2 = b.dimshuffle()
assert b2.ndim == 0
d_a = DimShuffle(input_ndim=0, new_order=[])(a)
d_b = DimShuffle(input_ndim=3, new_order=[0, 2, 1])(b)
d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x", "x"])(a)
assert _as_scalar(a) == a
assert _as_scalar(b) != b
assert _as_scalar(d_a) != d_a
assert _as_scalar(d_b) != d_b
assert _as_scalar(d_a2) != d_a2
def test_basic_1(self):
# Test that it fails on nonscalar constants
a = pt.constant(np.ones(5))
assert _as_scalar(a) is None
assert _as_scalar(DimShuffle(input_ndim=1, new_order=[0, "x"])(a)) is None
def test_basic_2(self):
# Test that it works on scalar variables
a = dscalar()
d_a = DimShuffle(input_ndim=0, new_order=[])(a)
d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x"])(a)
assert _as_scalar(a) is a
assert _as_scalar(d_a) is a
assert _as_scalar(d_a2) is a
def test_basic_3(self):
# Test that it fails on nonscalar variables
a = matrix()
assert _as_scalar(a) is None
assert _as_scalar(DimShuffle(input_ndim=2, new_order=[0, "x", 1])(a)) is None
class TestRealMatrix:
def test_basic(self):
assert _is_real_matrix(DimShuffle(input_ndim=2, new_order=[1, 0])(matrix()))
assert not _is_real_matrix(
DimShuffle(input_ndim=1, new_order=["x", 0])(dvector())
)
......@@ -16,7 +16,6 @@ from pytensor.compile.mode import Mode
from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config
from pytensor.gradient import grad
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import in2out
from pytensor.graph.utils import InconsistencyError
from pytensor.tensor import inplace
......@@ -28,12 +27,8 @@ from pytensor.tensor.blas import (
Gemm,
Gemv,
Ger,
_as_scalar,
_dot22,
_dot22scalar,
_factor_canonicalized,
_gemm_canonicalize,
_is_real_matrix,
batched_dot,
batched_tensordot,
gemm,
......@@ -44,19 +39,15 @@ from pytensor.tensor.blas import (
gemv_no_inplace,
ger,
ger_destructive,
res_is_a,
)
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, dot, mean, mul, neg, outer, sigmoid, sqrt
from pytensor.tensor.math import Dot, dot, mean, mul, outer, sigmoid
from pytensor.tensor.rewriting.blas import local_dot22_to_dot22scalar, local_gemm_to_ger
from pytensor.tensor.type import (
cmatrix,
col,
cscalar,
dmatrix,
drow,
dscalar,
dvector,
fmatrix,
fscalar,
imatrix,
......@@ -65,7 +56,6 @@ from pytensor.tensor.type import (
ivector,
matrices,
matrix,
row,
scalar,
scalars,
tensor,
......@@ -572,67 +562,6 @@ class TestGemmNoFlags:
self.run_gemm(dtype, alpha, beta, tA, tB, tC, sA, sB, sC, rng)
def test_res_is_a():
X, Y, Z, a, b = XYZab()
assert not res_is_a(None, a, sqrt)
assert not res_is_a(None, a + a, sqrt)
assert res_is_a(None, sqrt(a + a), sqrt)
sqrt_term = sqrt(a + a)
fg = FunctionGraph([a], [2 * sqrt_term], clone=False)
assert res_is_a(fg, sqrt_term, sqrt, 2)
assert not res_is_a(fg, sqrt_term, sqrt, 0)
class TestAsScalar:
def test_basic(self):
# Test that it works on scalar constants
a = pt.constant(2.5)
b = pt.constant(np.asarray([[[0.5]]]))
b2 = b.dimshuffle()
assert b2.ndim == 0
d_a = DimShuffle(input_ndim=0, new_order=[])(a)
d_b = DimShuffle(input_ndim=3, new_order=[0, 2, 1])(b)
d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x", "x"])(a)
assert _as_scalar(a) == a
assert _as_scalar(b) != b
assert _as_scalar(d_a) != d_a
assert _as_scalar(d_b) != d_b
assert _as_scalar(d_a2) != d_a2
def test_basic_1(self):
# Test that it fails on nonscalar constants
a = pt.constant(np.ones(5))
assert _as_scalar(a) is None
assert _as_scalar(DimShuffle(input_ndim=1, new_order=[0, "x"])(a)) is None
def test_basic_2(self):
# Test that it works on scalar variables
a = dscalar()
d_a = DimShuffle(input_ndim=0, new_order=[])(a)
d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x"])(a)
assert _as_scalar(a) is a
assert _as_scalar(d_a) is a
assert _as_scalar(d_a2) is a
def test_basic_3(self):
# Test that it fails on nonscalar variables
a = matrix()
assert _as_scalar(a) is None
assert _as_scalar(DimShuffle(input_ndim=2, new_order=[0, "x", 1])(a)) is None
class TestRealMatrix:
def test_basic(self):
assert _is_real_matrix(DimShuffle(input_ndim=2, new_order=[1, 0])(matrix()))
assert not _is_real_matrix(
DimShuffle(input_ndim=1, new_order=["x", 0])(dvector())
)
"""
This test suite ensures that Gemm is inserted where it belongs, and
that the resulting functions compute the same things as the originals.
......@@ -774,78 +703,6 @@ def test_gemm_opt_double_gemm():
assert max_abs_err <= eps, "GEMM is computing the wrong output. max_rel_err ="
def test_gemm_canonicalize():
X, Y, Z, a, b = (
matrix("X"),
matrix("Y"),
matrix("Z"),
scalar("a"),
scalar("b"),
)
c, d = scalar("c"), scalar("d")
u = row("u")
v = vector("v")
w = col("w")
can = []
fg = FunctionGraph([X, Y, Z], [X + Y + Z], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), (1.0, Z)]
can = []
fg = FunctionGraph([X, Y, u], [X + Y + u], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), (1.0, u)], can
can = []
fg = FunctionGraph([X, Y, v], [X + Y + v], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
# [(1.0, X), (1.0, Y), (1.0, InplaceDimShuffle{x,0}(v))]
assert can[:2] == [(1.0, X), (1.0, Y)]
assert isinstance(can[2], tuple)
assert len(can[2]) == 2
assert can[2][0] == 1.0
assert can[2][1].owner
assert isinstance(can[2][1].owner.op, DimShuffle)
assert can[2][1].owner.inputs == [v]
can = []
fg = FunctionGraph([X, Y, w], [X + Y + w], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), (1.0, w)], can
can = []
fg = FunctionGraph([a, X, Y, b, Z, c], [a * X + Y - b * Z * c], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can[0] == (a, X)
assert can[1] == (1.0, Y)
assert can[2][0].owner.op == mul
assert can[2][0].owner.inputs[0].owner.op == neg
assert can[2][0].owner.inputs[0].owner.inputs[0] == c
assert can[2][0].owner.inputs[1] == b
can = []
fg = FunctionGraph(
[a, X, Y, b, Z, c, d], [(-d) * X - (a * X + Y - b * Z * c)], clone=False
)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can[0][0].owner.op == neg
assert can[0][0].owner.inputs[0] == d
assert can[0][1] == X
assert can[1][0].owner.op == neg
assert can[1][0].owner.inputs[0] == a
assert can[2] == (-1.0, Y)
assert can[3][0].owner.op == mul
assert can[3][0].owner.inputs == [c, b]
def test_gemm_factor():
X, Y = matrix("X"), matrix("Y")
assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)])
assert [(2.0, X)] == _factor_canonicalized([(1.0, X), (1.0, X)])
def test_upcasting_scalar_nogemm():
# Test that the optimization does not crash when the scale has an incorrect
# dtype, and forces upcasting of the result
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论