Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
92ed9f15
提交
92ed9f15
authored
7月 15, 2010
作者:
Frederic Bastien
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added example file that show how to make Theano op that use pycuda generated fct.
上级
d606816a
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
242 行增加
和
0 行删除
+242
-0
pycuda_example.py
theano/misc/pycuda_example.py
+183
-0
test_pycuda.py
theano/misc/test_pycuda.py
+59
-0
没有找到文件。
theano/misc/pycuda_example.py
0 → 100644
浏览文件 @
92ed9f15
"""
This file show how we can use Pycuda compiled fct in a Theano Op. Do no use them in production code. See the TODO.
You can use them as a guide to use your pycuda code into a Theano op.
The PycudaElemwiseSourceModule op use pycuda code generated with pycuda.compiler.SourceModule
The PycudaElemwiseKernel op use pycuda code generated with pycuda.elementwise.ElementwiseKernel
Their is a test in test_pycuda.py.
This don't work with broadcast and non-contiguous memory as pycuda don't support that, but we make sure we don't introduce problem.
"""
import
numpy
import
theano
import
theano.tensor
as
T
from
theano.gof
import
Op
,
Apply
,
local_optimizer
,
EquilibriumDB
from
theano.sandbox.cuda
import
GpuElemwise
,
CudaNdarrayType
,
CudaNdarray
from
theano.sandbox.cuda.basic_ops
import
as_cuda_ndarray_variable
,
gpu_contiguous
,
host_from_gpu
from
theano.sandbox.cuda.opt
import
gpu_seqopt
from
pycuda.elementwise
import
ElementwiseKernel
from
pycuda.compiler
import
SourceModule
from
pycuda.gpuarray
import
splay
import
pycuda.autoinit
class
PycudaElemwiseSourceModule
(
Op
):
nin
=
property
(
lambda
self
:
self
.
scalar_op
.
nin
)
nout
=
property
(
lambda
self
:
self
.
scalar_op
.
nout
)
def
__init__
(
self
,
scalar_op
,
inplace_pattern
=
{},
name
=
None
):
self
.
name
=
name
self
.
scalar_op
=
scalar_op
self
.
inplace_pattern
=
None
def
__str__
(
self
):
if
self
.
name
is
None
:
if
self
.
inplace_pattern
:
items
=
self
.
inplace_pattern
.
items
()
items
.
sort
()
return
"PycudaElemwiseSourceModule{
%
s}
%
s"
%
(
self
.
scalar_op
,
str
(
items
))
else
:
return
"PycudaElemwiseSourceModule{
%
s}"
%
(
self
.
scalar_op
)
else
:
return
self
.
name
def
make_node
(
self
,
*
inputs
):
_inputs
=
[
gpu_contiguous
(
as_cuda_ndarray_variable
(
i
))
for
i
in
inputs
]
if
self
.
nin
>
0
and
len
(
_inputs
)
!=
self
.
nin
:
raise
TypeError
(
'Wrong argument count'
,
(
self
.
nin
,
len
(
_inputs
)))
for
i
in
_inputs
[
1
:]:
if
i
.
type
.
ndim
!=
inputs
[
0
]
.
type
.
ndim
:
raise
TypeError
(
'different ranks among inputs'
)
assert
not
any
([
any
(
i
.
type
.
broadcastable
)
for
i
in
inputs
])
assert
len
(
inputs
)
==
2
#TODO remove
otype
=
CudaNdarrayType
(
broadcastable
=
[
False
]
*
_inputs
[
0
]
.
type
.
ndim
)
assert
self
.
nout
==
1
#TODO change the scalar op with the good c_code!
fct_name
=
"pycuda_elemwise_
%
s"
%
str
(
self
.
scalar_op
)
out_node
=
Apply
(
self
,
_inputs
,
[
otype
()
for
o
in
xrange
(
self
.
nout
)])
in_name
=
[
"i"
+
str
(
id
)
for
id
in
range
(
len
(
inputs
))]
out_name
=
[
"o"
+
str
(
id
)
for
id
in
range
(
self
.
nout
)]
c_code
=
self
.
scalar_op
.
c_code
(
out_node
,
"some_name"
,
tuple
([
n
+
"[i]"
for
n
in
in_name
]),
tuple
(
n
+
"[i]"
for
n
in
out_name
),
{})
c_code_param
=
", "
.
join
([
var
.
type
.
dtype_specs
()[
1
]
+
" *"
+
name
for
var
,
name
in
zip
(
inputs
,
in_name
)
+
zip
(
out_node
.
outputs
,
out_name
)])
mod
=
SourceModule
(
"""
#include<Python.h>
#include <numpy/arrayobject.h>
__global__ void
%
s(
%
s)
{
int i = threadIdx.x + threadIdx.y*blockDim.x;
%
s
}
"""
%
(
fct_name
,
c_code_param
,
c_code
))
self
.
pycuda_fct
=
mod
.
get_function
(
fct_name
)
return
out_node
def
perform
(
self
,
node
,
inputs
,
(
z
,)):
#TODO support broadcast!
#TODO assert all input have the same shape
if
z
[
0
]
is
None
or
z
[
0
]
.
shape
!=
inputs
[
0
]
.
shape
:
z
[
0
]
=
theano
.
sandbox
.
cuda
.
CudaNdarray
.
zeros
(
inputs
[
0
]
.
shape
)
self
.
pycuda_fct
(
inputs
[
0
],
inputs
[
1
],
z
[
0
],
block
=
(
inputs
[
0
]
.
shape
[
0
],
inputs
[
0
]
.
shape
[
1
],
1
))
class
PycudaElemwiseKernel
(
Op
):
nin
=
property
(
lambda
self
:
self
.
scalar_op
.
nin
)
nout
=
property
(
lambda
self
:
self
.
scalar_op
.
nout
)
def
__init__
(
self
,
scalar_op
,
inplace_pattern
=
{},
name
=
None
):
self
.
name
=
name
self
.
scalar_op
=
scalar_op
self
.
inplace_pattern
=
None
def
__str__
(
self
):
if
self
.
name
is
None
:
if
self
.
inplace_pattern
:
items
=
self
.
inplace_pattern
.
items
()
items
.
sort
()
return
"PycudaElemwiseKernel{
%
s}
%
s"
%
(
self
.
scalar_op
,
str
(
items
))
else
:
return
"PycudaElemwiseKernel{
%
s}"
%
(
self
.
scalar_op
)
else
:
return
self
.
name
def
make_node
(
self
,
*
inputs
):
_inputs
=
[
gpu_contiguous
(
as_cuda_ndarray_variable
(
i
))
for
i
in
inputs
]
if
self
.
nin
>
0
and
len
(
_inputs
)
!=
self
.
nin
:
raise
TypeError
(
'Wrong argument count'
,
(
self
.
nin
,
len
(
_inputs
)))
for
i
in
_inputs
[
1
:]:
if
i
.
type
.
ndim
!=
inputs
[
0
]
.
type
.
ndim
:
raise
TypeError
(
'different ranks among inputs'
)
assert
not
any
([
any
(
i
.
type
.
broadcastable
)
for
i
in
inputs
])
assert
len
(
inputs
)
==
2
#TODO remove
# output is broadcastable only along dimensions where all inputs are broadcastable
broadcastable
=
[]
for
d
in
xrange
(
_inputs
[
0
]
.
type
.
ndim
):
bcast_d
=
True
for
i
in
_inputs
:
if
not
i
.
type
.
broadcastable
[
d
]:
bcast_d
=
False
break
broadcastable
.
append
(
bcast_d
)
assert
len
(
broadcastable
)
==
_inputs
[
0
]
.
type
.
ndim
otype
=
CudaNdarrayType
(
broadcastable
=
broadcastable
)
assert
self
.
nout
==
1
out_node
=
Apply
(
self
,
_inputs
,
[
otype
()
for
o
in
xrange
(
self
.
nout
)])
in_name
=
[
"i"
+
str
(
id
)
for
id
in
range
(
len
(
inputs
))]
out_name
=
[
"o"
+
str
(
id
)
for
id
in
range
(
self
.
nout
)]
c_code
=
self
.
scalar_op
.
c_code
(
out_node
,
"some_name"
,
tuple
([
n
+
"[i]"
for
n
in
in_name
]),
tuple
(
n
+
"[i]"
for
n
in
out_name
),
{})
self
.
pycuda_fct
=
ElementwiseKernel
(
", "
.
join
([
var
.
type
.
dtype_specs
()[
1
]
+
" *"
+
name
for
var
,
name
in
zip
(
inputs
,
in_name
)
+
zip
(
out_node
.
outputs
,
out_name
)]),
c_code
,
"pycuda_elemwise_kernel_
%
s"
%
str
(
self
.
scalar_op
),
preamble
=
"""#include<Python.h>
#include <numpy/arrayobject.h>"""
)
return
out_node
def
perform
(
self
,
node
,
inputs
,
(
z
,)):
#TODO assert all input have the same shape
if
z
[
0
]
is
None
or
z
[
0
]
.
shape
!=
inputs
[
0
]
.
shape
:
z
[
0
]
=
theano
.
sandbox
.
cuda
.
CudaNdarray
.
zeros
(
inputs
[
0
]
.
shape
)
i
=
inputs
+
z
sp
=
splay
(
i
[
0
]
.
mem_size
)
self
.
pycuda_fct
(
*
i
,
grid
=
sp
[
0
],
block
=
sp
[
1
])
pycuda_optimizer
=
EquilibriumDB
()
gpu_seqopt
.
register
(
"pycuda_optimizer"
,
pycuda_optimizer
,
1.5
,
"fast_run"
)
@local_optimizer
([])
def
local_pycuda_gpu_elemwise
(
node
):
"""
GpuElemwise -> PycudaElemwiseSourceModule
"""
if
isinstance
(
node
.
op
,
GpuElemwise
):
if
not
any
([
any
(
i
.
type
.
broadcastable
)
for
i
in
node
.
inputs
])
and
all
([
i
.
ndim
<=
2
for
i
in
node
.
inputs
]):
new_op
=
PycudaElemwiseSourceModule
(
node
.
op
.
scalar_op
,
node
.
op
.
inplace_pattern
)(
*
node
.
inputs
)
return
[
new_op
]
pycuda_optimizer
.
register
(
"local_pycuda_gpu_elemwise"
,
local_pycuda_gpu_elemwise
)
@local_optimizer
([])
def
local_pycuda_gpu_elemwise_kernel
(
node
):
"""
GpuElemwise -> PycudaElemwiseKernel
"""
if
isinstance
(
node
.
op
,
GpuElemwise
):
if
not
any
([
any
(
i
.
type
.
broadcastable
)
for
i
in
node
.
inputs
]):
new_op
=
PycudaElemwiseKernel
(
node
.
op
.
scalar_op
,
node
.
op
.
inplace_pattern
)(
*
node
.
inputs
)
return
[
new_op
]
pycuda_optimizer
.
register
(
"local_pycuda_gpu_elemwise_kernel"
,
local_pycuda_gpu_elemwise_kernel
,
1.5
)
theano/misc/test_pycuda.py
0 → 100644
浏览文件 @
92ed9f15
import
numpy
import
theano
import
theano.tensor
as
T
from
theano.misc.pycuda_example
import
PycudaElemwiseSourceModule
,
PycudaElemwiseKernel
from
theano.sandbox.cuda
import
GpuContiguous
import
theano.misc.pycuda_example
def
test_pycuda_elemwise_source_module
():
x
=
T
.
fmatrix
(
'x'
)
y
=
T
.
fmatrix
(
'y'
)
f
=
theano
.
function
([
x
,
y
],
x
*
y
)
print
f
.
maker
.
env
.
toposort
()
f2
=
theano
.
function
([
x
,
y
],
x
*
y
,
mode
=
theano
.
compile
.
mode
.
get_default_mode
()
.
including
(
"local_pycuda_gpu_elemwise"
))
print
f2
.
maker
.
env
.
toposort
()
assert
any
([
isinstance
(
node
.
op
,
theano
.
sandbox
.
cuda
.
GpuElemwise
)
for
node
in
f
.
maker
.
env
.
toposort
()])
assert
any
([
isinstance
(
node
.
op
,
PycudaElemwiseSourceModule
)
for
node
in
f2
.
maker
.
env
.
toposort
()])
val1
=
numpy
.
random
.
rand
(
5
,
5
)
val2
=
numpy
.
random
.
rand
(
5
,
5
)
#val1 = numpy.ones((5,5))
#val2 = numpy.arange(25).reshape(5,5)
assert
(
f
(
val1
,
val2
)
==
f2
(
val1
,
val2
))
.
all
()
print
f
(
val1
,
val2
)
print
f2
(
val1
,
val2
)
def
test_pycuda_elemwise_kernel
():
x
=
T
.
fmatrix
(
'x'
)
y
=
T
.
fmatrix
(
'y'
)
f
=
theano
.
function
([
x
,
y
],
x
+
y
)
print
f
.
maker
.
env
.
toposort
()
f2
=
theano
.
function
([
x
,
y
],
x
+
y
,
mode
=
theano
.
compile
.
mode
.
get_default_mode
()
.
including
(
"local_pycuda_gpu_elemwise_kernel"
))
print
f2
.
maker
.
env
.
toposort
()
assert
any
([
isinstance
(
node
.
op
,
theano
.
sandbox
.
cuda
.
GpuElemwise
)
for
node
in
f
.
maker
.
env
.
toposort
()])
assert
any
([
isinstance
(
node
.
op
,
PycudaElemwiseKernel
)
for
node
in
f2
.
maker
.
env
.
toposort
()])
val1
=
numpy
.
random
.
rand
(
5
,
5
)
val2
=
numpy
.
random
.
rand
(
5
,
5
)
#val1 = numpy.ones((5,5))
#val2 = numpy.arange(25).reshape(5,5)
assert
(
f
(
val1
,
val2
)
==
f2
(
val1
,
val2
))
.
all
()
print
f
(
val1
,
val2
)
print
f2
(
val1
,
val2
)
x3
=
T
.
ftensor3
(
'x'
)
y3
=
T
.
ftensor3
(
'y'
)
z3
=
T
.
ftensor3
(
'y'
)
f4
=
theano
.
function
([
x3
,
y3
,
z3
],
x3
*
y3
+
z3
,
mode
=
theano
.
compile
.
mode
.
get_default_mode
()
.
including
(
"local_pycuda_gpu_elemwise_kernel"
))
print
f4
.
maker
.
env
.
toposort
()
assert
any
([
isinstance
(
node
.
op
,
PycudaElemwiseKernel
)
for
node
in
f4
.
maker
.
env
.
toposort
()])
val1
=
numpy
.
random
.
rand
(
2
,
2
,
2
)
print
val1
print
f4
(
val1
,
val1
,
val1
)
assert
numpy
.
allclose
(
f4
(
val1
,
val1
,
val1
),
val1
*
val1
+
val1
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论