提交 a3dc0a72 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Broadcast input matrices in Gemm

上级 89353bd2
差异被折叠。
from copy import copy from copy import copy
from itertools import product from itertools import product
from random import shuffle
import numpy as np import numpy as np
import pytest import pytest
...@@ -67,6 +68,7 @@ from aesara.tensor.type import ( ...@@ -67,6 +68,7 @@ from aesara.tensor.type import (
matrix, matrix,
row, row,
scalar, scalar,
scalars,
tensor, tensor,
tensor3, tensor3,
tensor4, tensor4,
...@@ -1042,6 +1044,41 @@ def test_inplace1(): ...@@ -1042,6 +1044,41 @@ def test_inplace1():
assert [n.op for n in f.maker.fgraph.apply_nodes] == [gemm_no_inplace] assert [n.op for n in f.maker.fgraph.apply_nodes] == [gemm_no_inplace]
@pytest.mark.parametrize("linker", ("py", "cvm"))
@pytest.mark.parametrize("inplace", (False, True))
def test_gemm_broadcasting(inplace, linker):
a, b = scalars("a", "b")
z, x, y = matrices("z", "x", "y")
mode = Mode(linker=linker)
if inplace:
out = gemm_inplace(z, a, x, y, b)
f = aesara.function([z, x, y, a, b], out, accept_inplace=True, mode=mode)
assert [node.op for node in f.maker.fgraph.toposort()] == [gemm_inplace]
else:
out = gemm_no_inplace(z, a, x, y, b)
f = aesara.function([z, x, y, a, b], out, mode=mode)
assert [node.op for node in f.maker.fgraph.toposort()] == [gemm_no_inplace]
shapes_z = [(5, 3), (1, 3), (5, 1), (1, 1)]
shapes_x = [(5, 4), (1, 4)]
shapes_y = [(4, 3), (4, 1)]
rng = np.random.default_rng()
shuffle(shapes_z)
shuffle(shapes_x)
shuffle(shapes_y)
for shape_z, shape_x, shape_y in product(shapes_z, shapes_x, shapes_y):
z_v = rng.random(size=shape_z).astype(config.floatX)
x_v = rng.random(size=shape_x).astype(config.floatX)
y_v = rng.random(size=shape_y).astype(config.floatX)
# We have to copy for the inplace case
z_v_np = z_v.copy()
np.testing.assert_allclose(
f(z_v, x_v, y_v, 1, 1), z_v_np + np.dot(x_v, y_v), atol=2e-6
)
def test_dot22(): def test_dot22():
for dtype1 in ["float32", "float64", "complex64", "complex128"]: for dtype1 in ["float32", "float64", "complex64", "complex128"]:
a = matrix(dtype=dtype1) a = matrix(dtype=dtype1)
...@@ -2476,6 +2513,40 @@ class TestInferShape(unittest_tools.InferShapeTester): ...@@ -2476,6 +2513,40 @@ class TestInferShape(unittest_tools.InferShapeTester):
Gemm, Gemm,
) )
def test_gemm_broadcast(self):
rng = np.random.default_rng(unittest_tools.fetch_seed())
x, y, z = matrices("xyz")
a = scalar("a")
b = scalar("b")
# Broadcast Z
self._compile_and_check(
[x, y, a, z, b],
[gemm(z, a, x, y, b)],
[
rng.random((2, 3)).astype(config.floatX),
rng.random((3, 4)).astype(config.floatX),
np.asarray(0.5, dtype=config.floatX),
rng.random((1, 4)).astype(config.floatX),
np.asarray(0.5, dtype=config.floatX),
],
Gemm,
)
# Broadcast dot(X, Y)
self._compile_and_check(
[x, y, a, z, b],
[gemm(z, a, x, y, b)],
[
rng.random((1, 3)).astype(config.floatX),
rng.random((3, 4)).astype(config.floatX),
np.asarray(0.5, dtype=config.floatX),
rng.random((5, 4)).astype(config.floatX),
np.asarray(1, dtype=config.floatX),
],
Gemm,
)
def test_gemv(self): def test_gemv(self):
rng = np.random.default_rng(unittest_tools.fetch_seed()) rng = np.random.default_rng(unittest_tools.fetch_seed())
A = matrix("A") A = matrix("A")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论