Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2febf197
提交
2febf197
authored
6月 04, 2013
作者:
Frédéric Bastien
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1398 from HapeMask/py3k-fixes
Py3k Fixes Part 2
上级
5ee86171
fb445ae6
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
188 行增加
和
72 行删除
+188
-72
callcache.py
theano/gof/callcache.py
+2
-2
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+4
-6
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+145
-58
cuda_ndarray.cuh
theano/sandbox/cuda/cuda_ndarray.cuh
+27
-0
nvcc_compiler.py
theano/sandbox/cuda/nvcc_compiler.py
+10
-6
没有找到文件。
theano/gof/callcache.py
浏览文件 @
2febf197
...
...
@@ -8,7 +8,7 @@ class CallCache(object):
try
:
if
filename
is
None
:
raise
IOError
(
'bad filename'
)
#just goes to except
f
=
file
(
filename
,
'r'
)
f
=
open
(
filename
,
'r'
)
self
.
cache
=
cPickle
.
load
(
f
)
f
.
close
()
except
IOError
:
...
...
@@ -20,7 +20,7 @@ class CallCache(object):
#backport
#filename = self.filename if filename is None else filename
f
=
file
(
filename
,
'w'
)
f
=
open
(
filename
,
'w'
)
cPickle
.
dump
(
self
.
cache
,
f
)
f
.
close
()
...
...
theano/sandbox/cuda/basic_ops.py
浏览文件 @
2febf197
...
...
@@ -635,10 +635,8 @@ class GpuCAReduce(GpuOp):
# but tensor.elemwise.CAReduce has this exact same check so I guess
# this is OK to do
if
self
.
scalar_op
in
[
scal
.
minimum
,
scal
.
maximum
]:
conds
=
[]
for
i
in
xrange
(
nd_in
):
if
self
.
reduce_mask
[
i
]:
conds
.
append
(
"(CudaNdarray_HOST_DIMS(
%(x)
s)[
%(i)
s] == 0)"
%
locals
())
conds
=
[
"(CudaNdarray_HOST_DIMS(
%
s)[
%
d] == 0)"
%
(
x
,
i
)
for
i
in
xrange
(
nd_in
)
\
if
self
.
reduce_mask
[
i
]]
assert
len
(
conds
)
>
0
cond
=
"("
+
" || "
.
join
(
conds
)
+
")"
print
>>
sio
,
"""
...
...
@@ -663,7 +661,7 @@ class GpuCAReduce(GpuOp):
j
=
0
for
i
in
xrange
(
nd_in
):
if
not
self
.
reduce_mask
[
i
]:
print
>>
sio
,
" || (CudaNdarray_HOST_DIMS(
%(z)
s)[
%(j)
s] !=CudaNdarray_HOST_DIMS(
%
(x)
s)[
%(i)
s]) "
%
locals
(
)
print
>>
sio
,
" || (CudaNdarray_HOST_DIMS(
%(z)
s)[
%(j)
s] !=CudaNdarray_HOST_DIMS(
%
s)[
%
d]) "
%
(
x
,
i
)
j
+=
1
print
>>
sio
,
"""
...
...
@@ -791,7 +789,7 @@ class GpuCAReduce(GpuOp):
"""
%
locals
()
shapes_format
=
"shape=(
%
s)"
%
","
.
join
([
"
%
d"
]
*
node
.
inputs
[
0
]
.
ndim
)
shapes_data
=
","
.
join
([
"CudaNdarray_HOST_DIMS(
%
(x)
s)[
%(i)
s]"
%
locals
(
)
shapes_data
=
","
.
join
([
"CudaNdarray_HOST_DIMS(
%
s)[
%
d]"
%
(
x
,
i
)
for
i
in
range
(
node
.
inputs
[
0
]
.
ndim
)])
print
>>
sio
,
"""
);
...
...
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
2febf197
...
...
@@ -294,10 +294,10 @@ static void
CudaNdarray_dealloc
(
CudaNdarray
*
self
)
{
if
(
0
)
std
::
cerr
<<
"CudaNdarray dealloc "
<<
self
<<
" "
<<
self
->
devdata
<<
'\n'
;
if
(
self
->
ob_refcnt
>
1
)
if
(
Py_REFCNT
(
self
)
>
1
)
printf
(
"WARNING:CudaNdarray_dealloc called when there is still active reference to it.
\n
"
);
CudaNdarray_uninit
(
self
);
self
->
ob_type
->
tp_free
((
PyObject
*
)
self
);
Py_TYPE
(
self
)
->
tp_free
((
PyObject
*
)
self
);
--
_outstanding_mallocs
[
1
];
if
(
0
)
{
...
...
@@ -461,9 +461,9 @@ PyObject* CudaNdarray_ZEROS(int n, int * dims)
return
(
PyObject
*
)
rval
;
}
// declared as a static method (hence
"dummy"
is not used)
// declared as a static method (hence
1st parameter
is not used)
// Based on _Copy and _dimshuffle
PyObject
*
CudaNdarray_Zeros
(
PyObject
*
dummy
,
PyObject
*
shape
)
PyObject
*
CudaNdarray_Zeros
(
PyObject
*
_unused
,
PyObject
*
shape
)
{
if
(
!
shape
)
{
...
...
@@ -1226,7 +1226,7 @@ static PyMethodDef CudaNdarray_methods[] =
(
PyCFunction
)
CudaNdarray_DeepCopy
,
METH_O
,
"Create a copy of this object"
},
{
"zeros"
,
(
PyCFunction
)
CudaNdarray_Zeros
,
METH_STATIC
,
(
PyCFunction
)
CudaNdarray_Zeros
,
METH_STATIC
|
METH_O
,
"Create a new CudaNdarray with specified shape, filled with zeros."
},
{
"copy"
,
(
PyCFunction
)
CudaNdarray_Copy
,
METH_NOARGS
,
...
...
@@ -1817,56 +1817,100 @@ CudaNdarray_inplace_div(PyObject* py_self, PyObject * py_other)
return
py_self
;
}
// The PyNumberMethods struct layout changed in a non-trivial way from 2 to 3.
#if PY_MAJOR_VERSION == 3
static
PyNumberMethods
CudaNdarrayNumberMethods
=
{
(
binaryfunc
)
CudaNdarray_add
,
//binaryfunc nb_add; __add__
0
,
//binaryfunc nb_subtract; __sub__
0
,
//binaryfunc nb_multiply; __mul__
0
,
//binaryfunc nb_divide; __div__
0
,
//binaryfunc nb_remainder; __mod__
0
,
//binaryfunc nb_divmod; __divmod__
0
,
//ternaryfunc nb_power; __pow__
0
,
//unaryfunc nb_negative; __neg__
0
,
//unaryfunc nb_positive; __pos__
0
,
//unaryfunc nb_absolute; __abs__
0
,
//inquiry nb_nonzero; __nonzero__ /* Used by PyObject_IsTrue */
0
,
//unaryfunc nb_invert; __invert__
0
,
//binaryfunc nb_lshift; __lshift__
0
,
//binaryfunc nb_rshift; __rshift__
0
,
//binaryfunc nb_and; __and__
0
,
//binaryfunc nb_xor; __xor__
0
,
//binaryfunc nb_or; __or__
0
,
//coercion nb_coerce; __coerce__ /* Used by the coerce() function */
0
,
//unaryfunc nb_int; __int__
0
,
//unaryfunc nb_long; __long__
0
,
//unaryfunc nb_float; __float__
0
,
//unaryfunc nb_oct; __oct__
0
,
//unaryfunc nb_hex; __hex__
/* Added in release 2.0 */
(
binaryfunc
)
CudaNdarray_inplace_add
,
//binaryfunc nb_inplace_add; __iadd__
0
,
//binaryfunc nb_inplace_subtract; __isub__
0
,
//binaryfunc nb_inplace_multiply; __imul__
(
binaryfunc
)
CudaNdarray_inplace_div
,
//binaryfunc nb_inplace_divide; __idiv__
0
,
//binaryfunc nb_inplace_remainder; __imod__
0
,
//ternaryfunc nb_inplace_power; __ipow__
0
,
//binaryfunc nb_inplace_lshift; __ilshift__
0
,
//binaryfunc nb_inplace_rshift; __irshift__
0
,
//binaryfunc nb_inplace_and; __iand__
0
,
//binaryfunc nb_inplace_xor; __ixor__
0
,
//binaryfunc nb_inplace_or; __ior__
/* Added in release 2.2 */
0
,
//binaryfunc nb_floor_divide; __floordiv__
0
,
//binaryfunc nb_true_divide; __truediv__
0
,
//binaryfunc nb_inplace_floor_divide; __ifloordiv__
0
,
//binaryfunc nb_inplace_true_divide; __itruediv__
(
binaryfunc
)
CudaNdarray_add
,
//binaryfunc nb_add; __add__
0
,
//binaryfunc nb_subtract;
0
,
//binaryfunc nb_multiply;
0
,
//binaryfunc nb_remainder;
0
,
//binaryfunc nb_divmod;
0
,
//ternaryfunc nb_power;
0
,
//unaryfunc nb_negative;
0
,
//unaryfunc nb_positive;
0
,
//unaryfunc nb_absolute;
0
,
//inquiry nb_bool;
0
,
//unaryfunc nb_invert;
0
,
//binaryfunc nb_lshift;
0
,
//binaryfunc nb_rshift;
0
,
//binaryfunc nb_and;
0
,
//binaryfunc nb_xor;
0
,
//binaryfunc nb_or;
0
,
//unaryfunc nb_int;
0
,
//void *nb_reserved;
0
,
//unaryfunc nb_float;
(
binaryfunc
)
CudaNdarray_inplace_add
,
//binaryfunc nb_inplace_add; __iadd__
0
,
//binaryfunc nb_inplace_subtract;
0
,
//binaryfunc nb_inplace_multiply;
0
,
//binaryfunc nb_inplace_remainder;
0
,
//ternaryfunc nb_inplace_power;
0
,
//binaryfunc nb_inplace_lshift;
0
,
//binaryfunc nb_inplace_rshift;
0
,
//binaryfunc nb_inplace_and;
0
,
//binaryfunc nb_inplace_xor;
0
,
//binaryfunc nb_inplace_or;
0
,
//binaryfunc nb_floor_divide;
0
,
//binaryfunc nb_true_divide;
0
,
//binaryfunc nb_inplace_floor_divide;
(
binaryfunc
)
CudaNdarray_inplace_div
,
//binaryfunc nb_inplace_true_divide; __idiv__
0
,
//unaryfunc nb_index
};
#else
static
PyNumberMethods
CudaNdarrayNumberMethods
=
{
(
binaryfunc
)
CudaNdarray_add
,
//binaryfunc nb_add; __add__
0
,
//binaryfunc nb_subtract; __sub__
0
,
//binaryfunc nb_multiply; __mul__
0
,
//binaryfunc nb_divide; __div__
0
,
//binaryfunc nb_remainder; __mod__
0
,
//binaryfunc nb_divmod; __divmod__
0
,
//ternaryfunc nb_power; __pow__
0
,
//unaryfunc nb_negative; __neg__
0
,
//unaryfunc nb_positive; __pos__
0
,
//unaryfunc nb_absolute; __abs__
0
,
//inquiry nb_nonzero; __nonzero__ /* Used by PyObject_IsTrue */
0
,
//unaryfunc nb_invert; __invert__
0
,
//binaryfunc nb_lshift; __lshift__
0
,
//binaryfunc nb_rshift; __rshift__
0
,
//binaryfunc nb_and; __and__
0
,
//binaryfunc nb_xor; __xor__
0
,
//binaryfunc nb_or; __or__
0
,
//coercion nb_coerce; __coerce__ /* Used by the coerce() function */
0
,
//unaryfunc nb_int; __int__
0
,
//unaryfunc nb_long; __long__
0
,
//unaryfunc nb_float; __float__
0
,
//unaryfunc nb_oct; __oct__
0
,
//unaryfunc nb_hex; __hex__
/* Added in release 2.0 */
(
binaryfunc
)
CudaNdarray_inplace_add
,
//binaryfunc nb_inplace_add; __iadd__
0
,
//binaryfunc nb_inplace_subtract; __isub__
0
,
//binaryfunc nb_inplace_multiply; __imul__
(
binaryfunc
)
CudaNdarray_inplace_div
,
//binaryfunc nb_inplace_divide; __idiv__
0
,
//binaryfunc nb_inplace_remainder; __imod__
0
,
//ternaryfunc nb_inplace_power; __ipow__
0
,
//binaryfunc nb_inplace_lshift; __ilshift__
0
,
//binaryfunc nb_inplace_rshift; __irshift__
0
,
//binaryfunc nb_inplace_and; __iand__
0
,
//binaryfunc nb_inplace_xor; __ixor__
0
,
//binaryfunc nb_inplace_or; __ior__
/* Added in release 2.2 */
0
,
//binaryfunc nb_floor_divide; __floordiv__
0
,
//binaryfunc nb_true_divide; __truediv__
0
,
//binaryfunc nb_inplace_floor_divide; __ifloordiv__
0
,
//binaryfunc nb_inplace_true_divide; __itruediv__
#if PY_MINOR_VERSION > 4
/* Added in release 2.5 */
0
//unaryfunc nb_index; __index__
/* Added in release 2.5 */
0
//unaryfunc nb_index; __index__
#endif
};
#endif
/////////////////////
...
...
@@ -1970,7 +2014,7 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
int
d_dim
=
CudaNdarray_HOST_DIMS
(
self
)[
0
];
Py_ssize_t
start
,
stop
,
step
,
slen
;
if
(
PySlice_GetIndicesEx
(
(
PySliceObject
*
)
key
,
d_dim
,
&
start
,
&
stop
,
&
step
,
&
slen
))
if
(
PySlice_GetIndicesEx
(
SLICE_CAST
(
key
)
,
d_dim
,
&
start
,
&
stop
,
&
step
,
&
slen
))
{
if
(
verbose
)
fprintf
(
stderr
,
"PySlice_GetIndicesEx failed
\n
"
);
...
...
@@ -2067,7 +2111,7 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
if
(
PySlice_Check
(
key_d
))
{
Py_ssize_t
start
,
stop
,
step
,
slen
;
if
(
PySlice_GetIndicesEx
(
(
PySliceObject
*
)
key_d
,
CudaNdarray_HOST_DIMS
(
self
)[
d
],
&
start
,
&
stop
,
&
step
,
&
slen
))
if
(
PySlice_GetIndicesEx
(
SLICE_CAST
(
key_d
)
,
CudaNdarray_HOST_DIMS
(
self
)[
d
],
&
start
,
&
stop
,
&
step
,
&
slen
))
{
Py_DECREF
(
rval
);
return
NULL
;
...
...
@@ -2592,12 +2636,14 @@ static PyGetSetDef CudaNdarray_getset[] = {
{
NULL
,
NULL
,
NULL
,
NULL
}
/* Sentinel */
};
static
PyTypeObject
CudaNdarrayType
=
{
#if PY_MAJOR_VERSION >= 3
PyVarObject_HEAD_INIT
(
NULL
,
0
)
#else
PyObject_HEAD_INIT
(
NULL
)
0
,
/*ob_size*/
#endif
"CudaNdarray"
,
/*tp_name*/
sizeof
(
CudaNdarray
),
/*tp_basicsize*/
0
,
/*tp_itemsize*/
...
...
@@ -2616,7 +2662,12 @@ static PyTypeObject CudaNdarrayType =
0
,
/*tp_getattro*/
0
,
/*tp_setattro*/
0
,
/*tp_as_buffer*/
#if PY_MAJOR_VERSION >= 3
// Py_TPFLAGS_CHECKTYPES is always true and was removed in Python 3.
Py_TPFLAGS_DEFAULT
|
Py_TPFLAGS_BASETYPE
,
/*tp_flags*/
#else
Py_TPFLAGS_DEFAULT
|
Py_TPFLAGS_BASETYPE
|
Py_TPFLAGS_CHECKTYPES
,
/*tp_flags*/
#endif
"CudaNdarray objects"
,
/* tp_doc */
0
,
/* tp_traverse */
0
,
/* tp_clear */
...
...
@@ -3049,21 +3100,53 @@ static PyMethodDef module_methods[] = {
#ifndef PyMODINIT_FUNC
/* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
#define CNDA_MOD_NAME "cuda_ndarray"
#define CNDA_DOCSTRING "CUDA implementation of a numpy ndarray-like object."
#if PY_MAJOR_VERSION == 3
static
struct
PyModuleDef
cuda_ndarray_moduledef
=
{
PyModuleDef_HEAD_INIT
,
CNDA_MOD_NAME
,
CNDA_DOCSTRING
,
-
1
,
/* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
module_methods
};
PyMODINIT_FUNC
PyInit_cuda_ndarray
(
void
)
#else
PyMODINIT_FUNC
initcuda_ndarray
(
void
)
#endif
{
import_array
();
PyObject
*
m
;
if
(
PyType_Ready
(
&
CudaNdarrayType
)
<
0
)
if
(
PyType_Ready
(
&
CudaNdarrayType
)
<
0
)
{
#if PY_MAJOR_VERSION == 3
return
NULL
;
#else
return
;
#endif
}
m
=
Py_InitModule3
(
"cuda_ndarray"
,
module_methods
,
"Example module that creates an extension type."
);
#if PY_MAJOR_VERSION == 3
m
=
PyModule_Create
(
&
cuda_ndarray_moduledef
);
#else
m
=
Py_InitModule3
(
CNDA_MOD_NAME
,
module_methods
,
CNDA_DOCSTRING
);
#endif
if
(
m
==
NULL
)
if
(
m
==
NULL
)
{
#if PY_MAJOR_VERSION == 3
return
NULL
;
#else
return
;
#endif
}
Py_INCREF
(
&
CudaNdarrayType
);
PyModule_AddObject
(
m
,
"CudaNdarray"
,
(
PyObject
*
)
&
CudaNdarrayType
);
...
...
@@ -3088,6 +3171,10 @@ initcuda_ndarray(void)
std
::
cerr
<<
"Error in SetDevice:"
<<
cudaGetErrorString
(
err
)
<<
"
\n
"
;
}
}
#if PY_MAJOR_VERSION == 3
return
m
;
#endif
}
...
...
@@ -3106,7 +3193,7 @@ CudaNdarray_Check(const PyObject * ob)
int
CudaNdarray_CheckExact
(
const
PyObject
*
ob
)
{
return
((
ob
->
ob_type
==
&
CudaNdarrayType
)
?
1
:
0
);
return
((
Py_TYPE
(
ob
)
==
&
CudaNdarrayType
)
?
1
:
0
);
}
PyObject
*
...
...
theano/sandbox/cuda/cuda_ndarray.cuh
浏览文件 @
2febf197
#ifndef _CUDA_NDARRAY_H
#define _CUDA_NDARRAY_H
// Defines for Python 2/3 compatibility.
#if PY_MAJOR_VERSION == 3
// Py3k treats all ints as longs.
#define PyInt_Check PyLong_Check
#define PyInt_CheckExact PyLong_CheckExact
#define PyInt_AsLong PyLong_AsLong
#define PyInt_FromLong PyLong_FromLong
#define PyNumber_Int PyNumber_Long
// Py3k strings are unicode, these mimic old functionality.
#define PyString_Check PyUnicode_Check
#define PyString_FromString PyUnicode_FromString
#define PyString_AsString PyUnicode_AsUTF8
#define PyString_FromStringAndSize PyUnicode_FromStringAndSize
#define PyString_Size PyUnicode_GET_SIZE
#define PyCObject_AsVoidPtr NpyCapsule_AsVoidPtr
#define PyCObject_GetDesc NpyCapsule_GetDesc
#define PyCObject_Check NpyCapsule_Check
// Python 3 expects a PyObject* as the first argument to PySlice_GetIndicesEx().
#define SLICE_CAST(x) (x)
#else
// Python 2 expects a PySliceObject* as the first argument to PySlice_GetIndicesEx().
#define SLICE_CAST(x) ((PySliceObject*)(x))
#endif
#include <numpy/arrayobject.h>
#include <stdio.h>
...
...
theano/sandbox/cuda/nvcc_compiler.py
浏览文件 @
2febf197
...
...
@@ -9,6 +9,7 @@ import warnings
import
numpy
from
theano.compat
import
decode
,
decode_iter
from
theano.gof
import
local_bitwidth
from
theano.gof.cc
import
hash_from_file
from
theano.gof.cmodule
import
(
std_libs
,
std_lib_dirs
,
...
...
@@ -69,10 +70,13 @@ def is_nvcc_available():
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
p
.
wait
()
s
=
p
.
stdout
.
readlines
()[
-
1
]
.
split
(
','
)[
1
]
.
strip
()
.
split
()
assert
s
[
0
]
==
'release'
ver_line
=
decode
(
p
.
stdout
.
readlines
()[
-
1
])
build
,
version
=
ver_line
.
split
(
','
)[
1
]
.
strip
()
.
split
()
assert
build
==
'release'
global
nvcc_version
nvcc_version
=
s
[
1
]
nvcc_version
=
version
try
:
set_version
()
return
True
...
...
@@ -247,7 +251,7 @@ class NVCC_compiler(object):
lib_dirs
.
append
(
python_lib
)
cppfilename
=
os
.
path
.
join
(
location
,
'mod.cu'
)
cppfile
=
file
(
cppfilename
,
'w'
)
cppfile
=
open
(
cppfilename
,
'w'
)
_logger
.
debug
(
'Writing module C++ code to
%
s'
,
cppfilename
)
...
...
@@ -354,7 +358,7 @@ class NVCC_compiler(object):
os
.
chdir
(
location
)
p
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
nvcc_stdout
,
nvcc_stderr
=
p
.
communicate
()[:
2
]
nvcc_stdout
,
nvcc_stderr
=
decode_iter
(
p
.
communicate
()[:
2
])
finally
:
os
.
chdir
(
orig_dir
)
...
...
@@ -401,7 +405,7 @@ class NVCC_compiler(object):
if
py_module
:
#touch the __init__ file
file
(
os
.
path
.
join
(
location
,
"__init__.py"
),
'w'
)
.
close
()
open
(
os
.
path
.
join
(
location
,
"__init__.py"
),
'w'
)
.
close
()
return
dlimport
(
lib_filename
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论