提交 a3dc0a72 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Broadcast input matrices in Gemm

上级 89353bd2
......@@ -138,7 +138,6 @@ try:
except ImportError:
pass
from functools import reduce
from typing import Tuple, Union
import aesara.scalar
......@@ -630,8 +629,10 @@ class GemmRelated(COp):
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
"""
# broadcast_xy = None
check_dims = """
if (Nx[0] != Nz[0])
if (Nx[0] !=1 && Nz[0] != 1 && Nx[0] != Nz[0])
{
PyErr_Format(PyExc_ValueError,
"Shape mismatch: x has %%ld rows but z has %%ld rows",
......@@ -645,7 +646,7 @@ class GemmRelated(COp):
(long int)Nx[1], (long int)Nx[0], (long int)Ny[0], (long int)Ny[1]);
%(fail)s;
}
if (Ny[1] != Nz[1])
if (Ny[1] != 1 && Nz[1]!= 1 && Ny[1] != Nz[1])
{
PyErr_Format(PyExc_ValueError,
"Shape mismatch: y has %%ld cols but z has %%ld cols",
......@@ -822,14 +823,14 @@ class GemmRelated(COp):
else:
setup_z_Nz_Sz = self.setup_z_Nz_Sz
return reduce(
str.__add__,
return "".join(
(
self.declare_NS,
self.check_xyz_rank2,
setup_z_Nz_Sz,
self.check_xyz_double_or_float,
self.check_ab_double_or_float,
self.broadcast_xy,
self.check_dims,
self.check_strides,
self.encode_strides_in_unit,
......@@ -842,8 +843,7 @@ class GemmRelated(COp):
self.case_double_ab_constants,
self.case_double_gemm,
self.end_switch_typenum,
),
"",
)
)
def build_gemm_version(self):
......@@ -973,6 +973,11 @@ class Gemm(GemmRelated):
z.itemset(z * a + b * np.dot(x, y))
zout[0] = z
else:
# Broadcast Z if needed
if (x.shape[0] > z.shape[0]) or (y.shape[1] > z.shape[1]):
z = np.broadcast_to(
z, (max(x.shape[0], z.shape[0]), max(y.shape[1], z.shape[1]))
).copy()
if b == 0.0:
if a == 1.0:
z[:] = np.dot(x, y)
......@@ -993,88 +998,135 @@ class Gemm(GemmRelated):
zout[0] = z
def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]
z_shape, _, x_shape, y_shape, _ = input_shapes
return [
(
aesara.scalar.scalar_maximum(z_shape[0], x_shape[0]),
aesara.scalar.scalar_maximum(z_shape[1], y_shape[1]),
)
]
setup_z_Nz_Sz_inplace = """
if (%(_zout)s != %(_z)s)
{
if (%(_zout)s)
// Needs broadcasting
if (PyArray_DIMS(%(_z)s)[0] < Nx[0] || PyArray_DIMS(%(_z)s)[1] < Ny[1]){
npy_intp dims[2];
dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0];
dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1];
// Check if we need to allocate new array
if((NULL == %(_zout)s)
|| (PyArray_DIMS(%(_zout)s)[0] != dims[0])
|| (PyArray_DIMS(%(_zout)s)[1] != dims[1]))
{
Py_DECREF(%(_zout)s);
// fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\n", dims[0], dims[1]);
Py_XDECREF(%(_zout)s);
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s));
}
// fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\n", dims[0], dims[1]);
if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1)
{
%(fail)s;
}
} else {
if (%(_zout)s != %(_z)s)
{
Py_XDECREF(%(_zout)s);
%(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s);
}
%(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s);
}
Nz = PyArray_DIMS(%(_z)s);
Sz = PyArray_STRIDES(%(_z)s);
Nz = PyArray_DIMS(%(_zout)s);
Sz = PyArray_STRIDES(%(_zout)s);
"""
setup_z_Nz_Sz_outplace = """
npy_intp dims[2];
dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0];
dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1];
// Check if we need to allocate new array
if ((NULL == %(_zout)s)
|| (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_z)s)[0])
|| (PyArray_DIMS(%(_zout)s)[1] != PyArray_DIMS(%(_z)s)[1])
|| (PyArray_STRIDES(%(_zout)s)[0] <= 0)
|| (PyArray_STRIDES(%(_zout)s)[1] <= 0)
|| (PyArray_STRIDES(%(_zout)s)[0] MOD type_size)
|| (PyArray_STRIDES(%(_zout)s)[1] MOD type_size)
|| ((PyArray_STRIDES(%(_zout)s)[0] != type_size)
&& (PyArray_STRIDES(%(_zout)s)[1] != type_size)))
|| (PyArray_DIMS(%(_zout)s)[0] != dims[0])
|| (PyArray_DIMS(%(_zout)s)[1] != dims[1]))
{
Py_XDECREF(%(_zout)s);
npy_intp dims[2];
dims[0] = PyArray_DIMS(%(_z)s)[0];
dims[1] = PyArray_DIMS(%(_z)s)[1];
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims,
PyArray_TYPE(%(_z)s));
//fprintf(stderr, "Gemm Allocating %%i %%i\\n", dims[0], dims[1]);
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s));
// fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\n", dims[0], dims[1]);
if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemm_no_inplace output");
%(fail)s
}
}
// fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\n", dims[0], dims[1]);
if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1)
{
%(fail)s
}
Nz = PyArray_DIMS(%(_zout)s);
Sz = PyArray_STRIDES(%(_zout)s);
"""
if (PyArray_DESCR(%(_zout)s)->type_num == NPY_FLOAT)
broadcast_xy = """
// Broadcast X if needed
if (Nz[0] > Nx[0])
{
float * zoutdata = (float*)PyArray_DATA(%(_zout)s);
int zoi = Sz[0] / sizeof(float);
int zoj = Sz[1] / sizeof(float);
const float * zdata = (float*)PyArray_DATA(%(_z)s);
int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(float);
int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(float);
for (int i = 0; i < Nz[0]; ++i)
npy_intp dims[2];
dims[0] = Nz[0];
dims[1] = Nx[1];
// fprintf(stderr, "Gemm Broadcasting X into shape (%%i %%i)\\n", dims[0], dims[1]);
PyArrayObject *x_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s));
if(!x_new) {
PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemm_inplace input");
%(fail)s
}
if(PyArray_MoveInto(x_new, %(_x)s) == -1)
{
for (int j = 0; j < Nz[1]; ++j)
{
zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
}
%(fail)s
}
Py_DECREF(%(_x)s);
%(_x)s = x_new;
Nx = PyArray_DIMS(%(_x)s);
Sx = PyArray_STRIDES(%(_x)s);
}
else if (PyArray_DESCR(%(_zout)s)->type_num == NPY_DOUBLE)
// Broadcast Y if needed
if (Nz[1] > Ny[1])
{
double * zoutdata = (double*) PyArray_DATA(%(_zout)s);
int zoi = Sz[0] / sizeof(double);
int zoj = Sz[1] / sizeof(double);
const double * zdata = (double*)PyArray_DATA(%(_z)s);
int zi = PyArray_STRIDES(%(_z)s)[0]/sizeof(double);
int zj = PyArray_STRIDES(%(_z)s)[1]/sizeof(double);
for (int i = 0; i < Nz[0]; ++i)
npy_intp dims[2];
dims[0] = Ny[0];
dims[1] = Nz[1];
// fprintf(stderr, "Gemm Broadcasting Y into shape (%%i %%i)\\n", dims[0], dims[1]);
PyArrayObject *y_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s));
if(!y_new) {
PyErr_SetString(PyExc_MemoryError,
"failed to alloc gemm_inplace input");
%(fail)s
}
if(PyArray_MoveInto(y_new, %(_y)s) == -1)
{
for (int j = 0; j < Nz[1]; ++j)
{
zoutdata[zoi*i + zoj*j] = zdata[zi*i + zj*j];
}
%(fail)s
}
Py_DECREF(%(_y)s);
%(_y)s = y_new;
Ny = PyArray_DIMS(%(_y)s);
Sy = PyArray_STRIDES(%(_y)s);
}
else
{
PyErr_SetString(PyExc_AssertionError,
"neither float nor double dtype");
%(fail)s
}
"""
"""
case_float_ab_constants = """
#define REAL float
......@@ -1108,7 +1160,7 @@ class Gemm(GemmRelated):
def c_code_cache_version(self):
gv = self.build_gemm_version()
if gv:
return (6,) + gv
return (7,) + gv
else:
return gv
......@@ -1182,7 +1234,6 @@ def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True):
if M.owner and M.owner.op == _dot22:
Ml, Mr = M.owner.inputs
rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
# print 'GEMM 0', rval, beta, L, alpha, M
return rval, M
# it also might be the case that there is a dimshuffle between the +
......@@ -1650,6 +1701,7 @@ class Dot22(GemmRelated):
Sz = PyArray_STRIDES(%(_zout)s);
"""
broadcast_xy = ""
check_ab_double_or_float = ""
case_float_ab_constants = """
float a = 1.0;
......@@ -1933,6 +1985,7 @@ class Dot22Scalar(GemmRelated):
return [[input_shapes[0][0], input_shapes[1][1]]]
setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz
broadcast_xy = ""
check_ab_double_or_float = """
if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE)
......
from copy import copy
from itertools import product
from random import shuffle
import numpy as np
import pytest
......@@ -67,6 +68,7 @@ from aesara.tensor.type import (
matrix,
row,
scalar,
scalars,
tensor,
tensor3,
tensor4,
......@@ -1042,6 +1044,41 @@ def test_inplace1():
assert [n.op for n in f.maker.fgraph.apply_nodes] == [gemm_no_inplace]
@pytest.mark.parametrize("linker", ("py", "cvm"))
@pytest.mark.parametrize("inplace", (False, True))
def test_gemm_broadcasting(inplace, linker):
a, b = scalars("a", "b")
z, x, y = matrices("z", "x", "y")
mode = Mode(linker=linker)
if inplace:
out = gemm_inplace(z, a, x, y, b)
f = aesara.function([z, x, y, a, b], out, accept_inplace=True, mode=mode)
assert [node.op for node in f.maker.fgraph.toposort()] == [gemm_inplace]
else:
out = gemm_no_inplace(z, a, x, y, b)
f = aesara.function([z, x, y, a, b], out, mode=mode)
assert [node.op for node in f.maker.fgraph.toposort()] == [gemm_no_inplace]
shapes_z = [(5, 3), (1, 3), (5, 1), (1, 1)]
shapes_x = [(5, 4), (1, 4)]
shapes_y = [(4, 3), (4, 1)]
rng = np.random.default_rng()
shuffle(shapes_z)
shuffle(shapes_x)
shuffle(shapes_y)
for shape_z, shape_x, shape_y in product(shapes_z, shapes_x, shapes_y):
z_v = rng.random(size=shape_z).astype(config.floatX)
x_v = rng.random(size=shape_x).astype(config.floatX)
y_v = rng.random(size=shape_y).astype(config.floatX)
# We have to copy for the inplace case
z_v_np = z_v.copy()
np.testing.assert_allclose(
f(z_v, x_v, y_v, 1, 1), z_v_np + np.dot(x_v, y_v), atol=2e-6
)
def test_dot22():
for dtype1 in ["float32", "float64", "complex64", "complex128"]:
a = matrix(dtype=dtype1)
......@@ -2476,6 +2513,40 @@ class TestInferShape(unittest_tools.InferShapeTester):
Gemm,
)
def test_gemm_broadcast(self):
rng = np.random.default_rng(unittest_tools.fetch_seed())
x, y, z = matrices("xyz")
a = scalar("a")
b = scalar("b")
# Broadcast Z
self._compile_and_check(
[x, y, a, z, b],
[gemm(z, a, x, y, b)],
[
rng.random((2, 3)).astype(config.floatX),
rng.random((3, 4)).astype(config.floatX),
np.asarray(0.5, dtype=config.floatX),
rng.random((1, 4)).astype(config.floatX),
np.asarray(0.5, dtype=config.floatX),
],
Gemm,
)
# Broadcast dot(X, Y)
self._compile_and_check(
[x, y, a, z, b],
[gemm(z, a, x, y, b)],
[
rng.random((1, 3)).astype(config.floatX),
rng.random((3, 4)).astype(config.floatX),
np.asarray(0.5, dtype=config.floatX),
rng.random((5, 4)).astype(config.floatX),
np.asarray(1, dtype=config.floatX),
],
Gemm,
)
def test_gemv(self):
rng = np.random.default_rng(unittest_tools.fetch_seed())
A = matrix("A")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论