提交 a5010fe7 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix error message for sgemv to say "Sgemv" rather than "Sgemm"

上级 5e69ec44
......@@ -286,7 +286,7 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1],
%(name)s_iIdx, PyArray_DIM(%(inputIdx)s, 1),
%(name)s_oIdx, PyArray_DIM(%(outputIdx)s, 1));
}
{ /* Run SgemmBatched */
{ /* Run SgemvBatched */
float alpha = 1.0f;
float beta = 1.0f;
cublasStatus_t err;
......@@ -308,7 +308,7 @@ CudaNdarray_HOST_STRIDES(%(out)s)[0], CudaNdarray_HOST_STRIDES(%(out)s)[1],
CudaNdarray_HOST_DIMS(%(h)s)[1] *
CudaNdarray_HOST_DIMS(%(o)s)[0]);
if (err != CUBLAS_STATUS_SUCCESS) {
PyErr_SetString(PyExc_RuntimeError, "SgemmBatched failed");
PyErr_SetString(PyExc_RuntimeError, "SgemvBatched failed");
%(fail)s
}
}
......@@ -448,8 +448,6 @@ static cublasStatus_t SgerBatched(cublasHandle_t handle, int m, int n,
}
}
if (grid.x * grid.y * grid.z > 65535) {
// If grid.x * grid.y is bigger than 65535 you deserve the error
// you'll get later because that is way too big for this op.
grid.z = (65535 / (grid.x * grid.y));
}
cublasGetPointerMode(handle, &mode);
......@@ -701,7 +699,7 @@ def sparse_block_dot_SS(W, h, inputIdx, b, outputIdx):
b: (oBlocks, oSize), bias vector
outputIdx: (batch, oWin), indexes of the output blocks
returns (oBlocks, oSize), dot(W[i, j], h[i]) + b[j]
returns (batch, oBlocks, oSize), dot(W[i, j], h[i]) + b[j]
but b[j] is only added once
"""
assert inputIdx.ndim == h.ndim - 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论