提交 35e87e0a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove duplicated BLAS rewriting code

Accidentally introduced in c655b028 Also move tests to the rewriting test file
上级 a0fe30de
......@@ -79,7 +79,6 @@ import functools
import logging
import os
import shlex
import time
from pathlib import Path
import numpy as np
......@@ -103,10 +102,8 @@ from pytensor.scalar import bool as bool_t
from pytensor.tensor import basic as ptb
from pytensor.tensor.basic import expand_dims
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import add, mul, neg, sub, variadic_add
from pytensor.tensor.shape import shape_padright, specify_broadcastable
from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor
from pytensor.tensor.type import DenseTensorType, tensor
_logger = logging.getLogger("pytensor.tensor.blas")
......@@ -1148,322 +1145,6 @@ pprint.assign(gemm_inplace, FunctionPrinter(["gemm_inplace"]))
pprint.assign(gemm_no_inplace, FunctionPrinter(["gemm_no_inplace"]))
def res_is_a(fgraph, var, op, maxclients=None):
if maxclients is not None and var in fgraph.clients:
retval = len(fgraph.get_clients(var)) <= maxclients
else:
retval = True
return var.owner and var.owner.op == op and retval
def _as_scalar(res, dtype=None):
"""Return ``None`` or a `TensorVariable` of float type"""
if dtype is None:
dtype = config.floatX
if all(s == 1 for s in res.type.shape):
while res.owner and isinstance(res.owner.op, DimShuffle):
res = res.owner.inputs[0]
# may still have some number of True's
if res.type.ndim > 0:
rval = res.dimshuffle()
else:
rval = res
if rval.type.dtype in integer_dtypes:
# We check that the upcast of res and dtype won't change dtype.
# If dtype is float64, we will cast int64 to float64.
# This is valid when res is a scalar used as input to a dot22
# as the cast of the scalar can be done before or after the dot22
# and this will give the same result.
if pytensor.scalar.upcast(res.dtype, dtype) == dtype:
return ptb.cast(rval, dtype)
else:
return None
return rval
def _is_real_matrix(res):
return (
res.type.dtype in ("float16", "float32", "float64")
and res.type.ndim == 2
and res.type.shape[0] != 1
and res.type.shape[1] != 1
) # cope with tuple vs. list
def _is_real_vector(res):
return (
res.type.dtype in ("float16", "float32", "float64")
and res.type.ndim == 1
and res.type.shape[0] != 1
)
def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True):
# print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
# EXPRESSION: (beta * L) + (alpha * M)
# we've already checked the client counts, now just make the type check.
# if res_is_a(M, _dot22, 1):
if M.owner and M.owner.op == _dot22:
Ml, Mr = M.owner.inputs
rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
return rval, M
# it also might be the case that there is a dimshuffle between the +
# and the dot22. local_dot_to_dot22 in particular will put in such things.
if (
M.owner
and isinstance(M.owner.op, DimShuffle)
and M.owner.inputs[0].owner
and isinstance(M.owner.inputs[0].owner.op, Dot22)
):
MM = M.owner.inputs[0]
if M.owner.op.new_order == (0,):
# it is making a column MM into a vector
MMl, MMr = MM.owner.inputs
g = gemm_no_inplace(L.dimshuffle(0, "x"), alpha, MMl, MMr, beta)
rval = [g.dimshuffle(0)]
return rval, MM
if M.owner.op.new_order == (1,):
# it is making a row MM into a vector
MMl, MMr = MM.owner.inputs
g = gemm_no_inplace(L.dimshuffle("x", 0), alpha, MMl, MMr, beta)
rval = [g.dimshuffle(1)]
return rval, MM
if len(M.owner.op.new_order) == 0:
# it is making a row MM into a vector
MMl, MMr = MM.owner.inputs
g = gemm_no_inplace(L.dimshuffle("x", "x"), alpha, MMl, MMr, beta)
rval = [g.dimshuffle()]
return rval, MM
if recurse_flip:
return _beta_L_plus_alpha_M(fgraph, alpha, M, beta, L, recurse_flip=False)
else:
return False, False
def _gemm_canonicalize(fgraph, r, scale, rval, maxclients):
# Tries to interpret node as a sum of scalars * (vectors or matrices)
def scaled(thing):
if scale == 1:
return thing
if scale == -1 and thing.type.dtype != "bool":
return -thing
else:
return scale * thing
if not isinstance(r.type, TensorType):
return None
if (r.type.ndim not in (1, 2)) or r.type.dtype not in (
"float16",
"float32",
"float64",
"complex64",
"complex128",
):
rval.append(scaled(r))
return rval
if maxclients and len(fgraph.clients[r]) > maxclients:
rval.append((scale, r))
return rval
if r.owner and r.owner.op == sub:
_gemm_canonicalize(fgraph, r.owner.inputs[0], scale, rval, 1)
_gemm_canonicalize(fgraph, r.owner.inputs[1], -scale, rval, 1)
elif r.owner and r.owner.op == add:
for i in r.owner.inputs:
_gemm_canonicalize(fgraph, i, scale, rval, 1)
elif r.owner and r.owner.op == neg:
_gemm_canonicalize(fgraph, r.owner.inputs[0], -scale, rval, 1)
elif r.owner and r.owner.op == mul:
scalars = []
vectors = []
matrices = []
for i in r.owner.inputs:
if all(s == 1 for s in i.type.shape):
while i.owner and isinstance(i.owner.op, DimShuffle):
i = i.owner.inputs[0]
if i.type.ndim > 0:
scalars.append(i.dimshuffle())
else:
scalars.append(i)
elif _is_real_vector(i):
vectors.append(i)
elif _is_real_matrix(i):
matrices.append(i)
else:
# just put the original arguments as in the base case
rval.append((scale, r))
return rval
if len(matrices) == 1:
assert len(vectors) == 0
m = matrices[0]
if len(scalars) == 0:
_gemm_canonicalize(fgraph, m, scale, rval, 1)
elif len(scalars) == 1:
_gemm_canonicalize(fgraph, m, scaled(scalars[0]), rval, 1)
else:
_gemm_canonicalize(
fgraph, m, mul(scaled(scalars[0]), *scalars[1:]), rval, 1
)
elif len(vectors) == 1:
assert len(matrices) == 0
v = vectors[0]
if len(scalars) == 0:
_gemm_canonicalize(fgraph, v, scale, rval, 1)
elif len(scalars) == 1:
_gemm_canonicalize(fgraph, v, scaled(scalars[0]), rval, 1)
else:
_gemm_canonicalize(
fgraph, v, mul(scaled(scalars[0]), *scalars[1:]), rval, 1
)
else: # lets not open this up
rval.append((scale, r))
else:
rval.append((scale, r))
return rval
def _factor_canonicalized(lst):
# remove duplicates from canonicalized list
# we only delete out of the right end of the list,
# once i has touched a list element, it is permantent
lst = list(lst)
# print 'FACTOR', lst
# for t in lst:
# if not isinstance(t, (list, tuple)):
# t = (t,)
# for e in t:
# try:
# pytensor.printing.debugprint(e)
# except TypeError:
# print e, type(e)
i = 0
while i < len(lst) - 1:
try:
s_i, M_i = lst[i]
except Exception:
i += 1
continue
j = i + 1
while j < len(lst):
try:
s_j, M_j = lst[j]
except Exception:
j += 1
continue
if M_i is M_j:
s_i = s_i + s_j
lst[i] = (s_i, M_i)
del lst[j]
else:
j += 1
i += 1
return lst
def _gemm_from_factored_list(fgraph, lst):
"""
Returns None, or a list to replace node.outputs.
"""
lst2 = []
# Remove the tuple that can't be cast correctly.
# This can happen when we try to cast a complex to a real
for sM in lst:
# Make every pair in list have matching dtypes
# sM can be a tuple of 2 elements or an PyTensor variable.
if isinstance(sM, tuple):
sm0, sm1 = sM
sm0 = ptb.as_tensor_variable(sm0)
if pytensor.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype:
lst2.append((ptb.cast(sm0, sm1.dtype), sM[1]))
lst = lst2
def item_to_var(t):
try:
s, M = t
except Exception:
return t
if s == 1:
return M
if s == -1:
return -M
return s * M
# Try every pair in the sM_list, trying to turn it into a gemm operation
for i in range(len(lst) - 1):
s_i, M_i = lst[i]
for j in range(i + 1, len(lst)):
s_j, M_j = lst[j]
if not M_j.type.in_same_class(M_i.type):
continue
# print 'TRYING', (s_i, M_i, s_j, M_j)
gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M(
fgraph, s_i, M_i, s_j, M_j
)
# print 'GOT IT', gemm_of_sM_list
if gemm_of_sM_list:
assert len(gemm_of_sM_list) == 1
add_inputs = [
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
]
add_inputs.extend(gemm_of_sM_list)
rval = [variadic_add(*add_inputs)]
return rval, old_dot22
def _gemm_from_node2(fgraph, node):
"""
TODO: In many expressions, there are many ways to turn it into a
gemm. For example dot(a,b) + c + d. This function should return all
of them, so that if one version of gemm causes a cycle in the graph, then
another application of gemm can be tried.
"""
lst = []
t0 = time.perf_counter()
_gemm_canonicalize(fgraph, node.outputs[0], 1.0, lst, 0)
t1 = time.perf_counter()
if len(lst) > 1:
lst = _factor_canonicalized(lst)
t2 = time.perf_counter()
rval = _gemm_from_factored_list(fgraph, lst)
t3 = time.perf_counter()
# It can happen that _factor_canonicalized and
# _gemm_from_factored_list return a node with an incorrect
# type. This happens in particular when one of the scalar
# factors forces the upcast of the whole expression. In that
# case, we simply skip that candidate for Gemm. This was
# discussed in
# http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5,
# but never made it into a trac ticket.
if rval and rval[0][0].type.in_same_class(node.outputs[0].type):
return rval, t1 - t0, t2 - t1, t3 - t2
return None, t1 - t0, 0, 0
class Dot22(GemmRelated):
"""Compute a matrix-matrix product.
......
......@@ -2,11 +2,39 @@ import numpy as np
import pytest
from pytensor import function
from pytensor import tensor as pt
from pytensor.compile import get_default_mode
from pytensor.tensor import matmul, tensor, vectorize
from pytensor.graph import FunctionGraph
from pytensor.tensor import (
col,
dscalar,
dvector,
matmul,
matrix,
mul,
neg,
row,
scalar,
sqrt,
tensor,
vector,
vectorize,
)
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.blas import (
_as_scalar,
_factor_canonicalized,
_gemm_canonicalize,
_is_real_matrix,
res_is_a,
specialize_matmul_to_batched_dot,
)
def XYZab():
return matrix(), matrix(), matrix(), scalar(), scalar()
@pytest.mark.parametrize("valid_case", (True, False))
......@@ -46,3 +74,136 @@ def test_specialize_matmul_to_batched_dot(valid_case):
vectorize_pt(x_test, y_test),
vectorize_np(x_test, y_test),
)
def test_gemm_factor():
X, Y = matrix("X"), matrix("Y")
assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)])
assert [(2.0, X)] == _factor_canonicalized([(1.0, X), (1.0, X)])
def test_gemm_canonicalize():
X, Y, Z, a, b = (
matrix("X"),
matrix("Y"),
matrix("Z"),
scalar("a"),
scalar("b"),
)
c, d = scalar("c"), scalar("d")
u = row("u")
v = vector("v")
w = col("w")
can = []
fg = FunctionGraph([X, Y, Z], [X + Y + Z], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), (1.0, Z)]
can = []
fg = FunctionGraph([X, Y, u], [X + Y + u], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), (1.0, u)], can
can = []
fg = FunctionGraph([X, Y, v], [X + Y + v], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
# [(1.0, X), (1.0, Y), (1.0, InplaceDimShuffle{x,0}(v))]
assert can[:2] == [(1.0, X), (1.0, Y)]
assert isinstance(can[2], tuple)
assert len(can[2]) == 2
assert can[2][0] == 1.0
assert can[2][1].owner
assert isinstance(can[2][1].owner.op, DimShuffle)
assert can[2][1].owner.inputs == [v]
can = []
fg = FunctionGraph([X, Y, w], [X + Y + w], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), (1.0, w)], can
can = []
fg = FunctionGraph([a, X, Y, b, Z, c], [a * X + Y - b * Z * c], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can[0] == (a, X)
assert can[1] == (1.0, Y)
assert can[2][0].owner.op == mul
assert can[2][0].owner.inputs[0].owner.op == neg
assert can[2][0].owner.inputs[0].owner.inputs[0] == c
assert can[2][0].owner.inputs[1] == b
can = []
fg = FunctionGraph(
[a, X, Y, b, Z, c, d], [(-d) * X - (a * X + Y - b * Z * c)], clone=False
)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can[0][0].owner.op == neg
assert can[0][0].owner.inputs[0] == d
assert can[0][1] == X
assert can[1][0].owner.op == neg
assert can[1][0].owner.inputs[0] == a
assert can[2] == (-1.0, Y)
assert can[3][0].owner.op == mul
assert can[3][0].owner.inputs == [c, b]
def test_res_is_a():
X, Y, Z, a, b = XYZab()
assert not res_is_a(None, a, sqrt)
assert not res_is_a(None, a + a, sqrt)
assert res_is_a(None, sqrt(a + a), sqrt)
sqrt_term = sqrt(a + a)
fg = FunctionGraph([a], [2 * sqrt_term], clone=False)
assert res_is_a(fg, sqrt_term, sqrt, 2)
assert not res_is_a(fg, sqrt_term, sqrt, 0)
class TestAsScalar:
def test_basic(self):
# Test that it works on scalar constants
a = pt.constant(2.5)
b = pt.constant(np.asarray([[[0.5]]]))
b2 = b.dimshuffle()
assert b2.ndim == 0
d_a = DimShuffle(input_ndim=0, new_order=[])(a)
d_b = DimShuffle(input_ndim=3, new_order=[0, 2, 1])(b)
d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x", "x"])(a)
assert _as_scalar(a) == a
assert _as_scalar(b) != b
assert _as_scalar(d_a) != d_a
assert _as_scalar(d_b) != d_b
assert _as_scalar(d_a2) != d_a2
def test_basic_1(self):
# Test that it fails on nonscalar constants
a = pt.constant(np.ones(5))
assert _as_scalar(a) is None
assert _as_scalar(DimShuffle(input_ndim=1, new_order=[0, "x"])(a)) is None
def test_basic_2(self):
# Test that it works on scalar variables
a = dscalar()
d_a = DimShuffle(input_ndim=0, new_order=[])(a)
d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x"])(a)
assert _as_scalar(a) is a
assert _as_scalar(d_a) is a
assert _as_scalar(d_a2) is a
def test_basic_3(self):
# Test that it fails on nonscalar variables
a = matrix()
assert _as_scalar(a) is None
assert _as_scalar(DimShuffle(input_ndim=2, new_order=[0, "x", 1])(a)) is None
class TestRealMatrix:
def test_basic(self):
assert _is_real_matrix(DimShuffle(input_ndim=2, new_order=[1, 0])(matrix()))
assert not _is_real_matrix(
DimShuffle(input_ndim=1, new_order=["x", 0])(dvector())
)
......@@ -16,7 +16,6 @@ from pytensor.compile.mode import Mode
from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config
from pytensor.gradient import grad
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import in2out
from pytensor.graph.utils import InconsistencyError
from pytensor.tensor import inplace
......@@ -28,12 +27,8 @@ from pytensor.tensor.blas import (
Gemm,
Gemv,
Ger,
_as_scalar,
_dot22,
_dot22scalar,
_factor_canonicalized,
_gemm_canonicalize,
_is_real_matrix,
batched_dot,
batched_tensordot,
gemm,
......@@ -44,19 +39,15 @@ from pytensor.tensor.blas import (
gemv_no_inplace,
ger,
ger_destructive,
res_is_a,
)
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Dot, dot, mean, mul, neg, outer, sigmoid, sqrt
from pytensor.tensor.math import Dot, dot, mean, mul, outer, sigmoid
from pytensor.tensor.rewriting.blas import local_dot22_to_dot22scalar, local_gemm_to_ger
from pytensor.tensor.type import (
cmatrix,
col,
cscalar,
dmatrix,
drow,
dscalar,
dvector,
fmatrix,
fscalar,
imatrix,
......@@ -65,7 +56,6 @@ from pytensor.tensor.type import (
ivector,
matrices,
matrix,
row,
scalar,
scalars,
tensor,
......@@ -572,67 +562,6 @@ class TestGemmNoFlags:
self.run_gemm(dtype, alpha, beta, tA, tB, tC, sA, sB, sC, rng)
def test_res_is_a():
X, Y, Z, a, b = XYZab()
assert not res_is_a(None, a, sqrt)
assert not res_is_a(None, a + a, sqrt)
assert res_is_a(None, sqrt(a + a), sqrt)
sqrt_term = sqrt(a + a)
fg = FunctionGraph([a], [2 * sqrt_term], clone=False)
assert res_is_a(fg, sqrt_term, sqrt, 2)
assert not res_is_a(fg, sqrt_term, sqrt, 0)
class TestAsScalar:
def test_basic(self):
# Test that it works on scalar constants
a = pt.constant(2.5)
b = pt.constant(np.asarray([[[0.5]]]))
b2 = b.dimshuffle()
assert b2.ndim == 0
d_a = DimShuffle(input_ndim=0, new_order=[])(a)
d_b = DimShuffle(input_ndim=3, new_order=[0, 2, 1])(b)
d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x", "x"])(a)
assert _as_scalar(a) == a
assert _as_scalar(b) != b
assert _as_scalar(d_a) != d_a
assert _as_scalar(d_b) != d_b
assert _as_scalar(d_a2) != d_a2
def test_basic_1(self):
# Test that it fails on nonscalar constants
a = pt.constant(np.ones(5))
assert _as_scalar(a) is None
assert _as_scalar(DimShuffle(input_ndim=1, new_order=[0, "x"])(a)) is None
def test_basic_2(self):
# Test that it works on scalar variables
a = dscalar()
d_a = DimShuffle(input_ndim=0, new_order=[])(a)
d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x"])(a)
assert _as_scalar(a) is a
assert _as_scalar(d_a) is a
assert _as_scalar(d_a2) is a
def test_basic_3(self):
# Test that it fails on nonscalar variables
a = matrix()
assert _as_scalar(a) is None
assert _as_scalar(DimShuffle(input_ndim=2, new_order=[0, "x", 1])(a)) is None
class TestRealMatrix:
def test_basic(self):
assert _is_real_matrix(DimShuffle(input_ndim=2, new_order=[1, 0])(matrix()))
assert not _is_real_matrix(
DimShuffle(input_ndim=1, new_order=["x", 0])(dvector())
)
"""
This test suite ensures that Gemm is inserted where it belongs, and
that the resulting functions compute the same things as the originals.
......@@ -774,78 +703,6 @@ def test_gemm_opt_double_gemm():
assert max_abs_err <= eps, "GEMM is computing the wrong output. max_rel_err ="
def test_gemm_canonicalize():
X, Y, Z, a, b = (
matrix("X"),
matrix("Y"),
matrix("Z"),
scalar("a"),
scalar("b"),
)
c, d = scalar("c"), scalar("d")
u = row("u")
v = vector("v")
w = col("w")
can = []
fg = FunctionGraph([X, Y, Z], [X + Y + Z], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), (1.0, Z)]
can = []
fg = FunctionGraph([X, Y, u], [X + Y + u], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), (1.0, u)], can
can = []
fg = FunctionGraph([X, Y, v], [X + Y + v], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
# [(1.0, X), (1.0, Y), (1.0, InplaceDimShuffle{x,0}(v))]
assert can[:2] == [(1.0, X), (1.0, Y)]
assert isinstance(can[2], tuple)
assert len(can[2]) == 2
assert can[2][0] == 1.0
assert can[2][1].owner
assert isinstance(can[2][1].owner.op, DimShuffle)
assert can[2][1].owner.inputs == [v]
can = []
fg = FunctionGraph([X, Y, w], [X + Y + w], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can == [(1.0, X), (1.0, Y), (1.0, w)], can
can = []
fg = FunctionGraph([a, X, Y, b, Z, c], [a * X + Y - b * Z * c], clone=False)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can[0] == (a, X)
assert can[1] == (1.0, Y)
assert can[2][0].owner.op == mul
assert can[2][0].owner.inputs[0].owner.op == neg
assert can[2][0].owner.inputs[0].owner.inputs[0] == c
assert can[2][0].owner.inputs[1] == b
can = []
fg = FunctionGraph(
[a, X, Y, b, Z, c, d], [(-d) * X - (a * X + Y - b * Z * c)], clone=False
)
_gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0)
assert can[0][0].owner.op == neg
assert can[0][0].owner.inputs[0] == d
assert can[0][1] == X
assert can[1][0].owner.op == neg
assert can[1][0].owner.inputs[0] == a
assert can[2] == (-1.0, Y)
assert can[3][0].owner.op == mul
assert can[3][0].owner.inputs == [c, b]
def test_gemm_factor():
X, Y = matrix("X"), matrix("Y")
assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)])
assert [(2.0, X)] == _factor_canonicalized([(1.0, X), (1.0, X)])
def test_upcasting_scalar_nogemm():
# Test that the optimization does not crash when the scale has an incorrect
# dtype, and forces upcasting of the result
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论