提交 0550874d authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a get_work_mem helper function to have working memory in an op.

上级 6f346499
...@@ -298,6 +298,25 @@ outstanding_mallocs(PyObject* self, PyObject * args) ...@@ -298,6 +298,25 @@ outstanding_mallocs(PyObject* self, PyObject * args)
return PyInt_FromLong(_outstanding_mallocs[0]); return PyInt_FromLong(_outstanding_mallocs[0]);
} }
static void *work_mem = NULL;
static size_t work_size = 0;
/*
* Returns a chunk of memory for temporary work inside of an op. You can only
* request a single chunk of memory at a time since it is reused.
*/
void *get_work_mem(size_t sz) {
if (sz < work_size)
return work_mem;
device_free(work_mem);
work_mem = device_malloc(sz);
work_size = sz;
if (work_mem == NULL)
work_size = 0;
return work_mem;
}
///////////////////////// /////////////////////////
// Static helper methods // Static helper methods
///////////////////////// /////////////////////////
......
...@@ -88,7 +88,8 @@ typedef float real; ...@@ -88,7 +88,8 @@ typedef float real;
extern DllExport cublasHandle_t handle; extern DllExport cublasHandle_t handle;
/** /**
* Allocation and freeing of device memory should go through these functions so that the lib can track memory usage. * Allocation and freeing of device memory should go through these functions so
* that the lib can track memory usage.
* *
* device_malloc will set the Python error message before returning None. * device_malloc will set the Python error message before returning None.
* device_free will return nonzero on failure (after setting the python error message) * device_free will return nonzero on failure (after setting the python error message)
...@@ -98,6 +99,7 @@ extern DllExport cublasHandle_t handle; ...@@ -98,6 +99,7 @@ extern DllExport cublasHandle_t handle;
DllExport void * device_malloc(size_t size); DllExport void * device_malloc(size_t size);
DllExport void * device_malloc(size_t size, int verbose); DllExport void * device_malloc(size_t size, int verbose);
DllExport int device_free(void * ptr); DllExport int device_free(void * ptr);
DllExport void *get_work_mem(size_t sz);
template <typename T> template <typename T>
static T ceil_intdiv(T a, T b) static T ceil_intdiv(T a, T b)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论