提交 a3dc0a72 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Broadcast input matrices in Gemm

上级 89353bd2
...@@ -138,7 +138,6 @@ try: ...@@ -138,7 +138,6 @@ try:
except ImportError: except ImportError:
pass pass
from functools import reduce
from typing import Tuple, Union from typing import Tuple, Union
import aesara.scalar import aesara.scalar
...@@ -630,8 +629,10 @@ class GemmRelated(COp): ...@@ -630,8 +629,10 @@ class GemmRelated(COp):
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
""" """
# broadcast_xy = None
check_dims = """ check_dims = """
if (Nx[0] != Nz[0]) if (Nx[0] !=1 && Nz[0] != 1 && Nx[0] != Nz[0])
{ {
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"Shape mismatch: x has %%ld rows but z has %%ld rows", "Shape mismatch: x has %%ld rows but z has %%ld rows",
...@@ -645,7 +646,7 @@ class GemmRelated(COp): ...@@ -645,7 +646,7 @@ class GemmRelated(COp):
(long int)Nx[1], (long int)Nx[0], (long int)Ny[0], (long int)Ny[1]); (long int)Nx[1], (long int)Nx[0], (long int)Ny[0], (long int)Ny[1]);
%(fail)s; %(fail)s;
} }
if (Ny[1] != Nz[1]) if (Ny[1] != 1 && Nz[1]!= 1 && Ny[1] != Nz[1])
{ {
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"Shape mismatch: y has %%ld cols but z has %%ld cols", "Shape mismatch: y has %%ld cols but z has %%ld cols",
...@@ -822,14 +823,14 @@ class GemmRelated(COp): ...@@ -822,14 +823,14 @@ class GemmRelated(COp):
else: else:
setup_z_Nz_Sz = self.setup_z_Nz_Sz setup_z_Nz_Sz = self.setup_z_Nz_Sz
return reduce( return "".join(
str.__add__,
( (
self.declare_NS, self.declare_NS,
self.check_xyz_rank2, self.check_xyz_rank2,
setup_z_Nz_Sz, setup_z_Nz_Sz,
self.check_xyz_double_or_float, self.check_xyz_double_or_float,
self.check_ab_double_or_float, self.check_ab_double_or_float,
self.broadcast_xy,
self.check_dims, self.check_dims,
self.check_strides, self.check_strides,
self.encode_strides_in_unit, self.encode_strides_in_unit,
...@@ -842,8 +843,7 @@ class GemmRelated(COp): ...@@ -842,8 +843,7 @@ class GemmRelated(COp):
self.case_double_ab_constants, self.case_double_ab_constants,
self.case_double_gemm, self.case_double_gemm,
self.end_switch_typenum, self.end_switch_typenum,
), )
"",
) )
def build_gemm_version(self): def build_gemm_version(self):
...@@ -973,6 +973,11 @@ class Gemm(GemmRelated): ...@@ -973,6 +973,11 @@ class Gemm(GemmRelated):
z.itemset(z * a + b * np.dot(x, y)) z.itemset(z * a + b * np.dot(x, y))
zout[0] = z zout[0] = z
else: else:
# Broadcast Z if needed
if (x.shape[0] > z.shape[0]) or (y.shape[1] > z.shape[1]):
z = np.broadcast_to(
z, (max(x.shape[0], z.shape[0]), max(y.shape[1], z.shape[1]))
).copy()
if b == 0.0: if b == 0.0:
if a == 1.0: if a == 1.0:
z[:] = np.dot(x, y) z[:] = np.dot(x, y)
...@@ -993,88 +998,135 @@ class Gemm(GemmRelated): ...@@ -993,88 +998,135 @@ class Gemm(GemmRelated):
zout[0] = z zout[0] = z
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]] z_shape, _, x_shape, y_shape, _ = input_shapes
return [
(
aesara.scalar.scalar_maximum(z_shape[0], x_shape[0]),
aesara.scalar.scalar_maximum(z_shape[1], y_shape[1]),
)
]
setup_z_Nz_Sz_inplace = """ setup_z_Nz_Sz_inplace = """
if (%(_zout)s != %(_z)s) // Needs broadcasting
{ if (PyArray_DIMS(%(_z)s)[0] < Nx[0] || PyArray_DIMS(%(_z)s)[1] < Ny[1]){
if (%(_zout)s)
npy_intp dims[2];
dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0];
dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1];
// Check if we need to allocate new array
if((NULL == %(_zout)s)
|| (PyArray_DIMS(%(_zout)s)[0] != dims[0])
|| (PyArray_DIMS(%(_zout)s)[1] != dims[1]))
{ {
Py_DECREF(%(_zout)s); // fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\n", dims[0], dims[1]);
Py_XDECREF(%(_zout)s);
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s));
}
// fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\n", dims[0], dims[1]);
if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1)
{
%(fail)s;
}
} else {
if (%(_zout)s != %(_z)s)
{
Py_XDECREF(%(_zout)s);
%(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s);
} }
%(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s);
} }
Nz = PyArray_DIMS(%(_z)s);
Sz = PyArray_STRIDES(%(_z)s); Nz = PyArray_DIMS(%(_zout)s);
Sz = PyArray_STRIDES(%(_zout)s);
""" """
setup_z_Nz_Sz_outplace = """ setup_z_Nz_Sz_outplace = """
npy_intp dims[2];
dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0];
dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1];
// Check if we need to allocate new array
if ((NULL == %(_zout)s) if ((NULL == %(_zout)s)
|| (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_z)s)[0]) || (PyArray_DIMS(%(_zout)s)[0] != dims[0])
|| (PyArray_DIMS(%(_zout)s)[1] != PyArray_DIMS(%(_z)s)[1]) || (PyArray_DIMS(%(_zout)s)[1] != dims[1]))
|| (PyArray_STRIDES(%(_zout)s)[0] <= 0)
|| (PyArray_STRIDES(%(_zout)s)[1] <= 0)
|| (PyArray_STRIDES(%(_zout)s)[0] MOD type_size)
|| (PyArray_STRIDES(%(_zout)s)[1] MOD type_size)
|| ((PyArray_STRIDES(%(_zout)s)[0] != type_size)
&& (PyArray_STRIDES(%(_zout)s)[1] != type_size)))
{ {
Py_XDECREF(%(_zout)s); Py_XDECREF(%(_zout)s);
npy_intp dims[2]; %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s));
dims[0] = PyArray_DIMS(%(_z)s)[0]; // fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\n", dims[0], dims[1]);
dims[1] = PyArray_DIMS(%(_z)s)[1];
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
PyArray_TYPE(%(_z)s));
//fprintf(stderr, "Gemm Allocating %%i %%i\\n", dims[0], dims[1]);
if(!%(_zout)s) { if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemm_no_inplace output"); "failed to alloc gemm_no_inplace output");
%(fail)s %(fail)s
} }
} }
// fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\n", dims[0], dims[1]);
if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1)
{
%(fail)s
}
Nz = PyArray_DIMS(%(_zout)s); Nz = PyArray_DIMS(%(_zout)s);
Sz = PyArray_STRIDES(%(_zout)s); Sz = PyArray_STRIDES(%(_zout)s);
"""
if (PyArray_DESCR(%(_zout)s)->type_num == NPY_FLOAT) broadcast_xy = """
// Broadcast X if needed
if (Nz[0] > Nx[0])
{ {
float * zoutdata = (float*)PyArray_DATA(%(_zout)s); npy_intp dims[2];
int zoi = Sz[0] / sizeof(float); dims[0] = Nz[0];
int zoj = Sz[1] / sizeof(float); dims[1] = Nx[1];
const float * zdata = (float*)PyArray_DATA(%(_z)s); // fprintf(stderr, "Gemm Broadcasting X into shape (%%i %%i)\\n", dims[0], dims[1]);
int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(float); PyArrayObject *x_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s));
int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(float); if(!x_new) {
for (int i = 0; i < Nz[0]; ++i) PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemm_inplace input");
%(fail)s
}
if(PyArray_MoveInto(x_new, %(_x)s) == -1)
{ {
for (int j = 0; j < Nz[1]; ++j) %(fail)s
{
zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
}
} }
Py_DECREF(%(_x)s);
%(_x)s = x_new;
Nx = PyArray_DIMS(%(_x)s);
Sx = PyArray_STRIDES(%(_x)s);
} }
else if (PyArray_DESCR(%(_zout)s)->type_num == NPY_DOUBLE)
// Broadcast Y if needed
if (Nz[1] > Ny[1])
{ {
double * zoutdata = (double*) PyArray_DATA(%(_zout)s); npy_intp dims[2];
int zoi = Sz[0] / sizeof(double); dims[0] = Ny[0];
int zoj = Sz[1] / sizeof(double); dims[1] = Nz[1];
const double * zdata = (double*)PyArray_DATA(%(_z)s); // fprintf(stderr, "Gemm Broadcasting Y into shape (%%i %%i)\\n", dims[0], dims[1]);
int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(double); PyArrayObject *y_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s));
int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(double); if(!y_new) {
for (int i = 0; i < Nz[0]; ++i) PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemm_inplace input");
%(fail)s
}
if(PyArray_MoveInto(y_new, %(_y)s) == -1)
{ {
for (int j = 0; j < Nz[1]; ++j) %(fail)s
{
zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
}
} }
Py_DECREF(%(_y)s);
%(_y)s = y_new;
Ny = PyArray_DIMS(%(_y)s);
Sy = PyArray_STRIDES(%(_y)s);
} }
else
{ """
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
"""
case_float_ab_constants = """ case_float_ab_constants = """
#define REAL float #define REAL float
...@@ -1108,7 +1160,7 @@ class Gemm(GemmRelated): ...@@ -1108,7 +1160,7 @@ class Gemm(GemmRelated):
def c_code_cache_version(self): def c_code_cache_version(self):
gv = self.build_gemm_version() gv = self.build_gemm_version()
if gv: if gv:
return (6,) + gv return (7,) + gv
else: else:
return gv return gv
...@@ -1182,7 +1234,6 @@ def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True): ...@@ -1182,7 +1234,6 @@ def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True):
if M.owner and M.owner.op == _dot22: if M.owner and M.owner.op == _dot22:
Ml, Mr = M.owner.inputs Ml, Mr = M.owner.inputs
rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)] rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
# print 'GEMM 0', rval, beta, L, alpha, M
return rval, M return rval, M
# it also might be the case that there is a dimshuffle between the + # it also might be the case that there is a dimshuffle between the +
...@@ -1650,6 +1701,7 @@ class Dot22(GemmRelated): ...@@ -1650,6 +1701,7 @@ class Dot22(GemmRelated):
Sz = PyArray_STRIDES(%(_zout)s); Sz = PyArray_STRIDES(%(_zout)s);
""" """
broadcast_xy = ""
check_ab_double_or_float = "" check_ab_double_or_float = ""
case_float_ab_constants = """ case_float_ab_constants = """
float a = 1.0; float a = 1.0;
...@@ -1933,6 +1985,7 @@ class Dot22Scalar(GemmRelated): ...@@ -1933,6 +1985,7 @@ class Dot22Scalar(GemmRelated):
return [[input_shapes[0][0], input_shapes[1][1]]] return [[input_shapes[0][0], input_shapes[1][1]]]
setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz
broadcast_xy = ""
check_ab_double_or_float = """ check_ab_double_or_float = """
if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE) if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE)
......
from copy import copy from copy import copy
from itertools import product from itertools import product
from random import shuffle
import numpy as np import numpy as np
import pytest import pytest
...@@ -67,6 +68,7 @@ from aesara.tensor.type import ( ...@@ -67,6 +68,7 @@ from aesara.tensor.type import (
matrix, matrix,
row, row,
scalar, scalar,
scalars,
tensor, tensor,
tensor3, tensor3,
tensor4, tensor4,
...@@ -1042,6 +1044,41 @@ def test_inplace1(): ...@@ -1042,6 +1044,41 @@ def test_inplace1():
assert [n.op for n in f.maker.fgraph.apply_nodes] == [gemm_no_inplace] assert [n.op for n in f.maker.fgraph.apply_nodes] == [gemm_no_inplace]
@pytest.mark.parametrize("linker", ("py", "cvm"))
@pytest.mark.parametrize("inplace", (False, True))
def test_gemm_broadcasting(inplace, linker):
a, b = scalars("a", "b")
z, x, y = matrices("z", "x", "y")
mode = Mode(linker=linker)
if inplace:
out = gemm_inplace(z, a, x, y, b)
f = aesara.function([z, x, y, a, b], out, accept_inplace=True, mode=mode)
assert [node.op for node in f.maker.fgraph.toposort()] == [gemm_inplace]
else:
out = gemm_no_inplace(z, a, x, y, b)
f = aesara.function([z, x, y, a, b], out, mode=mode)
assert [node.op for node in f.maker.fgraph.toposort()] == [gemm_no_inplace]
shapes_z = [(5, 3), (1, 3), (5, 1), (1, 1)]
shapes_x = [(5, 4), (1, 4)]
shapes_y = [(4, 3), (4, 1)]
rng = np.random.default_rng()
shuffle(shapes_z)
shuffle(shapes_x)
shuffle(shapes_y)
for shape_z, shape_x, shape_y in product(shapes_z, shapes_x, shapes_y):
z_v = rng.random(size=shape_z).astype(config.floatX)
x_v = rng.random(size=shape_x).astype(config.floatX)
y_v = rng.random(size=shape_y).astype(config.floatX)
# We have to copy for the inplace case
z_v_np = z_v.copy()
np.testing.assert_allclose(
f(z_v, x_v, y_v, 1, 1), z_v_np + np.dot(x_v, y_v), atol=2e-6
)
def test_dot22(): def test_dot22():
for dtype1 in ["float32", "float64", "complex64", "complex128"]: for dtype1 in ["float32", "float64", "complex64", "complex128"]:
a = matrix(dtype=dtype1) a = matrix(dtype=dtype1)
...@@ -2476,6 +2513,40 @@ class TestInferShape(unittest_tools.InferShapeTester): ...@@ -2476,6 +2513,40 @@ class TestInferShape(unittest_tools.InferShapeTester):
Gemm, Gemm,
) )
def test_gemm_broadcast(self):
rng = np.random.default_rng(unittest_tools.fetch_seed())
x, y, z = matrices("xyz")
a = scalar("a")
b = scalar("b")
# Broadcast Z
self._compile_and_check(
[x, y, a, z, b],
[gemm(z, a, x, y, b)],
[
rng.random((2, 3)).astype(config.floatX),
rng.random((3, 4)).astype(config.floatX),
np.asarray(0.5, dtype=config.floatX),
rng.random((1, 4)).astype(config.floatX),
np.asarray(0.5, dtype=config.floatX),
],
Gemm,
)
# Broadcast dot(X, Y)
self._compile_and_check(
[x, y, a, z, b],
[gemm(z, a, x, y, b)],
[
rng.random((1, 3)).astype(config.floatX),
rng.random((3, 4)).astype(config.floatX),
np.asarray(0.5, dtype=config.floatX),
rng.random((5, 4)).astype(config.floatX),
np.asarray(1, dtype=config.floatX),
],
Gemm,
)
def test_gemv(self): def test_gemv(self):
rng = np.random.default_rng(unittest_tools.fetch_seed()) rng = np.random.default_rng(unittest_tools.fetch_seed())
A = matrix("A") A = matrix("A")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论