提交 a5010fe7 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix error message for sgemv to say "Sgemv" rather than "Sgemm"

上级 5e69ec44
...@@ -286,7 +286,7 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1], ...@@ -286,7 +286,7 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1],
%(name)s_iIdx, PyArray_DIM(%(inputIdx)s, 1), %(name)s_iIdx, PyArray_DIM(%(inputIdx)s, 1),
%(name)s_oIdx, PyArray_DIM(%(outputIdx)s, 1)); %(name)s_oIdx, PyArray_DIM(%(outputIdx)s, 1));
} }
{ /* Run SgemmBatched */ { /* Run SgemvBatched */
float alpha = 1.0f; float alpha = 1.0f;
float beta = 1.0f; float beta = 1.0f;
cublasStatus_t err; cublasStatus_t err;
...@@ -308,7 +308,7 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1], ...@@ -308,7 +308,7 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1],
CudaNdarray_HOST_DIMS(%(h)s)[1] * CudaNdarray_HOST_DIMS(%(h)s)[1] *
CudaNdarray_HOST_DIMS(%(o)s)[0]); CudaNdarray_HOST_DIMS(%(o)s)[0]);
if (err != CUBLAS_STATUS_SUCCESS) { if (err != CUBLAS_STATUS_SUCCESS) {
PyErr_SetString(PyExc_RuntimeError, "SgemmBatched failed"); PyErr_SetString(PyExc_RuntimeError, "SgemvBatched failed");
%(fail)s %(fail)s
} }
} }
...@@ -448,8 +448,6 @@ static cublasStatus_t SgerBatched(cublasHandle_t handle, int m, int n, ...@@ -448,8 +448,6 @@ static cublasStatus_t SgerBatched(cublasHandle_t handle, int m, int n,
} }
} }
if (grid.x * grid.y * grid.z > 65535) { if (grid.x * grid.y * grid.z > 65535) {
// If grid.x * grid.y is bigger than 65535 you deserve the error
// you'll get later because that is way too big for this op.
grid.z = (65535 / (grid.x * grid.y)); grid.z = (65535 / (grid.x * grid.y));
} }
cublasGetPointerMode(handle, &mode); cublasGetPointerMode(handle, &mode);
...@@ -701,7 +699,7 @@ def sparse_block_dot_SS(W, h, inputIdx, b, outputIdx): ...@@ -701,7 +699,7 @@ def sparse_block_dot_SS(W, h, inputIdx, b, outputIdx):
b: (oBlocks, oSize), bias vector b: (oBlocks, oSize), bias vector
outputIdx: (batch, oWin), indexes of the output blocks outputIdx: (batch, oWin), indexes of the output blocks
returns (oBlocks, oSize), dot(W[i, j], h[i]) + b[j] returns (batch, oBlocks, oSize), dot(W[i, j], h[i]) + b[j]
but b[j] is only added once but b[j] is only added once
""" """
assert inputIdx.ndim == h.ndim - 1 assert inputIdx.ndim == h.ndim - 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论