提交 87a5e2bd authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Actually test gemv, not gemm.

上级 d763c819
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import theano import theano
from theano import tensor from theano import tensor
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.blas import gemm, gemv_inplace, gemm_inplace, _dot22, batched_dot from theano.tensor.blas import gemv, gemv_inplace, gemm_inplace, _dot22, batched_dot
from theano.tensor.tests.test_blas import TestGer, BaseGemv from theano.tensor.tests.test_blas import TestGer, BaseGemv
from .. import gpuarray_shared_constructor from .. import gpuarray_shared_constructor
...@@ -50,7 +50,7 @@ def test_float16(): ...@@ -50,7 +50,7 @@ def test_float16():
np.asarray(0.5, dtype=np.float32)] np.asarray(0.5, dtype=np.float32)]
float16_shared = [gpuarray_shared_constructor(val, target=test_ctx_name) float16_shared = [gpuarray_shared_constructor(val, target=test_ctx_name)
for val in float16_data] for val in float16_data]
o = gemm(*float16_shared) o = gemv(*float16_shared)
f = theano.function([], o, mode=mode_with_gpu) f = theano.function([], o, mode=mode_with_gpu)
y, alpha, A, x, beta = float16_data y, alpha, A, x, beta = float16_data
out = f() out = f()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论