提交 2733557d authored 作者: nouiz's avatar nouiz

Merge pull request #394 from lamblin/fix_ger_test

Fix Ger C code when output is a row or column.
......@@ -128,14 +128,20 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
}
{
int Nz0 = %(Z)s->dimensions[0];
int Nz1 = %(Z)s->dimensions[1];
int Sz0 = %(Z)s->strides[0] / elemsize;
int Sz1 = %(Z)s->strides[1] / elemsize;
int Sx = %(x)s->strides[0] / elemsize;
int Sy = %(y)s->strides[0] / elemsize;
/* create appropriate strides for Z, if it is a row or column matrix.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
int Sz0 = (Nz0 > 1) ? (%(Z)s->strides[0] / elemsize) : (Nz1 + 1);
int Sz1 = (Nz1 > 1) ? (%(Z)s->strides[1] / elemsize) : (Nz0 + 1);
if (1)
{
if (%(Z)s->strides[0] == elemsize)
......@@ -198,7 +204,7 @@ class CGer(BaseBLAS, Ger):
return code
def c_code_cache_version(self):
return (3,)
return (4,)
@local_optimizer([ger, ger_destructive])
......
......@@ -1378,6 +1378,12 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
def test_f32_4_4(self):
return self.given_dtype('float32', 4, 4)
def test_f32_7_1(self):
return self.given_dtype('float32', 7, 1)
def test_f32_1_2(self):
return self.given_dtype('float32', 1, 2)
def test_f64_4_5(self):
return self.given_dtype('float64', 4, 5)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论