提交 2eaa8466 authored 作者: Frederic's avatar Frederic

Be more resistent to NumPy bad strides.

上级 039f7b5c
...@@ -721,9 +721,9 @@ class GemmRelated(Op): ...@@ -721,9 +721,9 @@ class GemmRelated(Op):
/* /*
encode the stride structure of _x,_y,_zout into a single integer encode the stride structure of _x,_y,_zout into a single integer
*/ */
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8; unit |= ((Sx[1] == type_size || Nx[1]==1) ? 0x0 : (Sx[0] == type_size || Nx[0]==1) ? 0x1 : 0x2) << 8;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4; unit |= ((Sy[1] == type_size || Ny[1]==1) ? 0x0 : (Sy[0] == type_size || Ny[0]==1) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 0; unit |= ((Sz[1] == type_size || Nz[1]==1) ? 0x0 : (Sz[0] == type_size || Nz[0]==1) ? 0x1 : 0x2) << 0;
""" """
compute_strides = """ compute_strides = """
...@@ -856,7 +856,7 @@ class GemmRelated(Op): ...@@ -856,7 +856,7 @@ class GemmRelated(Op):
self.end_switch_typenum), '') self.end_switch_typenum), '')
def build_gemm_version(self): def build_gemm_version(self):
return (12, blas_header_version()) return (13, blas_header_version())
class Gemm(GemmRelated): class Gemm(GemmRelated):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论