提交 5d652817 authored 作者: James Bergstra's avatar James Bergstra

changed assertions to error-checking in cuda_ndarray gemm, regarding illegal strides

上级 0cc47636
...@@ -2725,7 +2725,8 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B, ...@@ -2725,7 +2725,8 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
if ((CudaNdarray_HOST_DIMS(A)[1] != CudaNdarray_HOST_DIMS(B)[0]) if ((CudaNdarray_HOST_DIMS(A)[1] != CudaNdarray_HOST_DIMS(B)[0])
|| (CudaNdarray_HOST_DIMS(A)[0] != CudaNdarray_HOST_DIMS(C)[0]) || (CudaNdarray_HOST_DIMS(A)[0] != CudaNdarray_HOST_DIMS(C)[0])
|| (CudaNdarray_HOST_DIMS(B)[1] != CudaNdarray_HOST_DIMS(C)[1])) || (CudaNdarray_HOST_DIMS(B)[1] != CudaNdarray_HOST_DIMS(C)[1])
|| (CudaNdarray_HOST_DIMS(A)[1] == 0))
{ {
PyErr_Format(PyExc_ValueError, "dimension mismatch in args to gemm (%i,%i)x(%i,%i)->(%i,%i)", PyErr_Format(PyExc_ValueError, "dimension mismatch in args to gemm (%i,%i)x(%i,%i)->(%i,%i)",
CudaNdarray_HOST_DIMS(A)[0], CudaNdarray_HOST_DIMS(A)[0],
...@@ -2780,12 +2781,22 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B, ...@@ -2780,12 +2781,22 @@ int CudaNdarray_gemm(float alpha, const CudaNdarray * A, const CudaNdarray * B,
} }
// I don't know if cudablas handles negative strides // I don't know if cudablas handles negative strides
assert (CudaNdarray_HOST_STRIDES(A)[0] >= 0) ; // for now if ( (CudaNdarray_HOST_STRIDES(A)[0] < 0)
assert (CudaNdarray_HOST_STRIDES(A)[1] >= 0) ; // for now || (CudaNdarray_HOST_STRIDES(A)[1] < 0)
assert (CudaNdarray_HOST_STRIDES(B)[0] >= 0) ; // for now || (CudaNdarray_HOST_STRIDES(B)[0] < 0)
assert (CudaNdarray_HOST_STRIDES(B)[1] >= 0) ; // for now || (CudaNdarray_HOST_STRIDES(B)[1] < 0)
assert (CudaNdarray_HOST_STRIDES(C)[0] >= 0) ; // for now || (CudaNdarray_HOST_STRIDES(C)[0] < 0)
assert (CudaNdarray_HOST_STRIDES(C)[1] >= 0) ; // for now || (CudaNdarray_HOST_STRIDES(C)[1] < 0))
{
PyErr_Format(PyExc_ValueError, "illegal strides in args to gemm (%i,%i)x(%i,%i)->(%i,%i)",
CudaNdarray_HOST_STRIDES(A)[0],
CudaNdarray_HOST_STRIDES(A)[1],
CudaNdarray_HOST_STRIDES(B)[0],
CudaNdarray_HOST_STRIDES(B)[1],
CudaNdarray_HOST_STRIDES(C)[0],
CudaNdarray_HOST_STRIDES(C)[1]);
return -1;
}
/* create appropriate strides for malformed matrices that are row or column /* create appropriate strides for malformed matrices that are row or column
* vectors * vectors
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论