提交 0550874d authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a get_work_mem helper function to have working memory in an op.

上级 6f346499
......@@ -298,6 +298,25 @@ outstanding_mallocs(PyObject* self, PyObject * args)
return PyInt_FromLong(_outstanding_mallocs[0]);
}
static void *work_mem = NULL;
static size_t work_size = 0;
/*
* Returns a chunk of memory for temporary work inside of an op. You can only
* request a single chunk of memory at a time since it is reused.
*/
void *get_work_mem(size_t sz) {
if (sz < work_size)
return work_mem;
device_free(work_mem);
work_mem = device_malloc(sz);
work_size = sz;
if (work_mem == NULL)
work_size = 0;
return work_mem;
}
/////////////////////////
// Static helper methods
/////////////////////////
......
......@@ -88,7 +88,8 @@ typedef float real;
extern DllExport cublasHandle_t handle;
/**
* Allocation and freeing of device memory should go through these functions so that the lib can track memory usage.
* Allocation and freeing of device memory should go through these functions so
* that the lib can track memory usage.
*
* device_malloc will set the Python error message before returning None.
* device_free will return nonzero on failure (after setting the python error message)
......@@ -98,6 +99,7 @@ extern DllExport cublasHandle_t handle;
DllExport void * device_malloc(size_t size);
DllExport void * device_malloc(size_t size, int verbose);
DllExport int device_free(void * ptr);
DllExport void *get_work_mem(size_t sz);
template <typename T>
static T ceil_intdiv(T a, T b)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论