提交 5d652817 authored 作者: James Bergstra's avatar James Bergstra

changed assertions to error-checking in cuda_ndarray gemm, regarding illegal strides

上级 0cc47636
......@@ -2725,7 +2725,8 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
if ((CudaNdarray_HOST_DIMS(A)[1] != CudaNdarray_HOST_DIMS(B)[0])
|| (CudaNdarray_HOST_DIMS(A)[0] != CudaNdarray_HOST_DIMS(C)[0])
|| (CudaNdarray_HOST_DIMS(B)[1] != CudaNdarray_HOST_DIMS(C)[1]))
|| (CudaNdarray_HOST_DIMS(B)[1] != CudaNdarray_HOST_DIMS(C)[1])
|| (CudaNdarray_HOST_DIMS(A)[1] == 0))
{
PyErr_Format(PyExc_ValueError, "dimension mismatch in args to gemm (%i,%i)x(%i,%i)->(%i,%i)",
CudaNdarray_HOST_DIMS(A)[0],
......@@ -2780,12 +2781,22 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
}
// I don't know if cudablas handles negative strides
assert (CudaNdarray_HOST_STRIDES(A)[0] >= 0) ; // for now
assert (CudaNdarray_HOST_STRIDES(A)[1] >= 0) ; // for now
assert (CudaNdarray_HOST_STRIDES(B)[0] >= 0) ; // for now
assert (CudaNdarray_HOST_STRIDES(B)[1] >= 0) ; // for now
assert (CudaNdarray_HOST_STRIDES(C)[0] >= 0) ; // for now
assert (CudaNdarray_HOST_STRIDES(C)[1] >= 0) ; // for now
if ( (CudaNdarray_HOST_STRIDES(A)[0] < 0)
|| (CudaNdarray_HOST_STRIDES(A)[1] < 0)
|| (CudaNdarray_HOST_STRIDES(B)[0] < 0)
|| (CudaNdarray_HOST_STRIDES(B)[1] < 0)
|| (CudaNdarray_HOST_STRIDES(C)[0] < 0)
|| (CudaNdarray_HOST_STRIDES(C)[1] < 0))
{
PyErr_Format(PyExc_ValueError, "illegal strides in args to gemm (%i,%i)x(%i,%i)->(%i,%i)",
CudaNdarray_HOST_STRIDES(A)[0],
CudaNdarray_HOST_STRIDES(A)[1],
CudaNdarray_HOST_STRIDES(B)[0],
CudaNdarray_HOST_STRIDES(B)[1],
CudaNdarray_HOST_STRIDES(C)[0],
CudaNdarray_HOST_STRIDES(C)[1]);
return -1;
}
/* create appropriate strides for malformed matrices that are row or column
* vectors
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论