Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6a2ee199
提交
6a2ee199
authored
7月 23, 2011
作者:
Frederic Bastien
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added function to create pycuda.gpuarray.GPUArray from CudaNdarray and the other way.
上级
53105ba7
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
142 行增加
和
0 行删除
+142
-0
pycuda_utils.py
theano/misc/pycuda_utils.py
+65
-0
test_pycuda_utils.py
theano/misc/tests/test_pycuda_utils.py
+77
-0
没有找到文件。
theano/misc/pycuda_utils.py
0 → 100644
浏览文件 @
6a2ee199
import
numpy
import
pycuda.gpuarray
import
theano.sandbox.cuda
as
cuda
if
cuda
.
cuda_available
==
False
:
raise
ImportError
(
'Optional theano package cuda disabled'
)
def
to_gpuarray
(
x
,
copyif
=
False
):
""" take a CudaNdarray and return a pycuda.gpuarray.GPUArray
:type x: CudaNdarray
:param x: The array to transform to pycuda.gpuarray.GPUArray.
:type copyif: bool
:param copyif: If False, raise an error if x is not c contiguous.
If it is c contiguous, we return a GPUArray that share
the same memory region as x.
If True, copy x if it is no c contiguous, so the return won't
shape the same memory region. If c contiguous, the return
will share the same memory region.
We need to do this as GPUArray don't fully support strided memory.
:return type: pycuda.gpuarray.GPUArray
"""
if
not
isinstance
(
x
,
cuda
.
CudaNdarray
):
raise
ValueError
(
"We can transfer only CudaNdarray to pycuda.gpuarray.GPUArray"
)
else
:
# Check if it is c contiguous
size
=
1
c_contiguous
=
True
for
i
in
range
(
x
.
ndim
-
1
,
-
1
,
-
1
):
if
x
.
shape
[
i
]
==
1
:
continue
if
x
.
_strides
[
i
]
!=
size
:
c_contiguous
=
False
break
size
*=
x
.
shape
[
i
]
if
not
c_contiguous
:
if
copyif
:
x
=
x
.
copy
()
else
:
raise
ValueError
(
"We where asked to don't copy memory, but the memory is not c contiguous."
)
# Now x is always c contiguous
px
=
pycuda
.
gpuarray
.
GPUArray
(
x
.
shape
,
x
.
dtype
,
base
=
x
,
gpudata
=
x
.
gpudata
)
return
px
def
to_cudandarray
(
x
):
""" take a pycuda.gpuarray.GPUArray and make a CudaNdarray that point to its memory
:note: CudaNdarray support only float32, so only float32 GPUArray are accepted
"""
if
not
isinstance
(
x
,
pycuda
.
gpuarray
.
GPUArray
):
raise
ValueError
(
"We can transfer only pycuda.gpuarray.GPUArray to CudaNdarray"
)
elif
x
.
dtype
!=
"float32"
:
raise
ValueError
(
"CudaNdarray support only float32"
)
else
:
strides
=
[
1
]
for
i
in
x
.
shape
[::
-
1
][:
-
1
]:
strides
.
append
(
strides
[
-
1
]
*
i
)
strides
=
tuple
(
strides
[::
-
1
])
ptr
=
int
(
x
.
gpudata
)
# in pycuda trunk, y.ptr also works, which is a little cleaner
z
=
cuda
.
from_gpu_pointer
(
ptr
,
x
.
shape
,
strides
,
x
)
return
z
theano/misc/tests/test_pycuda_utils.py
0 → 100644
浏览文件 @
6a2ee199
import
numpy
import
theano.sandbox.cuda
as
cuda
import
theano.misc.pycuda_init
from
theano.misc.pycuda_utils
import
to_gpuarray
,
to_cudandarray
if
not
theano
.
misc
.
pycuda_init
.
pycuda_available
:
from
nose.plugins.skip
import
SkipTest
raise
SkipTest
(
"Pycuda not installed. Skip test of theano op with pycuda code."
)
if
cuda
.
cuda_available
==
False
:
from
nose.plugins.skip
import
SkipTest
raise
SkipTest
(
'Optional theano package cuda disabled'
)
import
pycuda.gpuarray
def
test_to_gpuarray
():
cx
=
cuda
.
CudaNdarray
.
zeros
((
5
,
4
))
px
=
to_gpuarray
(
cx
)
assert
isinstance
(
px
,
pycuda
.
gpuarray
.
GPUArray
)
cx
[
0
,
0
]
=
numpy
.
asarray
(
1
,
dtype
=
"float32"
)
# Check that they share the same memory space
assert
px
.
gpudata
==
cx
.
gpudata
assert
numpy
.
asarray
(
cx
[
0
,
0
])
==
1
assert
numpy
.
allclose
(
numpy
.
asarray
(
cx
),
px
.
get
())
assert
px
.
dtype
==
cx
.
dtype
assert
px
.
shape
==
cx
.
shape
assert
all
(
numpy
.
asarray
(
cx
.
_strides
)
*
4
==
px
.
strides
)
# Test when the CudaNdarray is strided
cx
=
cx
[::
2
,::]
px
=
to_gpuarray
(
cx
,
copyif
=
True
)
assert
isinstance
(
px
,
pycuda
.
gpuarray
.
GPUArray
)
cx
[
0
,
0
]
=
numpy
.
asarray
(
2
,
dtype
=
"float32"
)
# Check that they do not share the same memory space
assert
px
.
gpudata
!=
cx
.
gpudata
assert
numpy
.
asarray
(
cx
[
0
,
0
])
==
2
assert
not
numpy
.
allclose
(
numpy
.
asarray
(
cx
),
px
.
get
())
assert
px
.
dtype
==
cx
.
dtype
assert
px
.
shape
==
cx
.
shape
assert
not
all
(
numpy
.
asarray
(
cx
.
_strides
)
*
4
==
px
.
strides
)
# Test that we return an error
try
:
px
=
to_gpuarray
(
cx
)
assert
False
except
ValueError
:
pass
def
test_to_cudandarray
():
px
=
pycuda
.
gpuarray
.
zeros
((
3
,
4
,
5
),
'float32'
)
cx
=
to_cudandarray
(
px
)
assert
isinstance
(
cx
,
cuda
.
CudaNdarray
)
assert
numpy
.
allclose
(
px
.
get
(),
numpy
.
asarray
(
cx
))
assert
px
.
dtype
==
cx
.
dtype
assert
px
.
shape
==
cx
.
shape
assert
all
(
numpy
.
asarray
(
cx
.
_strides
)
*
4
==
px
.
strides
)
try
:
px
=
pycuda
.
gpuarray
.
zeros
((
3
,
4
,
5
),
'float64'
)
to_cudandarray
(
px
)
assert
False
except
ValueError
:
pass
try
:
to_cudandarray
(
numpy
.
zeros
(
4
))
assert
False
except
ValueError
:
pass
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论