提交 8c878ce3 authored 作者: Frederic Bastien's avatar Frederic Bastien

check gemm, dot22* rank of inputs before using them.

上级 e9c19739
......@@ -423,9 +423,12 @@ class GemmRelated(Op):
#setup_z_Nz_Sz = None
check_xyz_rank2 = """
if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(_zout)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
if (%(_x)s->nd != 2) {
PyErr_Format(PyExc_NotImplementedError, "rank(x) != 2. rank(x) is %%d.", %(_x)s->nd); %(fail)s;}
if (%(_y)s->nd != 2) {
PyErr_Format(PyExc_NotImplementedError, "rank(y) != 2. rank(y) is %%d.", %(_y)s->nd); %(fail)s;}
if (%(_zout)s && %(_zout)s->nd != 2) {
PyErr_Format(PyExc_NotImplementedError, "rank(z) != 2. rank(z) is %%d.", %(_zout)s->nd); %(fail)s;}
"""
check_xyz_double_or_float = """
if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
......@@ -626,8 +629,8 @@ class GemmRelated(Op):
return reduce(str.__add__, (
self.declare_NS,
self.setup_z_Nz_Sz,
self.check_xyz_rank2,
self.setup_z_Nz_Sz,
self.check_xyz_double_or_float,
self.check_ab_double_or_float,
self.check_dims,
......@@ -644,7 +647,7 @@ class GemmRelated(Op):
self.end_switch_typenum), '')
def build_gemm_version(self):
return (9,)
return (10,)
class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论