提交 8c878ce3 authored 作者: Frederic Bastien's avatar Frederic Bastien

check gemm, dot22* rank of inputs before using them.

上级 e9c19739
...@@ -423,9 +423,12 @@ class GemmRelated(Op): ...@@ -423,9 +423,12 @@ class GemmRelated(Op):
#setup_z_Nz_Sz = None #setup_z_Nz_Sz = None
check_xyz_rank2 = """ check_xyz_rank2 = """
if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;} if (%(_x)s->nd != 2) {
if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;} PyErr_Format(PyExc_NotImplementedError, "rank(x) != 2. rank(x) is %%d.", %(_x)s->nd); %(fail)s;}
if (%(_zout)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;} if (%(_y)s->nd != 2) {
PyErr_Format(PyExc_NotImplementedError, "rank(y) != 2. rank(y) is %%d.", %(_y)s->nd); %(fail)s;}
if (%(_zout)s && %(_zout)s->nd != 2) {
PyErr_Format(PyExc_NotImplementedError, "rank(z) != 2. rank(z) is %%d.", %(_zout)s->nd); %(fail)s;}
""" """
check_xyz_double_or_float = """ check_xyz_double_or_float = """
if ((%(_x)s->descr->type_num != PyArray_DOUBLE) if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
...@@ -626,8 +629,8 @@ class GemmRelated(Op): ...@@ -626,8 +629,8 @@ class GemmRelated(Op):
return reduce(str.__add__, ( return reduce(str.__add__, (
self.declare_NS, self.declare_NS,
self.setup_z_Nz_Sz,
self.check_xyz_rank2, self.check_xyz_rank2,
self.setup_z_Nz_Sz,
self.check_xyz_double_or_float, self.check_xyz_double_or_float,
self.check_ab_double_or_float, self.check_ab_double_or_float,
self.check_dims, self.check_dims,
...@@ -644,7 +647,7 @@ class GemmRelated(Op): ...@@ -644,7 +647,7 @@ class GemmRelated(Op):
self.end_switch_typenum), '') self.end_switch_typenum), '')
def build_gemm_version(self): def build_gemm_version(self):
return (9,) return (10,)
class Gemm(GemmRelated): class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation): """In-place version of matrix-matrix multiplication (with accumulation):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论