提交 5df0aeca authored 作者: James Bergstra's avatar James Bergstra

added more informative shape error message to gemm

上级 05c80ca0
......@@ -149,16 +149,32 @@ class GemmRelated(Op):
"""
check_dims_strides = """
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
if (Nx[0] != Nz[0])
{
PyErr_SetString(PyExc_ValueError, "Input dimensions do not agree");
PyErr_Format(PyExc_ValueError,
"Shape mismatch: x has %%ld rows but z has %%ld rows",
(long int)Nx[0], (long int)Nz[0]);
%(fail)s;
}
if (Nx[1] != Ny[0])
{
PyErr_Format(PyExc_ValueError,
"Shape mismatch: x has %%ld cols but y has %%ld rows",
(long int)Nx[1], (long int)Ny[0]);
%(fail)s;
}
if (Ny[1] != Nz[1])
{
PyErr_Format(PyExc_ValueError,
"Shape mismatch: y has %%ld cols but z has %%ld cols",
(long int)Ny[1], (long int)Nz[1]);
%(fail)s;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
{
PyErr_SetString(PyExc_ValueError, "stride is not multiple of element size"); %(fail)s;
PyErr_SetString(PyExc_NotImplementedError, "stride is not multiple of element size"); %(fail)s;
}
"""
......@@ -275,7 +291,7 @@ class GemmRelated(Op):
self.end_switch_typenum), '')
def build_gemm_version(self):
return (1,)
return (2,)
class Gemm(GemmRelated):
"""In-place version of matrix-matrix multiplication (with accumulation):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论