提交 96ef4da1 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

GpuBatchedDot: fix typo

上级 6059a547
...@@ -53,7 +53,7 @@ class GpuBatchedDot(GpuOp): ...@@ -53,7 +53,7 @@ class GpuBatchedDot(GpuOp):
// use parallel cublasSgemm calls rather than cublasSgemmBatched for large products // use parallel cublasSgemm calls rather than cublasSgemmBatched for large products
// (compute products in double because they can be large and we don't need to be exact) // (compute products in double because they can be large and we don't need to be exact)
bool use_cublas_sgemm_batched = ( bool use_cublas_sgemm_batched = (
double(Nx[1]) * double(Nx[2]) * double(Nx[2]) < double(Nx[1]) * double(Nx[2]) * double(Ny[2]) <
double(%(threshold)s) * double(%(threshold)s) * double(%(threshold)s)); double(%(threshold)s) * double(%(threshold)s) * double(%(threshold)s));
if (Nx[0] != Ny[0]) { if (Nx[0] != Ny[0]) {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论