提交 ccf3cef1 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix Ger C code when output is a row or column.

Some BLAS versions did not accept (1,1) strides in that case.
上级 c8f8dba5
...@@ -128,14 +128,20 @@ def ger_c_code(A, a, x, y, Z, destructive, fail): ...@@ -128,14 +128,20 @@ def ger_c_code(A, a, x, y, Z, destructive, fail):
} }
{ {
int Nz0 = %(Z)s->dimensions[0]; int Nz0 = %(Z)s->dimensions[0];
int Nz1 = %(Z)s->dimensions[1]; int Nz1 = %(Z)s->dimensions[1];
int Sz0 = %(Z)s->strides[0] / elemsize;
int Sz1 = %(Z)s->strides[1] / elemsize;
int Sx = %(x)s->strides[0] / elemsize; int Sx = %(x)s->strides[0] / elemsize;
int Sy = %(y)s->strides[0] / elemsize; int Sy = %(y)s->strides[0] / elemsize;
/* create appropriate strides for Z, if it is a row or column matrix.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
int Sz0 = (Nz0 > 1) ? (%(Z)s->strides[0] / elemsize) : (Nz1 + 1);
int Sz1 = (Nz1 > 1) ? (%(Z)s->strides[1] / elemsize) : (Nz0 + 1);
if (1) if (1)
{ {
if (%(Z)s->strides[0] == elemsize) if (%(Z)s->strides[0] == elemsize)
...@@ -198,7 +204,7 @@ class CGer(BaseBLAS, Ger): ...@@ -198,7 +204,7 @@ class CGer(BaseBLAS, Ger):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (4,)
@local_optimizer([ger, ger_destructive]) @local_optimizer([ger, ger_destructive])
......
...@@ -1378,6 +1378,12 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin): ...@@ -1378,6 +1378,12 @@ class TestGer(TestCase, unittest_tools.TestOptimizationMixin):
def test_f32_4_4(self): def test_f32_4_4(self):
return self.given_dtype('float32', 4, 4) return self.given_dtype('float32', 4, 4)
def test_f32_7_1(self):
return self.given_dtype('float32', 7, 1)
def test_f32_1_2(self):
return self.given_dtype('float32', 1, 2)
def test_f64_4_5(self): def test_f64_4_5(self):
return self.given_dtype('float64', 4, 5) return self.given_dtype('float64', 4, 5)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论