Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
56d152c8
提交
56d152c8
authored
7月 21, 2009
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test_elemwise0 passed
上级
5d16a644
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
373 行增加
和
281 行删除
+373
-281
__init__.py
__init__.py
+6
-0
basic_ops.py
basic_ops.py
+263
-0
test_basic_ops.py
tests/test_basic_ops.py
+35
-17
type.py
type.py
+65
-259
type_support.cu
type_support.cu
+4
-5
没有找到文件。
__init__.py
浏览文件 @
56d152c8
from
.type
import
CudaNdarrayType
from
.var
import
(
CudaNdarrayVariable
,
CudaNdarrayConstant
,
CudaNdarraySharedVariable
,
shared_constructor
)
basic_ops.py
浏览文件 @
56d152c8
import
StringIO
import
numpy
from
theano
import
Op
,
Type
,
Apply
,
Variable
,
Constant
from
theano
import
tensor
,
scalar
from
.type
import
CudaNdarrayType
from
.type_support
import
filter
as
type_support_filter
def
as_cuda_ndarray_variable
(
x
):
if
hasattr
(
x
,
'_as_CudaNdarrayVariable'
):
return
x
.
_as_CudaNdarrayVariable
()
tensor_x
=
tensor
.
as_tensor_variable
(
x
)
return
GpuFromHost
()(
tensor_x
)
class
HostFromGpu
(
Op
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
):
if
not
isinstance
(
x
.
type
,
CudaNdarrayType
):
raise
TypeError
(
x
)
return
Apply
(
self
,
[
x
],
[
tensor
.
TensorType
(
dtype
=
x
.
dtype
,
broadcastable
=
x
.
broadcastable
)()])
def
perform
(
self
,
node
,
(
x
,),
(
z
,)):
z
[
0
]
=
numpy
.
asarray
(
x
)
def
grad
(
self
,
inputs
,
(
gz
,)):
return
[
GpuFromHost
()(
gz
)]
class
GpuFromHost
(
Op
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
):
if
not
isinstance
(
x
.
type
,
tensor
.
TensorType
):
raise
TypeError
(
x
)
return
Apply
(
self
,
[
x
],
[
CudaNdarrayType
(
broadcastable
=
x
.
broadcastable
)()])
def
perform
(
self
,
node
,
(
x
,),
(
z
,)):
z
[
0
]
=
type_support_filter
(
numpy
.
asarray
(
x
,
dtype
=
'float32'
),
tuple
([
0
]
*
x
.
ndim
),
0
)
def
grad
(
self
,
inputs
,
(
gz
,)):
return
[
HostFromGpu
()(
gz
)]
class
GpuAdd
(
Op
):
def
__eq__
(
self
,
other
):
self
.
scalar_op
=
scalar
.
add
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
a
,
b
):
_a
=
as_cuda_ndarray_variable
(
a
)
_b
=
as_cuda_ndarray_variable
(
b
)
if
_a
.
type
.
broadcastable
!=
_b
.
type
.
broadcastable
:
raise
NotImplementedError
(
'different bcastable'
)
return
Apply
(
self
,
[
_a
,
_b
],
[
CudaNdarrayType
(
broadcastable
=
_a
.
broadcastable
)()])
def
perform
(
self
,
node
,
(
a
,
b
),
(
z
,)):
aval
=
numpy
.
asarray
(
a
,
dtype
=
'float32'
)
bval
=
numpy
.
asarray
(
b
,
dtype
=
'float32'
)
z
[
0
]
=
type_support_filter
(
aval
+
bval
,
(
0
,)
*
len
(
zval
.
shape
),
0
)
def
grad
(
self
,
inputs
,
(
gz
,)):
return
[
gz
for
i
in
inputs
]
def
c_support_code
(
self
):
return
"""
#define INTDIV_POW2(a, b) (a >> b)
#define INTMOD_POW2(a, b) (a & ((1<<b)-1))
"""
def
c_src_kernel
(
self
,
node
,
nodename
):
nd
=
node
.
outputs
[
0
]
.
type
.
ndim
sio
=
StringIO
.
StringIO
()
#TODO: optimize by passing the log2 of each dim, as well as the mask of 1s that we need
# to compute the modulo
print
>>
sio
,
"static __global__ void kernel_
%
s(unsigned int numEls,"
%
nodename
print
>>
sio
,
"
\t
"
,
", "
.
join
(
"unsigned int log2_dim
%
i"
%
i
for
i
in
xrange
(
nd
))
#declare inputs
for
ipos
,
i
in
enumerate
(
node
.
inputs
):
print
>>
sio
,
"
\t
,"
,
", "
.
join
(
"int i
%
i_str_
%
i"
%
(
ipos
,
d
)
for
d
in
xrange
(
nd
))
print
>>
sio
,
"
\t
,"
,
"const float * i
%
i_data"
%
ipos
#declare outputs
for
ipos
,
i
in
enumerate
(
node
.
outputs
):
print
>>
sio
,
"
\t
,"
,
", "
.
join
(
"int o
%
i_str_
%
i"
%
(
ipos
,
d
)
for
d
in
xrange
(
nd
))
print
>>
sio
,
"
\t
,"
,
"float * o
%
i_data"
%
ipos
print
>>
sio
,
"
\t
)
\n
{"
print
>>
sio
,
" const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;"
print
>>
sio
,
" const unsigned int numThreads = blockDim.x * gridDim.x;"
#TODO: insert code to check for strides of 1, and use a different loop
#loop over the elements to be treated by this kernel call
print
>>
sio
,
" for (unsigned int i = idx; i < numEls; i += numThreads) {"
# calculate the data pointers for all arguments
print
>>
sio
,
" unsigned int ii = i;"
for
ipos
,
i
in
enumerate
(
node
.
inputs
):
print
>>
sio
,
" const float * ii_i
%
i_data = i
%
i_data;"
%
(
ipos
,
ipos
)
for
ipos
,
i
in
enumerate
(
node
.
outputs
):
print
>>
sio
,
" float * ii_o
%
i_data = o
%
i_data;"
%
(
ipos
,
ipos
)
for
d
in
xrange
(
nd
-
1
,
-
1
,
-
1
):
if
d
>
0
:
print
>>
sio
,
" unsigned int pos
%
i = INTMOD_POW2(ii, log2_dim
%
i);"
%
(
d
,
d
)
print
>>
sio
,
" ii = INTDIV_POW2(ii, log2_dim
%
i);"
%
d
else
:
print
>>
sio
,
" unsigned int pos
%
i = ii;"
%
d
for
ipos
,
i
in
enumerate
(
node
.
inputs
):
print
>>
sio
,
" ii_i
%
i_data += pos
%
i * i
%
i_str_
%
i;"
%
(
ipos
,
d
,
ipos
,
d
)
for
ipos
,
i
in
enumerate
(
node
.
outputs
):
print
>>
sio
,
" ii_o
%
i_data += pos
%
i * o
%
i_str_
%
i;"
%
(
ipos
,
d
,
ipos
,
d
)
# perform the scalar operation on the input and output references
print
>>
sio
,
" "
,
self
.
scalar_op
.
c_code
(
None
,
None
,
[
'ii_i
%
i_data[0]'
%
ipos
for
ipos
,
i
in
enumerate
(
node
.
inputs
)],
[
'ii_o
%
i_data[0]'
%
ipos
for
ipos
,
i
in
enumerate
(
node
.
outputs
)],
sub
=
dict
(
fail
=
'return;'
))
#TODO: set a failure code somehow!!!
print
>>
sio
,
" }"
#TODO: insert runtime stride checks that select the best loop order either here, or in
# the host code that launched the kernel (host code probably better spot)
#indent = " "*(4*d+7)
#for ipos, i in enumerate(node.inputs):
#print >> sio, indent, "const float * i%i" % ipos, '= i%i_data', ''
print
>>
sio
,
"}"
print
sio
.
getvalue
()
return
sio
.
getvalue
()
def
c_support_code_apply
(
self
,
node
,
nodename
):
return
self
.
c_src_kernel
(
node
,
nodename
)
+
\
self
.
c_src_callkernel
(
node
,
nodename
)
def
c_src_callkernel
(
self
,
node
,
nodename
):
nd
=
node
.
outputs
[
0
]
.
type
.
ndim
d
=
dict
()
assert
nd
==
2
kernel_call_args
=
(
"numEls, log2_dims[0], log2_dims[1]"
", a_str[0], a_str[1], a_data"
", b_str[0], b_str[1], b_data"
", z_str[0], z_str[1], z_data"
)
d
.
update
(
locals
())
return
"""
static void callkernel_
%(nodename)
s(const unsigned int numEls, const int d,
const int * dims, int * log2_dims,
const float * a_data, const int * a_str,
const float * b_data, const int * b_str,
float * z_data, const int * z_str)
{
if (d ==
%(nd)
s)
{
int threads_per_block = std::min(numEls, (unsigned int)NUM_VECTOR_OP_THREADS_PER_BLOCK);
//a ceil would be better here
int n_blocks = std::min(numEls/threads_per_block + 1, (unsigned int)NUM_VECTOR_OP_BLOCKS);
kernel_
%(nodename)
s<<<n_blocks, threads_per_block>>>(
%(kernel_call_args)
s);
std::cerr << "ADDCALL a str" << a_str[0] << " "<< a_str[1] << "
\\
n";
std::cerr << "ADDCALL a data" << a_data << "
\\
n";
std::cerr << "ADDCALL b str" << b_str[0] << " "<< b_str[1] << "
\\
n";
std::cerr << "ADDCALL b data" << b_data << "
\\
n";
std::cerr << "ADDCALL z str" << z_str[0] << " "<< z_str[1] << "
\\
n";
std::cerr << "ADDCALL z data" << z_data << "
\\
n";
}
else
{
std::cerr << "_ADDCALL d " << d << "
\\
n";
unsigned int dim_d = dims[d];
std::cerr << "_ADDCALL dim_d " << dim_d << "
\\
n";
int log2_dim = 0;
while(dim_d)
{
std::cerr << "___ADDCALL d " << d << " " << dim_d << "
\\
n";
if (dim_d&1)
{
log2_dims[d] = log2_dim;
std::cerr << "___ADDCALL a str" << a_str[0] << " "<< a_str[1] << "
\\
n";
std::cerr << "___ADDCALL a data" << a_data << "
\\
n";
std::cerr << "___ADDCALL b str" << b_str[0] << " "<< b_str[1] << "
\\
n";
std::cerr << "___ADDCALL b data" << b_data << "
\\
n";
std::cerr << "___ADDCALL z str" << z_str[0] << " "<< z_str[1] << "
\\
n";
std::cerr << "___ADDCALL z data" << z_data << "
\\
n";
callkernel_
%(nodename)
s(numEls * (1<<log2_dim), d+1,
dims, log2_dims,
a_data, a_str,
b_data, b_str,
z_data, z_str);
a_data += (1 << log2_dim) * a_str[d];
b_data += (1 << log2_dim) * b_str[d];
z_data += (1 << log2_dim) * z_str[d];
}
log2_dim += 1;
dim_d >>= 1;
}
}
}
"""
%
d
def
c_code
(
self
,
node
,
nodename
,
(
a
,
b
),
(
z
,),
sub
):
d
=
dict
(
sub
)
nd
=
node
.
outputs
[
0
]
.
type
.
ndim
d
.
update
(
locals
())
return
"""
std::cerr << "ADD start
\\
n";
//standard elemwise size checks
if (cnda_
%(a)
s->nd != cnda_
%(b)
s->nd)
{
PyErr_SetString(PyExc_TypeError, "need same number of dims");
return NULL;
}
//standard elemwise dim checks
unsigned int size = 1;
for (int i = 0; i< cnda_
%(a)
s->nd; ++i)
{
if (cnda_
%(a)
s->dim[i] != cnda_
%(b)
s->dim[i])
{
PyErr_SetString(PyExc_TypeError, "need same dimensions");
return NULL;
}
size *= (unsigned int) cnda_
%(a)
s->dim[i];
}
std::cerr << "ADD size " << size << "
\\
n";
if (cnda_
%(z)
s){
//TODO: check if we can maybe use existing storage
Py_XDECREF(cnda_
%(z)
s);
cnda_
%(z)
s = NULL;
std::cerr << "ADD decref z
\\
n";
}
if (NULL == cnda_
%(z)
s)
{
cnda_
%(z)
s = (CudaNdarray*)CudaNdarray_new_null();
if (!cnda_
%(z)
s)
{
%(fail)
s;
}
if (CudaNdarray_alloc_contiguous(cnda_
%(z)
s, cnda_
%(a)
s->nd, cnda_
%(a)
s->dim))
{
Py_XDECREF(cnda_
%(z)
s);
cnda_
%(z)
s = NULL;
%(fail)
s;
}
}
std::cerr << "ADD z nd" << cnda_
%(z)
s->nd << "
\\
n";
std::cerr << "ADD z str" << cnda_
%(z)
s->str[0] << " "<< cnda_
%(z)
s->str[1] << "
\\
n";
std::cerr << "ADD z data" << cnda_
%(z)
s->devdata << "
\\
n";
{ //new block so that failure gotos don't skip over variable initialization
int log2_dims[
%(nd)
s];
callkernel_
%(nodename)
s(1, 0, CudaNdarray_DIMS(cnda_
%(z)
s), log2_dims,
CudaNdarray_DEV_DATA(cnda_
%(a)
s), CudaNdarray_STRIDES(cnda_
%(a)
s),
CudaNdarray_DEV_DATA(cnda_
%(b)
s), CudaNdarray_STRIDES(cnda_
%(b)
s),
CudaNdarray_DEV_DATA(cnda_
%(z)
s), CudaNdarray_STRIDES(cnda_
%(z)
s));
cudaThreadSynchronize();
cudaError_t err = cudaGetLastError();
if( cudaSuccess != err)
{
PyErr_Format(PyExc_RuntimeError, "Cuda error:
%%
s:
%%
s.
\\
n", "kExp", cudaGetErrorString(err));
Py_XDECREF(cnda_
%(z)
s);
cnda_
%(z)
s = NULL;
%(fail)
s;
}
}
"""
%
d
def
c_code_cache_version
(
self
):
return
()
tests/test_basic_ops.py
浏览文件 @
56d152c8
...
@@ -4,35 +4,53 @@ from theano import tensor
...
@@ -4,35 +4,53 @@ from theano import tensor
import
numpy
import
numpy
import
gputensor
as
gpt
import
theano_cuda_ndarray
as
tcn
def
test0
():
def
test
_elemwise
0
():
a
=
gpt
.
gpu_tensor_shared_constructor
(
numpy
.
random
.
rand
(
3
,
4
),
'a'
)
a
=
tcn
.
shared_constructor
(
numpy
.
random
.
rand
(
4
,
4
),
'a'
)
b
=
tensor
.
dmatrix
()
b
=
tensor
.
dmatrix
()
f
=
pfunc
([
b
],
[],
updates
=
[(
a
,
a
+
b
)])
f
=
pfunc
([
b
],
[],
updates
=
[(
a
,
a
+
b
)])
a0
=
a
.
value
*
1.0
a0
=
a
.
value
*
1.0
f
(
numpy
.
ones
((
3
,
4
)))
print
'BEFORE ADD'
,
a
.
value
f
(
numpy
.
ones
((
4
,
4
)))
print
f
.
maker
.
env
.
toposort
()
print
f
.
maker
.
env
.
toposort
()
print
'AFTER ADD'
,
a
.
value
assert
numpy
.
all
(
a0
+
1.0
==
a
.
value
)
assert
numpy
.
all
(
a0
+
1.0
==
a
.
value
)
def
test1
():
def
test_elemwise1
():
""" Several kinds of elemwise expressions with no broadcasting, non power-of-two shape """
a
=
gpt
.
gpu_tensor_shared_constructor
(
numpy
.
random
.
rand
(
3
,
4
),
'a'
)
shape
=
(
3
,
4
)
a
=
tcn
.
shared_constructor
(
numpy
.
random
.
rand
(
*
shape
),
'a'
)
b
=
tensor
.
dmatrix
()
b
=
tensor
.
dmatrix
()
f
=
pfunc
([
b
],
[],
updates
=
[(
a
,
a
+
b
*
tensor
.
exp
(
b
**
a
))])
f
=
pfunc
([
b
],
[],
updates
=
[(
a
,
a
+
b
)])
#let debugmode catch any mistakes
for
i
,
node
in
enumerate
(
f
.
maker
.
env
.
toposort
()):
f
(
numpy
.
ones
(
shape
))
print
'test1 toposort'
,
i
,
node
def
test_elemwise2
():
a0
=
a
.
value
*
1.0
""" Several kinds of elemwise expressions with dimension permutations """
f
(
numpy
.
ones
((
3
,
4
)))
shape
=
(
3
,
4
,
5
,
6
)
assert
numpy
.
all
(
a0
+
1.0
==
a
.
value
)
a
=
tcn
.
shared_constructor
(
numpy
.
random
.
rand
(
*
shape
),
'a'
)
b
=
tensor
.
Tensor
(
dtype
=
'float32'
,
broadcastable
=
[
0
]
*
len
(
shape
))()
f
=
pfunc
([
b
],
[],
updates
=
[(
a
,
(
a
+
b
)
.
dimshuffle
([
2
,
0
,
3
,
1
])
*
tensor
.
exp
(
b
**
a
)
.
dimshuffle
([
2
,
0
,
3
,
1
]))])
#let debugmode catch errors
f
(
numpy
.
ones
(
shape
))
def
test_elemwise3
():
""" Several kinds of elemwise expressions with dimension permutations and broadcasting"""
shape
=
(
3
,
4
,
5
,
6
)
a
=
tcn
.
shared_constructor
(
numpy
.
random
.
rand
(
*
shape
),
'a'
)
b
=
tensor
.
dvector
()
f
=
pfunc
([
b
],
[],
updates
=
[(
a
,
(
a
+
b
)
.
dimshuffle
([
2
,
0
,
3
,
1
])
*
tensor
.
exp
(
1
+
b
**
a
)
.
dimshuffle
([
2
,
0
,
3
,
1
]))])
#let debugmode catch errors
f
(
numpy
.
ones
(
6
))
type.py
浏览文件 @
56d152c8
import
sys
import
sys
,
os
import
numpy
import
numpy
from
theano
import
Op
,
Type
,
Apply
,
Variable
,
Constant
from
theano
import
Op
,
Type
,
Apply
,
Variable
,
Constant
from
theano
import
tensor
from
theano
import
tensor
from
theano.compile.sandbox.sharedvalue
import
shared
,
SharedVariable
,
shared_constructor
import
cuda_ndarray
# the module
import
cuda_ndarray
class
_tensor_operators
(
object
):
from
.type_support
import
filter
as
type_support_filter
def
_as_TensorVariable
(
self
):
return
HostFromGpu
()(
self
)
def
_as_CudaNdarrayVariable
(
self
):
return
self
dtype
=
property
(
lambda
s
:
s
.
type
.
dtype
)
from
.nvcc_compiler
import
nvcc_module_compile_str
broadcastable
=
property
(
lambda
s
:
s
.
type
.
broadcastable
)
ndim
=
property
(
lambda
s
:
s
.
type
.
ndim
)
class
CudaNdarrayType
(
Type
):
class
CudaNdarrayType
(
Type
):
def
__init__
(
self
,
dtype
,
broadcastable
,
name
=
None
):
typenum
=
11
# Until hardware improves, this class deals with floats.
self
.
typenum
=
numpy
.
dtype
(
dtype
)
.
num
self
.
dtype
=
str
(
dtype
)
dtype
=
'float32'
Variable
=
None
""" This will be set to the Variable type corresponding to this class.
That variable type is `CudaNdarrayVariable` defined in the ``var.py`` file beside this one.
:note:
The var file depends on the file basic_ops.py, which depends on this file.
A cyclic dependency is avoided by not hardcoding ``Variable = CudaNdarrayVariable``.
"""
Constant
=
None
""" This will be set to `CudaNdarrayConstant` defined in ``var.py``
:note:
The var file depends on the file basic_ops.py, which depends on this file.
A cyclic dependency is avoided by not hardcoding this class.
"""
SharedVariable
=
None
""" This will be set to `CudaNdarraySharedVariable` defined in ``var.py``
:note:
The var file depends on the file basic_ops.py, which depends on this file.
A cyclic dependency is avoided by not hardcoding this class.
"""
def
__init__
(
self
,
broadcastable
,
name
=
None
):
self
.
broadcastable
=
tuple
(
broadcastable
)
self
.
broadcastable
=
tuple
(
broadcastable
)
self
.
name
=
name
self
.
name
=
name
self
.
dtype_specs
()
# error checking is done there
self
.
dtype_specs
()
# error checking is done there
def
filter
(
self
,
data
,
strict
=
False
):
def
filter
(
self
,
data
,
strict
=
False
):
typenum
=
numpy
.
dtype
(
self
.
dtype
)
.
num
return
type_support_filter
(
data
,
self
.
broadcastable
,
strict
)
print
>>
sys
.
stderr
,
"bcastable"
,
self
.
broadcastable
return
tensorview_module
.
filter
(
data
,
typenum
,
self
.
broadcastable
,
strict
)
def
dtype_specs
(
self
):
def
dtype_specs
(
self
):
"""Return a tuple (python type, c type, numpy typenum) that corresponds to
"""Return a tuple (python type, c type, numpy typenum) that corresponds to
...
@@ -57,59 +76,11 @@ class CudaNdarrayType(Type):
...
@@ -57,59 +76,11 @@ class CudaNdarrayType(Type):
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
"""Compare True iff other is the same kind of CudaNdarrayType"""
"""Compare True iff other is the same kind of CudaNdarrayType"""
return
type
(
self
)
==
type
(
other
)
and
other
.
typenum
==
self
.
typenum
and
other
.
broadcastable
==
self
.
broadcastable
return
type
(
self
)
==
type
(
other
)
and
other
.
broadcastable
==
self
.
broadcastable
def
values_eq_approx
(
self
,
a
,
b
):
if
type
(
a
)
is
numpy
.
ndarray
and
type
(
b
)
is
numpy
.
ndarray
:
if
a
.
shape
!=
b
.
shape
:
return
False
if
a
.
dtype
!=
b
.
dtype
:
return
False
if
'int'
in
str
(
a
.
dtype
):
return
numpy
.
all
(
a
==
b
)
elif
a
.
shape
==
():
#for comparing scalars, use broadcasting.
# Note: according to James B, there was a reason for the
# following two lines, that may seem weird at first glance.
# If someone can figure out what it is, please say it here!
ones
=
numpy
.
ones
(
2
)
return
numpy
.
allclose
(
ones
*
a
,
ones
*
b
)
#elif str(a.dtype).startswith('complex'):
# print >> sys.stderr, 'WARNING: skipping comparison of complex'
# return True
else
:
cmp
=
numpy
.
allclose
(
a
,
b
)
if
cmp
:
# Numpy claims they are close, this is good enough for us.
return
True
# Numpy is unhappy, but it does not necessarily mean that a and
# b are different. Indeed, Numpy does not like missing values
# and will return False whenever some are found in a or b.
# The proper way would be to use the MaskArray stuff available
# in Numpy. However, it looks like it has been added to Numpy's
# core recently, so it may not be available to everyone. Thus,
# for now we use a home-made recipe, that should probably be
# revisited in the future.
a_missing
=
numpy
.
isnan
(
a
)
if
not
a_missing
.
any
():
# There are no missing values in a, thus this is not the
# reason why numpy.allclose(a, b) returned False.
return
False
# The following line is what numpy.allclose bases its decision
# upon, according to its documentation.
rtol
=
1.0000000000000001e-05
atol
=
1e-8
cmp_elemwise
=
(
numpy
.
absolute
(
a
-
b
)
<=
(
atol
+
rtol
*
numpy
.
absolute
(
b
)))
# Find places where both a and b have missing values.
both_missing
=
a_missing
*
numpy
.
isnan
(
b
)
# Combine all information.
return
(
cmp_elemwise
+
both_missing
)
.
all
()
return
False
def
__hash__
(
self
):
def
__hash__
(
self
):
"""Hash equal for same kinds of CudaNdarrayType"""
"""Hash equal for same kinds of CudaNdarrayType"""
return
hash
(
type
(
self
))
^
hash
(
self
.
typenum
)
^
hash
(
self
.
broadcastable
)
return
hash
(
type
(
self
))
^
hash
(
self
.
broadcastable
)
ndim
=
property
(
lambda
self
:
len
(
self
.
broadcastable
),
doc
=
"number of dimensions"
)
ndim
=
property
(
lambda
self
:
len
(
self
.
broadcastable
),
doc
=
"number of dimensions"
)
"""Number of dimensions
"""Number of dimensions
...
@@ -127,7 +98,7 @@ class CudaNdarrayType(Type):
...
@@ -127,7 +98,7 @@ class CudaNdarrayType(Type):
A pretty name to identify this `Variable` when printing and debugging
A pretty name to identify this `Variable` when printing and debugging
"""
"""
return
CudaNdarray
Variable
(
self
,
name
=
name
)
return
self
.
Variable
(
self
,
name
=
name
)
def
__str__
(
self
):
def
__str__
(
self
):
if
self
.
name
:
if
self
.
name
:
...
@@ -149,18 +120,21 @@ class CudaNdarrayType(Type):
...
@@ -149,18 +120,21 @@ class CudaNdarrayType(Type):
def
c_declare
(
self
,
name
,
sub
):
def
c_declare
(
self
,
name
,
sub
):
ndim
=
self
.
ndim
ndim
=
self
.
ndim
c_typename
=
self
.
dtype_specs
()[
1
]
c_typename
=
self
.
dtype_specs
()[
1
]
return
""" CudaNdarray
Type::VoidTensor* vt
_
%(name)
s;"""
%
locals
()
return
""" CudaNdarray
* cnda
_
%(name)
s;"""
%
locals
()
def
c_init
(
self
,
name
,
sub
):
def
c_init
(
self
,
name
,
sub
):
return
"
vt
_
%(name)
s = NULL;"
%
locals
()
return
"
cnda
_
%(name)
s = NULL;"
%
locals
()
def
c_extract
(
self
,
name
,
sub
):
def
c_extract
(
self
,
name
,
sub
):
return
"""
return
"""
vt_
%(name)
s = CudaNdarrayType::voidtensor_from_cobject(py_
%(name)
s);
if (CudaNdarray_Check(py_
%(name)
s))
std::cerr << "extract "<< py_
%(name)
s << " " << vt_
%(name)
s << "
\\
n";
{
if (!vt_
%(name)
s)
cnda_
%(name)
s = (CudaNdarray*)py_
%(name)
s;
}
else
{
{
PyErr_SetString(PyExc_TypeError, "Failed to extract VoidTensor");
PyErr_SetString(PyExc_TypeError, "Argument not a CudaNdarray");
cnda_
%(name)
s = NULL;
%(fail)
s;
%(fail)
s;
}
}
"""
%
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
])
"""
%
dict
(
sub
,
name
=
name
,
type_num
=
self
.
dtype_specs
()[
2
])
...
@@ -174,214 +148,46 @@ class CudaNdarrayType(Type):
...
@@ -174,214 +148,46 @@ class CudaNdarrayType(Type):
"""Override `CLinkerOp.c_sync` """
"""Override `CLinkerOp.c_sync` """
return
"""
return
"""
std::cerr << "sync
\\
n";
std::cerr << "sync
\\
n";
if (
!vt
_
%(name)
s) {
if (
NULL == cnda
_
%(name)
s) {
// failure: sync None to storage
// failure: sync None to storage
Py_XDECREF(py_
%(name)
s);
Py_XDECREF(py_
%(name)
s);
py_
%(name)
s = Py_None;
py_
%(name)
s = Py_None;
Py_XINCREF(py_
%(name)
s);
Py_XINCREF(py_
%(name)
s);
}
}
else if (PyCObject_AsVoidPtr(py_
%(name)
s) != (void*)vt_
%(name)
s) {
else
// success, but a new gtt was allocated for us
{
// we trust that the op code deleted the old gtt
py_
%(name)
s = (PyObject*)cnda_
%(name)
s;
// we just pack the new gtt into a CObject
Py_XDECREF(py_
%(name)
s);
py_
%(name)
s = CudaNdarrayType::cobject_from_voidtensor(vt_
%(name)
s);
std::cerr << "sync packing " << vt_
%(name)
s << " into new CObject "<< py_
%(name)
s << " "<< PyCObject_Check(py_
%(name)
s) << "
\\
n";
}
}
"""
%
locals
()
"""
%
locals
()
def
c_headers
(
self
):
def
c_headers
(
self
):
"""Override `CLinkerOp.c_headers` """
"""Override `CLinkerOp.c_headers` """
return
[]
return
[
'cuda_ndarray.cuh'
]
def
c_header_dirs
(
self
):
"""Override `CLinkerOp.c_headers` """
return
[
os
.
path
.
dirname
(
cuda_ndarray
.
__file__
),
os
.
path
.
join
(
os
.
getenv
(
"CUDA_ROOT"
),
'include'
)]
def
c_lib_dirs
(
self
):
return
[
os
.
path
.
dirname
(
cuda_ndarray
.
__file__
),
os
.
path
.
join
(
os
.
getenv
(
"CUDA_ROOT"
),
'lib'
)]
def
c_libraries
(
self
):
def
c_libraries
(
self
):
return
[]
return
[
'cuda_ndarray'
,
'cudart'
]
def
c_support_code
(
cls
):
def
c_support_code
(
cls
):
rval
=
file
(
'tensorview.cc'
)
.
read
()
return
""
return
rval
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
()
#do not cache this stuff until it matures
return
()
#do not cache this stuff until it matures
class
CudaNdarrayVariable
(
Variable
,
_tensor_operators
):
pass
class
CudaNdarrayConstant
(
Constant
,
_tensor_operators
):
pass
class
CudaNdarraySharedVariable
(
SharedVariable
,
_tensor_operators
):
def
__getvalue
(
self
):
return
tensorview_module
.
ndarray_from_voidtensor
(
self
.
container
.
value
)
def
__setvalue
(
self
,
value
):
self
.
container
.
value
=
value
#container does the filtering
value
=
property
(
__getvalue
,
__setvalue
)
def
filter_update
(
self
,
other
):
if
hasattr
(
other
,
'_as_CudaNdarrayVariable'
):
return
other
.
_as_CudaNdarrayVariable
()
if
isinstance
(
other
.
type
,
tensor
.
TensorType
)
and
(
other
.
type
.
dtype
==
self
.
dtype
)
and
(
other
.
broadcastable
==
self
.
broadcastable
):
return
GpuFromHost
()(
other
)
else
:
raise
TypeError
(
other
)
def
gpu_tensor_shared_constructor
(
value
,
name
,
strict
=
False
):
"""SharedVariable Constructor for TensorType"""
if
not
isinstance
(
value
,
numpy
.
ndarray
):
raise
TypeError
bcast
=
[
0
for
b
in
value
.
shape
]
def
c_compiler
(
self
):
return
nvcc_module_compile_str
type
=
CudaNdarrayType
(
value
.
dtype
,
broadcastable
=
bcast
)
return
CudaNdarraySharedVariable
(
type
=
type
,
value
=
value
,
name
=
name
,
strict
=
strict
)
class
HostFromGpu
(
Op
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
):
if
not
isinstance
(
x
.
type
,
CudaNdarrayType
):
raise
TypeError
(
x
)
return
Apply
(
self
,
[
x
],
[
tensor
.
TensorType
(
dtype
=
x
.
dtype
,
broadcastable
=
x
.
broadcastable
)()])
def
perform
(
self
,
node
,
(
x
,),
(
z
,)):
z
[
0
]
=
tensorview_module
.
ndarray_from_voidtensor
(
x
)
def
grad
(
self
,
inputs
,
(
gz
,)):
return
[
GpuFromHost
()(
gz
)]
class
GpuFromHost
(
Op
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
x
):
if
not
isinstance
(
x
.
type
,
tensor
.
TensorType
):
raise
TypeError
(
x
)
return
Apply
(
self
,
[
x
],
[
CudaNdarrayType
(
dtype
=
x
.
dtype
,
broadcastable
=
x
.
broadcastable
)()])
def
perform
(
self
,
node
,
(
x
,),
(
z
,)):
z
[
0
]
=
tensorview_module
.
filter
(
x
,
x
.
dtype
.
num
,
tuple
([
0
]
*
x
.
ndim
),
0
)
def
grad
(
self
,
inputs
,
(
gz
,)):
return
[
HostFromGpu
()(
gz
)]
class
GpuAdd
(
Op
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
def
__hash__
(
self
):
return
hash
(
type
(
self
))
def
make_node
(
self
,
a
,
b
):
if
not
isinstance
(
a
.
type
,
CudaNdarrayType
):
raise
TypeError
(
a
)
if
not
isinstance
(
b
.
type
,
CudaNdarrayType
):
raise
TypeError
(
b
)
if
a
.
type
.
broadcastable
!=
b
.
type
.
broadcastable
:
raise
NotImplementedError
(
'different bcastable'
)
if
a
.
dtype
!=
b
.
dtype
:
raise
NotImplementedError
(
'different dtype'
)
return
Apply
(
self
,
[
a
,
b
],
[
CudaNdarrayType
(
dtype
=
a
.
dtype
,
broadcastable
=
a
.
broadcastable
)()])
def
perform
(
self
,
node
,
(
a
,
b
),
(
z
,)):
aval
=
tensorview_module
.
ndarray_from_voidtensor
(
a
)
bval
=
tensorview_module
.
ndarray_from_voidtensor
(
b
)
zval
=
aval
+
bval
z
[
0
]
=
tensorview_module
.
filter
(
zval
,
zval
.
dtype
.
num
,(
0
,)
*
len
(
zval
.
shape
),
0
)
def
grad
(
self
,
inputs
,
(
gz
,)):
return
[
gz
for
i
in
inputs
]
def
c_support_code
(
self
):
return
"""
template<typename T0, typename T1, typename T2>
void gpu_tensor_add(const int nd, const int * dim,
T0 * __restrict__ z, const int * zstr,
const T1 * __restrict__ a, const int * astr,
const T2 * __restrict__ b, const int * bstr)
{
if (0 == nd) //copy a scalar
{
z[0] = a[0] + b[0];
}
else
{
for (int i = 0; i< dim[0]; ++i)
{
gpu_tensor_add(nd-1, dim+1,
z + i * zstr[0], zstr+1,
a + i * astr[0], astr+1,
b + i * bstr[0], bstr+1);
}
}
}
"""
def
c_code
(
self
,
node
,
nodename
,
(
a
,
b
),
(
z
,),
sub
):
asym
,
bsym
=
node
.
inputs
zsym
,
=
node
.
outputs
nd_a
=
asym
.
ndim
nd_b
=
bsym
.
ndim
nd_z
=
zsym
.
ndim
typename_a
=
asym
.
type
.
dtype_specs
()[
1
]
typename_b
=
bsym
.
type
.
dtype_specs
()[
1
]
typename_z
=
zsym
.
type
.
dtype_specs
()[
1
]
return
"""
std::cerr << "GpuAdd start
\\
n";
if (vt_
%(z)
s) delete vt_
%(z)
s;
vt_
%(z)
s = new CudaNdarrayType::VoidTensor(vt_
%(a)
s->typenum, vt_
%(a)
s->elsize,
%(nd_a)
s, vt_
%(a)
s->dim);
CudaNdarrayType::TensorView<
%(nd_a)
s,
%(typename_a)
s> view_
%(a)
s(vt_
%(a)
s);
CudaNdarrayType::TensorView<
%(nd_b)
s,
%(typename_b)
s> view_
%(b)
s(vt_
%(b)
s);
CudaNdarrayType::TensorView<
%(nd_z)
s,
%(typename_z)
s> view_
%(z)
s(vt_
%(z)
s);
gpu_tensor_add(vt_
%(a)
s->nd, vt_
%(a)
s->dim,
view_
%(z)
s.data, view_
%(z)
s.str,
view_
%(a)
s.data, view_
%(a)
s.str,
view_
%(b)
s.data, view_
%(b)
s.str);
std::cerr << "GpuAdd done
\\
n";
"""
%
locals
()
def
c_code_cache_version
(
self
):
return
()
#compiler = theano.gof.cmodule.nvcc_module_compile_str
@tensor.gof.local_optimizer
([
GpuFromHost
(),
None
])
def
local_gpu_host_gpu
(
node
):
if
not
tensor
.
opt
.
opt
.
check_chain
(
node
,
GpuFromHost
(),
HostFromGpu
()):
return
False
return
[
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
]]
tensor
.
opt
.
register_canonicalize
(
local_gpu_host_gpu
,
'gpu_host_gpu'
)
@tensor.gof.local_optimizer
([
HostFromGpu
(),
None
])
def
local_host_gpu_host
(
node
):
if
not
tensor
.
opt
.
opt
.
check_chain
(
node
,
HostFromGpu
(),
GpuFromHost
()):
return
False
return
[
node
.
inputs
[
0
]
.
owner
.
inputs
[
0
]]
tensor
.
opt
.
register_canonicalize
(
local_host_gpu_host
,
'host_gpu_host'
)
@tensor.gof.local_optimizer
([
GpuFromHost
(),
None
])
def
local_gpu_add
(
node
):
if
node
.
op
==
GpuFromHost
():
if
node
.
inputs
[
0
]
.
owner
and
node
.
inputs
[
0
]
.
owner
.
op
==
tensor
.
add
:
add_inputs
=
node
.
inputs
[
0
]
.
owner
.
inputs
if
any
(
hasattr
(
i
.
owner
,
'op'
)
and
isinstance
(
i
.
owner
.
op
,
HostFromGpu
)
for
i
in
add_inputs
):
# move the add to a GpuAdd
return
[
GpuAdd
()(
*
(
GpuFromHost
()(
i
)
for
i
in
add_inputs
))]
return
False
tensor
.
opt
.
register_canonicalize
(
local_gpu_add
,
'gpu_add'
)
def
unset_shared_for_numpy
():
raise
NotImplementedError
()
def
set_shared_for_numpy
():
"""
Set the gpu_tensor_constructor as the handler for ndarray
"""
shared_constructor
(
gpu_tensor_shared_constructor
)
type_support.cu
浏览文件 @
56d152c8
...
@@ -7,15 +7,14 @@
...
@@ -7,15 +7,14 @@
#define DECL(s) static PyObject * s(PyObject * self, PyObject *args)
#define DECL(s) static PyObject * s(PyObject * self, PyObject *args)
static PyObject *
static PyObject *
filter(PyObject* self, PyObject *args) // args = (data,
typenum,
broadcastable, strict)
filter(PyObject* self, PyObject *args) // args = (data, broadcastable, strict)
{
{
PyObject *py_data=NULL;
PyObject *py_data=NULL;
PyArrayObject * data = NULL;
PyArrayObject * data = NULL;
int dtype_typenum=-1;
int strict = 0;
int strict = 0;
PyObject * broadcastable=NULL;
PyObject * broadcastable=NULL;
if (!PyArg_ParseTuple(args, "O
iOi", &py_data, &dtype_typenum
, &broadcastable, &strict)) return NULL;
if (!PyArg_ParseTuple(args, "O
Oi", &py_data
, &broadcastable, &strict)) return NULL;
if (!PyTuple_Check(broadcastable)){
if (!PyTuple_Check(broadcastable)){
PyErr_SetString(PyExc_TypeError, "broadcastable arg should be a tuple of int.");
PyErr_SetString(PyExc_TypeError, "broadcastable arg should be a tuple of int.");
...
@@ -99,9 +98,9 @@ static PyMethodDef MyMethods[] = {
...
@@ -99,9 +98,9 @@ static PyMethodDef MyMethods[] = {
PyMODINIT_FUNC
PyMODINIT_FUNC
init
_theano_cuda_ndarray
(void)
init
type_support
(void)
{
{
(void) Py_InitModule("
_theano_cuda_ndarray
", MyMethods);
(void) Py_InitModule("
type_support
", MyMethods);
import_array();
import_array();
}
}
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论