提交 8c9b612b authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Type context for dnn.py

上级 e4a14f54
......@@ -107,14 +107,14 @@ cudnnHandle_t APPLY_SPECIFIC(_handle);
#section init_code_struct
{
cuda_enter(pygpu_default_context()->ctx);
cuda_enter(CONTEXT->ctx);
cudnnStatus_t err;
APPLY_SPECIFIC(_handle) = NULL;
if ((err = cudnnCreate(&APPLY_SPECIFIC(_handle))) != CUDNN_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError, "could not create cuDNN handle: %s",
cudnnGetErrorString(err));
cuda_exit(pygpu_default_context()->ctx);
cuda_exit(CONTEXT->ctx);
FAIL;
}
cuda_exit(pygpu_default_context()->ctx);
cuda_exit(CONTEXT->ctx);
}
......@@ -5,12 +5,12 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
PyGpuArrayObject *om,
cudnnConvolutionDescriptor_t desc,
double alpha, double beta,
PyGpuArrayObject **output) {
PyGpuArrayObject **output,
PyGpuContextObject *c) {
cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
float af = alpha, bf = beta;
void *alpha_p;
void *beta_p;
PyGpuContextObject *c = pygpu_default_context();
if (PyGpuArray_DIMS(input)[1] != PyGpuArray_DIMS(kerns)[1]) {
PyErr_SetString(PyExc_ValueError,
......
......@@ -4,12 +4,12 @@ int
APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
PyGpuArrayObject *im,
cudnnConvolutionDescriptor_t desc,
double alpha, double beta, PyGpuArrayObject **input) {
double alpha, double beta, PyGpuArrayObject **input,
PyGpuContextObject *c) {
cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
float af = alpha, bf = beta;
void *alpha_p;
void *beta_p;
PyGpuContextObject *c = pygpu_default_context();
if (PyGpuArray_DIMS(im)[1] != PyGpuArray_DIMS(kerns)[1]) {
PyErr_SetString(PyExc_ValueError, "images and kernel must have the same "
......
......@@ -4,12 +4,12 @@ int
APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
PyGpuArrayObject *km,
cudnnConvolutionDescriptor_t desc,
double alpha, double beta, PyGpuArrayObject **kerns) {
double alpha, double beta, PyGpuArrayObject **kerns,
PyGpuContextObject *c) {
cudnnStatus_t err = CUDNN_STATUS_SUCCESS;
float af = alpha, bf = beta;
void *alpha_p;
void *beta_p;
PyGpuContextObject *c = pygpu_default_context();
if (PyGpuArray_DIMS(input)[1] != PyGpuArray_DIMS(km)[1]) {
PyErr_SetString(PyExc_ValueError,
......
......@@ -29,10 +29,10 @@ if (APPLY_SPECIFIC(output) != NULL) { cudnnDestroyTensorDescriptor(APPLY_SPECIFI
int APPLY_SPECIFIC(dnn_pool)(PyGpuArrayObject *img,
cudnnPoolingDescriptor_t desc,
PyGpuArrayObject **out) {
PyGpuArrayObject **out,
PyGpuContextObject *c) {
cudnnStatus_t err;
size_t dims[5];
PyGpuContextObject *c = pygpu_default_context();
if (!GpuArray_IS_C_CONTIGUOUS(&img->ga)) {
PyErr_SetString(PyExc_ValueError, "Only contiguous inputs are supported.");
......
......@@ -53,9 +53,9 @@ int APPLY_SPECIFIC(dnn_pool_grad)(PyGpuArrayObject *inp,
PyGpuArrayObject *out,
PyGpuArrayObject *out_grad,
cudnnPoolingDescriptor_t desc,
PyGpuArrayObject **inp_grad) {
PyGpuArrayObject **inp_grad,
PyGpuContextObject *c) {
cudnnStatus_t err;
PyGpuContextObject *c = pygpu_default_context();
if (!GpuArray_IS_C_CONTIGUOUS(&inp->ga)) {
PyErr_SetString(PyExc_ValueError, "Only contiguous inputs are supported.");
......@@ -81,7 +81,7 @@ int APPLY_SPECIFIC(dnn_pool_grad)(PyGpuArrayObject *inp,
if (theano_prep_output(inp_grad, PyGpuArray_NDIM(inp),
PyGpuArray_DIMS(inp), inp->ga.typecode,
GA_C_ORDER, pygpu_default_context()) != 0) {
GA_C_ORDER, c) != 0) {
return 1;
}
......
......@@ -34,9 +34,9 @@ if (APPLY_SPECIFIC(output) != NULL)
#section support_code_struct
int APPLY_SPECIFIC(softmax)(PyGpuArrayObject *x,
PyGpuArrayObject **out) {
PyGpuArrayObject **out,
PyGpuContextObject *c) {
cudnnStatus_t err;
PyGpuContextObject *c = pygpu_default_context();
if (c_set_tensorNd(x, APPLY_SPECIFIC(input)) != 0)
return 1;
......
......@@ -45,9 +45,9 @@ if (APPLY_SPECIFIC(dx) != NULL)
int APPLY_SPECIFIC(softmax_grad)(PyGpuArrayObject *dy,
PyGpuArrayObject *sm,
PyGpuArrayObject **dx) {
PyGpuArrayObject **dx,
PyGpuContextObject *c) {
cudnnStatus_t err;
PyGpuContextObject *c = pygpu_default_context();
if (c_set_tensorNd(dy, APPLY_SPECIFIC(dy)) != 0)
return 1;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论