提交 a3dc0a72 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Broadcast input matrices in Gemm

上级 89353bd2
差异被折叠。
from copy import copy
from itertools import product
from random import shuffle
import numpy as np
import pytest
......@@ -67,6 +68,7 @@ from aesara.tensor.type import (
matrix,
row,
scalar,
scalars,
tensor,
tensor3,
tensor4,
......@@ -1042,6 +1044,41 @@ def test_inplace1():
assert [n.op for n in f.maker.fgraph.apply_nodes] == [gemm_no_inplace]
@pytest.mark.parametrize("linker", ("py", "cvm"))
@pytest.mark.parametrize("inplace", (False, True))
def test_gemm_broadcasting(inplace, linker):
a, b = scalars("a", "b")
z, x, y = matrices("z", "x", "y")
mode = Mode(linker=linker)
if inplace:
out = gemm_inplace(z, a, x, y, b)
f = aesara.function([z, x, y, a, b], out, accept_inplace=True, mode=mode)
assert [node.op for node in f.maker.fgraph.toposort()] == [gemm_inplace]
else:
out = gemm_no_inplace(z, a, x, y, b)
f = aesara.function([z, x, y, a, b], out, mode=mode)
assert [node.op for node in f.maker.fgraph.toposort()] == [gemm_no_inplace]
shapes_z = [(5, 3), (1, 3), (5, 1), (1, 1)]
shapes_x = [(5, 4), (1, 4)]
shapes_y = [(4, 3), (4, 1)]
rng = np.random.default_rng()
shuffle(shapes_z)
shuffle(shapes_x)
shuffle(shapes_y)
for shape_z, shape_x, shape_y in product(shapes_z, shapes_x, shapes_y):
z_v = rng.random(size=shape_z).astype(config.floatX)
x_v = rng.random(size=shape_x).astype(config.floatX)
y_v = rng.random(size=shape_y).astype(config.floatX)
# We have to copy for the inplace case
z_v_np = z_v.copy()
np.testing.assert_allclose(
f(z_v, x_v, y_v, 1, 1), z_v_np + np.dot(x_v, y_v), atol=2e-6
)
def test_dot22():
for dtype1 in ["float32", "float64", "complex64", "complex128"]:
a = matrix(dtype=dtype1)
......@@ -2476,6 +2513,40 @@ class TestInferShape(unittest_tools.InferShapeTester):
Gemm,
)
def test_gemm_broadcast(self):
rng = np.random.default_rng(unittest_tools.fetch_seed())
x, y, z = matrices("xyz")
a = scalar("a")
b = scalar("b")
# Broadcast Z
self._compile_and_check(
[x, y, a, z, b],
[gemm(z, a, x, y, b)],
[
rng.random((2, 3)).astype(config.floatX),
rng.random((3, 4)).astype(config.floatX),
np.asarray(0.5, dtype=config.floatX),
rng.random((1, 4)).astype(config.floatX),
np.asarray(0.5, dtype=config.floatX),
],
Gemm,
)
# Broadcast dot(X, Y)
self._compile_and_check(
[x, y, a, z, b],
[gemm(z, a, x, y, b)],
[
rng.random((1, 3)).astype(config.floatX),
rng.random((3, 4)).astype(config.floatX),
np.asarray(0.5, dtype=config.floatX),
rng.random((5, 4)).astype(config.floatX),
np.asarray(1, dtype=config.floatX),
],
Gemm,
)
def test_gemv(self):
rng = np.random.default_rng(unittest_tools.fetch_seed())
A = matrix("A")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论