Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
31afc498
提交
31afc498
authored
7月 23, 2009
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
elemwise and dimshuffle working, basic profiling.
上级
0039353e
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
116 行增加
和
28 行删除
+116
-28
basic_ops.py
basic_ops.py
+39
-18
blas.py
blas.py
+7
-0
nvcc_compiler.py
nvcc_compiler.py
+1
-0
walltime.py
tests/walltime.py
+59
-0
type.py
type.py
+10
-10
没有找到文件。
basic_ops.py
浏览文件 @
31afc498
...
...
@@ -90,6 +90,7 @@ class GpuElemwise(Op):
def
__hash__
(
self
):
return
self
.
_hashval
def
__str__
(
self
):
if
self
.
inplace_pattern
:
items
=
self
.
inplace_pattern
.
items
()
...
...
@@ -97,6 +98,7 @@ class GpuElemwise(Op):
return
"GpuElemwise{
%
s}
%
s"
%
(
self
.
scalar_op
,
str
(
items
))
else
:
return
"GpuElemwise{
%
s}"
%
(
self
.
scalar_op
)
def
make_node
(
self
,
*
inputs
):
_inputs
=
[
as_cuda_ndarray_variable
(
i
)
for
i
in
inputs
]
if
self
.
nin
>
0
and
len
(
_inputs
)
!=
self
.
nin
:
...
...
@@ -119,6 +121,7 @@ class GpuElemwise(Op):
otype
=
CudaNdarrayType
(
broadcastable
=
broadcastable
)
assert
self
.
nout
>
0
return
Apply
(
self
,
_inputs
,
[
otype
()
for
o
in
xrange
(
self
.
nout
)])
def
c_support_code
(
self
):
return
"""
#define INTDIV_POW2(a, b) (a >> b)
...
...
@@ -128,8 +131,10 @@ class GpuElemwise(Op):
def
c_src_kernel
(
self
,
node
,
nodename
):
nd
=
node
.
outputs
[
0
]
.
type
.
ndim
sio
=
StringIO
.
StringIO
()
#TODO: optimize by passing the log2 of each dim, as well as the mask of 1s that we need
# to compute the modulo
print
'C_SRC_KERNEL'
,
sio
.
getvalue
()
def
_logical_scalar
(
x
):
return
all
(
x
.
type
.
broadcastable
)
print
>>
sio
,
"static __global__ void kernel_
%
s(unsigned int numEls,"
%
nodename
print
>>
sio
,
"
\t
"
,
", "
.
join
(
"unsigned int log2_dim
%
i"
%
i
for
i
in
xrange
(
nd
))
...
...
@@ -144,6 +149,14 @@ class GpuElemwise(Op):
print
>>
sio
,
"
\t
)
\n
{"
print
>>
sio
,
" const unsigned int idx = blockIdx.x * blockDim.x + threadIdx.x;"
print
>>
sio
,
" const unsigned int numThreads = blockDim.x * gridDim.x;"
# For each input that is a scalar which has been broadcasted to a tensor,
# load it into a local variable
for
ipos
,
i
in
enumerate
(
node
.
inputs
):
if
_logical_scalar
(
i
):
print
>>
sio
,
" const float ii_i
%
i_value = i
%
i_data[0];"
%
(
ipos
,
ipos
)
#TODO: insert code to check for strides of 1, and use a different loop
#loop over the elements to be treated by this kernel call
...
...
@@ -151,7 +164,8 @@ class GpuElemwise(Op):
# calculate the data pointers for all arguments
print
>>
sio
,
" unsigned int ii = i;"
for
ipos
,
i
in
enumerate
(
node
.
inputs
):
print
>>
sio
,
" const float * ii_i
%
i_data = i
%
i_data;"
%
(
ipos
,
ipos
)
if
not
_logical_scalar
(
i
):
print
>>
sio
,
" const float * ii_i
%
i_data = i
%
i_data;"
%
(
ipos
,
ipos
)
for
ipos
,
i
in
enumerate
(
node
.
outputs
):
print
>>
sio
,
" float * ii_o
%
i_data = o
%
i_data;"
%
(
ipos
,
ipos
)
for
d
in
xrange
(
nd
-
1
,
-
1
,
-
1
):
...
...
@@ -161,16 +175,23 @@ class GpuElemwise(Op):
else
:
print
>>
sio
,
" unsigned int pos
%
i = ii;"
%
d
for
ipos
,
i
in
enumerate
(
node
.
inputs
):
print
>>
sio
,
" ii_i
%
i_data += pos
%
i * i
%
i_str_
%
i;"
%
(
ipos
,
d
,
ipos
,
d
)
if
not
_logical_scalar
(
i
):
print
>>
sio
,
" ii_i
%
i_data += pos
%
i * i
%
i_str_
%
i;"
%
(
ipos
,
d
,
ipos
,
d
)
for
ipos
,
i
in
enumerate
(
node
.
outputs
):
print
>>
sio
,
" ii_o
%
i_data += pos
%
i * o
%
i_str_
%
i;"
%
(
ipos
,
d
,
ipos
,
d
)
# perform the scalar operation on the input and output references
if
d
==
0
:
print
>>
sio
,
" "
,
self
.
scalar_op
.
c_code
(
None
,
None
,
[
'ii_i
%
i_data[0]'
%
ipos
for
ipos
,
i
in
enumerate
(
node
.
inputs
)],
[
'ii_o
%
i_data[0]'
%
ipos
for
ipos
,
i
in
enumerate
(
node
.
outputs
)],
sub
=
dict
(
fail
=
'return;'
))
#TODO: set a failure code somehow!!!
#TODO: What if the scalar_op needs support_code??
task_code
=
self
.
scalar_op
.
c_code
(
Apply
(
self
.
scalar_op
,
[
scalar
.
Scalar
(
dtype
=
input
.
type
.
dtype
)()
for
input
in
node
.
inputs
],
[
scalar
.
Scalar
(
dtype
=
output
.
type
.
dtype
)()
for
output
in
node
.
outputs
])
,
nodename
+
'_scalar_'
,
[(
'ii_i
%
i_value'
if
_logical_scalar
(
i
)
else
'ii_i
%
i_data[0]'
)
%
ipos
for
ipos
,
i
in
enumerate
(
node
.
inputs
)]
,
[
'ii_o
%
i_data[0]'
%
ipos
for
ipos
,
i
in
enumerate
(
node
.
outputs
)]
,
sub
=
dict
(
fail
=
'return;'
))
#TODO: set a failure code somehow!!!
print
>>
sio
,
" "
,
task_code
print
>>
sio
,
" }"
#TODO: insert runtime stride checks that select the best loop order either here, or in
...
...
@@ -180,13 +201,14 @@ class GpuElemwise(Op):
#for ipos, i in enumerate(node.inputs):
#print >> sio, indent, "const float * i%i" % ipos, '= i%i_data', ''
print
>>
sio
,
"}"
if
0
:
print
sio
.
getvalue
()
print
sio
.
getvalue
()
return
sio
.
getvalue
()
def
c_support_code_apply
(
self
,
node
,
nodename
):
return
self
.
c_src_kernel
(
node
,
nodename
)
+
\
self
.
c_src_callkernel
(
node
,
nodename
)
def
c_src_callkernel
(
self
,
node
,
nodename
):
nd
=
node
.
outputs
[
0
]
.
type
.
ndim
d
=
dict
()
...
...
@@ -289,7 +311,7 @@ class GpuElemwise(Op):
initial_dims
=
','
.
join
(
'1'
for
i
in
xrange
(
nd
))
if
1
or
self
.
scalar_op
==
scalar
.
pow
:
print
>>
sio
,
"""
std::cerr << "C_CODE
%(opname)
s START
\\
n";
//
std::cerr << "C_CODE
%(opname)
s START
\\
n";
//standard elemwise size checks
"""
%
locals
()
print
>>
sio
,
"""
...
...
@@ -297,7 +319,7 @@ class GpuElemwise(Op):
"""
%
locals
()
for
iname
in
inputs
:
print
>>
sio
,
"""
std::cerr << "C_CODE
%(opname)
s checking input
%(iname)
s
\\
n";
//
std::cerr << "C_CODE
%(opname)
s checking input
%(iname)
s
\\
n";
if (
%(nd)
s != cnda_
%(iname)
s->nd)
{
PyErr_Format(PyExc_TypeError, "need
%(nd)
s dims, not
%%
i", cnda_
%(iname)
s->nd);
...
...
@@ -308,7 +330,7 @@ class GpuElemwise(Op):
dims[i] = (dims[i] == 1) ? cnda_
%(iname)
s->dim[i] : dims[i];
if ((cnda_
%(iname)
s->dim[i] != 1) && (dims[i] != cnda_
%(iname)
s->dim[i]))
{
std::cerr << "C_CODE
%(opname)
s checking input
%(iname)
s failed
\\
n";
//
std::cerr << "C_CODE
%(opname)
s checking input
%(iname)
s failed
\\
n";
PyErr_Format(PyExc_TypeError, "GpuElemwise input has incompatible dim[
%%
i] ==
%%
i, where output has size
%%
i",
i,
cnda_
%(iname)
s->dim[i],
...
...
@@ -342,14 +364,14 @@ class GpuElemwise(Op):
%(fail)
s;
}
}
std::cerr << "ELEMWISE NEW
%(oname)
s nd" << cnda_
%(oname)
s->nd << "
\\
n";
//
std::cerr << "ELEMWISE NEW
%(oname)
s nd" << cnda_
%(oname)
s->nd << "
\\
n";
//std::cerr << "ELEMWISE NEW
%(oname)
s data" << cnda_
%(oname)
s->devdata << "
\\
n";
"""
%
locals
()
print
>>
sio
,
"""
{
//new block so that failure gotos don't skip over variable initialization
int log2_dims[
%(nd)
s];
std::cerr << "calling callkernel
\\
n";
//
std::cerr << "calling callkernel
\\
n";
callkernel_
%(nodename)
s(1, 0, dims, log2_dims
"""
%
locals
()
for
iname
in
inputs
:
...
...
@@ -363,7 +385,7 @@ class GpuElemwise(Op):
print
>>
sio
,
"""
);
std::cerr << "calling callkernel returned
\\
n";
//
std::cerr << "calling callkernel returned
\\
n";
cudaThreadSynchronize();
cudaError_t err = cudaGetLastError();
if( cudaSuccess != err)
...
...
@@ -386,7 +408,6 @@ class GpuElemwise(Op):
def
c_code_cache_version
(
self
):
return
()
class
GpuDimShuffle
(
Op
):
def
__init__
(
self
,
input_broadcastable
,
new_order
):
input_broadcastable
=
tuple
(
input_broadcastable
)
...
...
@@ -528,7 +549,7 @@ class GpuDimShuffle(Op):
for
i
,
o
in
enumerate
(
self
.
new_order
):
print
>>
sio
,
"""
std::cerr << "GpuDimShuffle " << cnda_
%(res)
s << " str[
%(i)
s] = " << cnda_
%(res)
s->str[
%(i)
s] << "
\\
n";
//
std::cerr << "GpuDimShuffle " << cnda_
%(res)
s << " str[
%(i)
s] = " << cnda_
%(res)
s->str[
%(i)
s] << "
\\
n";
"""
%
locals
()
# copy the host dims and stride -> device
...
...
blas.py
0 → 100644
浏览文件 @
31afc498
class
GpuDot22
(
Op
):
pass
class
GpuGemm
(
Op
):
pass
nvcc_compiler.py
浏览文件 @
31afc498
...
...
@@ -52,6 +52,7 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[
(
module_name
,
get_lib_extension
()))
debug
(
'Generating shared lib'
,
lib_filename
)
# TODO: Why do these args cause failure on gtx285 that has 1.3 compute capability? '--gpu-architecture=compute_13', '--gpu-code=compute_13',
cmd
=
[
'nvcc'
,
'-shared'
,
'-g'
]
+
[
pa
for
pa
in
preargs
if
pa
.
startswith
(
'-O'
)]
cmd
.
extend
([
'-Xcompiler'
,
','
.
join
(
pa
for
pa
in
preargs
if
not
pa
.
startswith
(
'-O'
))])
cmd
.
extend
(
'-I
%
s'
%
idir
for
idir
in
include_dirs
)
...
...
tests/walltime.py
0 → 100644
浏览文件 @
31afc498
import
sys
,
time
from
theano.compile.sandbox.sharedvalue
import
shared
from
theano.compile.sandbox.pfunc
import
pfunc
from
theano
import
tensor
import
numpy
import
theano_cuda_ndarray
as
tcn
from
theano_cuda_ndarray.basic_ops
import
host_from_gpu
,
gpu_from_host
def
compare_fns
(
fns
,
input
,
reps
=
10
):
times
=
{}
for
implname
,
impl
in
fns
.
iteritems
():
try
:
print
'TOPOSORT'
,
implname
for
i
,
n
in
enumerate
(
impl
.
maker
.
env
.
toposort
()):
print
i
,
n
except
:
pass
t0
=
time
.
time
()
for
i
in
xrange
(
reps
):
impl
(
input
)
dt
=
time
.
time
()
-
t0
times
[
implname
]
=
dt
return
times
def
showtimes
(
times
):
for
impl
,
dt
in
times
.
iteritems
():
print
impl
,
dt
def
cmp_sigmoids
(
shape
):
def
numpy_sigmoid
(
input
):
rval
=
1.0
/
(
1.0
+
numpy
.
exp
(
-
input
))
sinput
=
tensor
.
Tensor
(
dtype
=
'float32'
,
broadcastable
=
(
0
,)
*
len
(
shape
))()
shared_input
=
tcn
.
shared_constructor
(
numpy
.
random
.
rand
(
*
shape
),
'shared_input'
)
times
=
compare_fns
(
dict
(
numpy
=
numpy_sigmoid
,
theano_cpu
=
pfunc
([
sinput
],
1.0
/
(
1.0
+
tensor
.
exp
(
-
sinput
)))
,
theano_gpu_onboard
=
pfunc
([
sinput
],
[],
updates
=
[(
shared_input
,
1.0
/
(
1.0
+
tensor
.
exp
(
-
shared_input
)))])
),
input
=
shared_input
.
value
)
showtimes
(
times
)
def
cmp_sigmoids_T
(
shape
):
def
numpy_sigmoid
(
input
):
rval
=
1.0
/
(
1.0
+
numpy
.
exp
(
-
input
.
T
))
sinput
=
tensor
.
Tensor
(
dtype
=
'float32'
,
broadcastable
=
(
0
,)
*
len
(
shape
))()
shared_input
=
tcn
.
shared_constructor
(
numpy
.
random
.
rand
(
*
shape
),
'shared_input'
)
times
=
compare_fns
(
dict
(
numpy
=
numpy_sigmoid
,
theano_cpu
=
pfunc
([
sinput
],
1.0
/
(
1.0
+
tensor
.
exp
(
-
sinput
.
T
)))
,
theano_gpu_onboard
=
pfunc
([
sinput
],
[],
updates
=
[(
shared_input
,
1.0
/
(
1.0
+
tensor
.
exp
(
-
shared_input
.
T
)))])
),
input
=
shared_input
.
value
)
showtimes
(
times
)
if
__name__
==
'__main__'
:
eval
(
sys
.
argv
[
1
])
type.py
浏览文件 @
31afc498
...
...
@@ -137,14 +137,14 @@ class CudaNdarrayType(Type):
if (CudaNdarray_Check(py_
%(name)
s))
{
cnda_
%(name)
s = (CudaNdarray*)py_
%(name)
s;
std::cerr << "c_extract " << cnda_
%(name)
s << '
\\
n';
//
std::cerr << "c_extract " << cnda_
%(name)
s << '
\\
n';
if (cnda_
%(name)
s->nd !=
%(nd)
s)
{
PyErr_Format(PyExc_RuntimeError, "Some CudaNdarray has rank
%%
i, it was supposed to have rank
%(nd)
s", cnda_
%(name)
s->nd);
cnda_
%(name)
s = NULL;
%(fail)
s;
}
std::cerr << "c_extract " << cnda_
%(name)
s << " nd check passed
\\
n";
//
std::cerr << "c_extract " << cnda_
%(name)
s << " nd check passed
\\
n";
"""
%
locals
()
for
i
,
b
in
enumerate
(
self
.
broadcastable
):
if
b
:
...
...
@@ -155,17 +155,17 @@ class CudaNdarrayType(Type):
cnda_
%(name)
s = NULL;
%(fail)
s;
}
std::cerr << "c_extract " << cnda_
%(name)
s << "dim check
%(i)
s passed
\\
n";
std::cerr << "c_extract " << cnda_
%(name)
s << "checking bcast
%(i)
s <" << cnda_
%(name)
s->str<< ">
\\
n";
std::cerr << "c_extract " << cnda_
%(name)
s->str[
%(i)
s] << "
\\
n";
//
std::cerr << "c_extract " << cnda_
%(name)
s << "dim check
%(i)
s passed
\\
n";
//
std::cerr << "c_extract " << cnda_
%(name)
s << "checking bcast
%(i)
s <" << cnda_
%(name)
s->str<< ">
\\
n";
//
std::cerr << "c_extract " << cnda_
%(name)
s->str[
%(i)
s] << "
\\
n";
if (cnda_
%(name)
s->str[
%(i)
s])
{
std::cerr << "c_extract bad stride detected...
\\
n";
//
std::cerr << "c_extract bad stride detected...
\\
n";
PyErr_Format(PyExc_RuntimeError, "Some CudaNdarray has a nonzero stride
%%
i on a broadcastable dimension
%%
i", cnda_
%(name)
s->str[
%(i)
s],
%(i)
s);
cnda_
%(name)
s = NULL;
%(fail)
s;
}
std::cerr << "c_extract " << cnda_
%(name)
s << "bcast check
%(i)
s passed
\\
n";
//
std::cerr << "c_extract " << cnda_
%(name)
s << "bcast check
%(i)
s passed
\\
n";
"""
%
locals
()
print
>>
sio
,
"""
assert(cnda_
%(name)
s);
...
...
@@ -177,19 +177,19 @@ class CudaNdarrayType(Type):
cnda_
%(name)
s = NULL;
%(fail)
s;
}
std::cerr << "c_extract done " << cnda_
%(name)
s << '
\\
n';
//
std::cerr << "c_extract done " << cnda_
%(name)
s << '
\\
n';
"""
%
locals
()
#print sio.getvalue()
return
sio
.
getvalue
()
def
c_cleanup
(
self
,
name
,
sub
):
return
"""
std::cerr << "cleanup " << py_
%(name)
s << " " << cnda_
%(name)
s << "
\\
n";
//
std::cerr << "cleanup " << py_
%(name)
s << " " << cnda_
%(name)
s << "
\\
n";
if (cnda_
%(name)
s)
{
Py_XDECREF(cnda_
%(name)
s);
}
std::cerr << "cleanup done" << py_
%(name)
s << "
\\
n";
//
std::cerr << "cleanup done" << py_
%(name)
s << "
\\
n";
"""
%
locals
()
def
c_sync
(
self
,
name
,
sub
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论