提交 8908b832 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

modif to try to make gpu work on windows.

上级 261add21
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "cuda_ndarray.cuh" #include "cuda_ndarray.cuh"
//If true, when there is a gpu malloc or free error, we print the size of allocated memory on the device. //If true, when there is a gpu malloc or free error, we print the size of allocated memory on the device.
#define COMPUTE_GPU_MEM_USED false #define COMPUTE_GPU_MEM_USED 0
///////////////////////// /////////////////////////
// Alloc and Free // Alloc and Free
...@@ -420,7 +420,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape) ...@@ -420,7 +420,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
PyObject * CudaNdarray_Copy(CudaNdarray * self) PyObject * CudaNdarray_Copy(CudaNdarray * self)
{ {
PyObject * rval = CudaNdarray_new_null(); PyObject * rval = CudaNdarray_new_null();
if ((!rval) or (-1 == self->nd)) if ((!rval) || (-1 == self->nd))
{ {
return rval; return rval;
} }
...@@ -1871,7 +1871,7 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s ...@@ -1871,7 +1871,7 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s
CudaNdarray * cnda = (CudaNdarray*)py_data; CudaNdarray * cnda = (CudaNdarray*)py_data;
if (strict or CudaNdarray_Check(py_data)) if (strict || CudaNdarray_Check(py_data))
{ {
//TODO: support non-strict "casting" from a vt to the broadcastable/type/size that we need. //TODO: support non-strict "casting" from a vt to the broadcastable/type/size that we need.
if (!CudaNdarray_Check(py_data)) if (!CudaNdarray_Check(py_data))
...@@ -1890,7 +1890,7 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s ...@@ -1890,7 +1890,7 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s
} }
for (int i = 0; i < cnda->nd; ++i) for (int i = 0; i < cnda->nd; ++i)
{ {
if ((CudaNdarray_HOST_DIMS(cnda)[i] > 1) and PyInt_AsLong(PyTuple_GetItem(broadcastable, Py_ssize_t(i)))) if ((CudaNdarray_HOST_DIMS(cnda)[i] > 1) && PyInt_AsLong(PyTuple_GetItem(broadcastable, Py_ssize_t(i))))
{ {
PyErr_Format(PyExc_TypeError, "Non-unit size in broadcastable vt dimension %i", i); PyErr_Format(PyExc_TypeError, "Non-unit size in broadcastable vt dimension %i", i);
Py_DECREF(py_data); Py_DECREF(py_data);
...@@ -1913,7 +1913,7 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s ...@@ -1913,7 +1913,7 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s
} }
for (int i = 0; i < data->nd; ++i) for (int i = 0; i < data->nd; ++i)
{ {
if ((data->dimensions[i] > 1) and PyInt_AsLong(PyTuple_GetItem(broadcastable, Py_ssize_t(i)))) if ((data->dimensions[i] > 1) && PyInt_AsLong(PyTuple_GetItem(broadcastable, Py_ssize_t(i))))
{ {
PyErr_Format(PyExc_TypeError, "Non-unit size in broadcastable dimension %i", i); PyErr_Format(PyExc_TypeError, "Non-unit size in broadcastable dimension %i", i);
Py_DECREF(data); Py_DECREF(data);
...@@ -2114,7 +2114,7 @@ CudaNdarray_is_c_contiguous(const CudaNdarray * self) ...@@ -2114,7 +2114,7 @@ CudaNdarray_is_c_contiguous(const CudaNdarray * self)
{ {
bool c_contiguous = true; bool c_contiguous = true;
int size = 1; int size = 1;
for (int i = self->nd-1; (i >= 0) and c_contiguous; --i) for (int i = self->nd-1; (i >= 0) && c_contiguous; --i)
{ {
if (CudaNdarray_HOST_DIMS(self)[i] == 1) if (CudaNdarray_HOST_DIMS(self)[i] == 1)
continue; continue;
...@@ -2763,7 +2763,7 @@ CudaNdarray_dimshuffle(CudaNdarray * self, unsigned int len, const int * pattern ...@@ -2763,7 +2763,7 @@ CudaNdarray_dimshuffle(CudaNdarray * self, unsigned int len, const int * pattern
} }
else else
{ {
if ((dims_taken[pattern[i]]) or (pattern[i]>= self->nd)) if ((dims_taken[pattern[i]]) || (pattern[i]>= self->nd))
{ {
PyErr_SetString(PyExc_ValueError, "invalid pattern for Cudandarray_dimshuffle"); PyErr_SetString(PyExc_ValueError, "invalid pattern for Cudandarray_dimshuffle");
free(newdims); free(newdims);
......
import sys, os, subprocess, logging import sys, os, subprocess, logging
from theano.gof.cmodule import (std_libs, std_lib_dirs, std_include_dirs, dlimport, from theano.gof.cmodule import (std_libs, std_lib_dirs, std_include_dirs, dlimport,
get_lib_extension, local_bitwidth) get_lib_extension, local_bitwidth)
from theano import config
import distutils import distutils
import commands import commands
...@@ -67,6 +66,7 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[ ...@@ -67,6 +66,7 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[
if preargs is None: if preargs is None:
preargs= [] preargs= []
else: preargs = list(preargs) else: preargs = list(preargs)
if sys.platform!='win32':
preargs.append('-fPIC') preargs.append('-fPIC')
no_opt = False no_opt = False
cuda_root = config.cuda.root cuda_root = config.cuda.root
...@@ -116,12 +116,18 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[ ...@@ -116,12 +116,18 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[
cmd = [nvcc_path, '-shared', '-g'] + preargs1 cmd = [nvcc_path, '-shared', '-g'] + preargs1
if config.nvcc.compiler_bindir: if config.nvcc.compiler_bindir:
cmd.extend(['--compiler-bindir', config.nvcc.compiler_bindir]) cmd.extend(['--compiler-bindir', config.nvcc.compiler_bindir])
if sys.platform!='win32':
if local_bitwidth() == 64: if local_bitwidth() == 64:
cmd.append('-m64') cmd.append('-m64')
cmd.extend(['-Xcompiler', ','.join(preargs2 +[ '-m64'])]) preargs2.append('-m64')
else: else:
cmd.append('-m32') cmd.append('-m32')
cmd.extend(['-Xcompiler', ','.join(preargs2 +[ '-m32'])]) preargs2.append('-m32')
if len(preargs2)>0:
cmd.extend(['-Xcompiler', ','.join(preargs2)])
if os.path.exists(os.path.join(config.cuda.root,'lib')): if os.path.exists(os.path.join(config.cuda.root,'lib')):
cmd.extend(['-Xlinker',','.join(['-rpath',os.path.join(config.cuda.root,'lib')])]) cmd.extend(['-Xlinker',','.join(['-rpath',os.path.join(config.cuda.root,'lib')])])
if sys.platform != 'darwin': if sys.platform != 'darwin':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论