提交 0c02c4be authored 作者: Frederic's avatar Frederic

check the exact number of gemm in the tests.

上级 f444bf75
......@@ -464,22 +464,23 @@ class Failure(Exception):
pass
def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], max_graphlen=0):
def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()],
max_graphlen=0, expected_nb_gemm=1):
try:
f = inplace_func(
[Param(ii, mutable=True, allow_downcast=True) for ii in i],
o,
mode='FAST_RUN',
on_unused_input='ignore')
at_least_one_gemm = False
nb_gemm = 0
for node in f.maker.env.nodes:
if node.op == T.dot:
raise Failure('dot not changed to gemm_inplace in graph')
if node.op == _dot22:
raise Failure('_dot22 not changed to gemm_inplace in graph')
if node.op == gemm_inplace:
at_least_one_gemm = True
assert at_least_one_gemm
nb_gemm += 1
assert nb_gemm == expected_nb_gemm
g = inplace_func(i, o, mode=compile.Mode(linker='py', optimizer=None),
allow_input_downcast=True, on_unused_input='ignore')
for node in g.maker.env.nodes:
......@@ -545,7 +546,8 @@ def test_gemm_opt_double_gemm():
just_gemm([X, Y, Z, a, b, R, S, c],
[Z * c + a * T.dot(X, Y) + b * T.dot(R, S).T],
ishapes=[(4, 3), (3, 5), (4, 5), (), (), (5, 9), (9, 4), ()])
ishapes=[(4, 3), (3, 5), (4, 5), (), (), (5, 9), (9, 4), ()],
expected_nb_gemm=2)
ishapes = [(4, 3), (3, 5), (4, 5), (), (), (5, 9), (9, 4), ()]
i = [X, Y, Z, a, b, R, S, c]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论