提交 a6eb05aa authored 作者: Tim Cooijmans's avatar Tim Cooijmans

GpuBatchedDot: avoid incorrectness due to overflow in threshold test

上级 5fdf304d
...@@ -63,7 +63,10 @@ class GpuBatchedDot(GpuOp): ...@@ -63,7 +63,10 @@ class GpuBatchedDot(GpuOp):
y_dim2 = CudaNdarray_HOST_DIMS(%(by)s)[2]; y_dim2 = CudaNdarray_HOST_DIMS(%(by)s)[2];
// use parallel cublasSgemm calls rather than cublasSgemmBatched for large products // use parallel cublasSgemm calls rather than cublasSgemmBatched for large products
bool use_cublas_sgemm_batched = x_dim1 * x_dim2 * y_dim2 < %(threshold)s * %(threshold)s * %(threshold)s; // (compute products in double because they can be large and we don't need to be exact)
bool use_cublas_sgemm_batched = (
double(x_dim1) * double(x_dim2) * double(y_dim2) <
double(%(threshold)s) * double(%(threshold)s) * double(%(threshold)s));
if (x_dim0 != y_dim0) if (x_dim0 != y_dim0)
{ {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论