Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1da867d8
提交
1da867d8
authored
11月 30, 2011
作者:
Olivier Delalleau
浏览文件
操作
浏览文件
下载
差异文件
Merge remote-tracking branch 'delallea/surban-master'
上级
24b5cff9
18b95657
隐藏空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
680 行增加
和
311 行删除
+680
-311
.gitignore
.gitignore
+1
-0
Theano.pyproj
Theano.pyproj
+232
-0
Theano.sln
Theano.sln
+18
-0
compiledir.py
theano/gof/compiledir.py
+13
-2
__init__.py
theano/sandbox/cuda/__init__.py
+6
-2
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+4
-1
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+288
-0
cuda_ndarray.cuh
theano/sandbox/cuda/cuda_ndarray.cuh
+64
-287
elemwise.py
theano/sandbox/cuda/elemwise.py
+44
-16
nvcc_compiler.py
theano/sandbox/cuda/nvcc_compiler.py
+10
-3
没有找到文件。
.gitignore
浏览文件 @
1da867d8
...
...
@@ -33,3 +33,4 @@ theano/version.py
theano/version.py.out
distribute-*.egg
distribute-*.tar.gz
Theano.suo
Theano.pyproj
0 → 100644
浏览文件 @
1da867d8
<?xml version="1.0" encoding="utf-8"?>
<Project
DefaultTargets=
"Build"
xmlns=
"http://schemas.microsoft.com/developer/msbuild/2003"
>
<PropertyGroup>
<Configuration
Condition=
" '$(Configuration)' == '' "
>
Debug
</Configuration>
<SchemaVersion>
2.0
</SchemaVersion>
<ProjectGuid>
{b67d762d-0020-4e02-9ddf-7db4f89b1dd3}
</ProjectGuid>
<ProjectHome>
.
</ProjectHome>
<StartupFile>
</StartupFile>
<SearchPath>
</SearchPath>
<WorkingDirectory>
.
</WorkingDirectory>
<OutputPath>
.
</OutputPath>
<Name>
Theano
</Name>
<RootNamespace>
Theano
</RootNamespace>
<IsWindowsApplication>
False
</IsWindowsApplication>
<InterpreterId>
2af0f10d-7135-4994-9156-5d01c9c11b7e
</InterpreterId>
<InterpreterVersion>
2.7
</InterpreterVersion>
</PropertyGroup>
<PropertyGroup
Condition=
" '$(Configuration)' == 'Debug' "
>
<DebugSymbols>
true
</DebugSymbols>
<EnableUnmanagedDebugging>
false
</EnableUnmanagedDebugging>
</PropertyGroup>
<PropertyGroup
Condition=
" '$(Configuration)' == 'Release' "
>
<DebugSymbols>
true
</DebugSymbols>
<EnableUnmanagedDebugging>
false
</EnableUnmanagedDebugging>
</PropertyGroup>
<ItemGroup>
<Compile
Include=
"theano\compile\builders.py"
/>
<Compile
Include=
"theano\compile\debugmode.py"
/>
<Compile
Include=
"theano\compile\function.py"
/>
<Compile
Include=
"theano\compile\function_module.py"
/>
<Compile
Include=
"theano\compile\io.py"
/>
<Compile
Include=
"theano\compile\mode.py"
/>
<Compile
Include=
"theano\compile\module.py"
/>
<Compile
Include=
"theano\compile\pfunc.py"
/>
<Compile
Include=
"theano\compile\profilemode.py"
/>
<Compile
Include=
"theano\compile\profiling.py"
/>
<Compile
Include=
"theano\compile\sandbox\__init__.py"
/>
<Compile
Include=
"theano\compile\sharedvalue.py"
/>
<Compile
Include=
"theano\compile\tests\test_builders.py"
/>
<Compile
Include=
"theano\compile\tests\test_debugmode.py"
/>
<Compile
Include=
"theano\compile\tests\test_function_module.py"
/>
<Compile
Include=
"theano\compile\tests\test_inplace_opt_for_value.py"
/>
<Compile
Include=
"theano\compile\tests\test_misc.py"
/>
<Compile
Include=
"theano\compile\tests\test_modes.py"
/>
<Compile
Include=
"theano\compile\tests\test_module.py"
/>
<Compile
Include=
"theano\compile\tests\test_pfunc.py"
/>
<Compile
Include=
"theano\compile\tests\test_shared.py"
/>
<Compile
Include=
"theano\compile\tests\__init__.py"
/>
<Compile
Include=
"theano\compile\__init__.py"
/>
<Compile
Include=
"theano\configdefaults.py"
/>
<Compile
Include=
"theano\configparser.py"
/>
<Compile
Include=
"theano\gof\callcache.py"
/>
<Compile
Include=
"theano\gof\cc.py"
/>
<Compile
Include=
"theano\gof\cmodule.py"
/>
<Compile
Include=
"theano\gof\compiledir.py"
/>
<Compile
Include=
"theano\gof\compilelock.py"
/>
<Compile
Include=
"theano\gof\cutils.py"
/>
<Compile
Include=
"theano\gof\destroyhandler.py"
/>
<Compile
Include=
"theano\gof\env.py"
/>
<Compile
Include=
"theano\gof\graph.py"
/>
<Compile
Include=
"theano\gof\lazylinker_c.py"
/>
<Compile
Include=
"theano\gof\link.py"
/>
<Compile
Include=
"theano\gof\op.py"
/>
<Compile
Include=
"theano\gof\opt.py"
/>
<Compile
Include=
"theano\gof\optdb.py"
/>
<Compile
Include=
"theano\gof\python25.py"
/>
<Compile
Include=
"theano\gof\sandbox\equilibrium.py"
/>
<Compile
Include=
"theano\gof\tests\test_cc.py"
/>
<Compile
Include=
"theano\gof\tests\test_compute_test_value.py"
/>
<Compile
Include=
"theano\gof\tests\test_destroyhandler.py"
/>
<Compile
Include=
"theano\gof\tests\test_graph.py"
/>
<Compile
Include=
"theano\gof\tests\test_lazy.py"
/>
<Compile
Include=
"theano\gof\tests\test_link.py"
/>
<Compile
Include=
"theano\gof\tests\test_op.py"
/>
<Compile
Include=
"theano\gof\tests\test_opt.py"
/>
<Compile
Include=
"theano\gof\tests\test_optdb.py"
/>
<Compile
Include=
"theano\gof\tests\test_toolbox.py"
/>
<Compile
Include=
"theano\gof\tests\test_types.py"
/>
<Compile
Include=
"theano\gof\tests\test_vm.py"
/>
<Compile
Include=
"theano\gof\tests\__init__.py"
/>
<Compile
Include=
"theano\gof\toolbox.py"
/>
<Compile
Include=
"theano\gof\type.py"
/>
<Compile
Include=
"theano\gof\unify.py"
/>
<Compile
Include=
"theano\gof\utils.py"
/>
<Compile
Include=
"theano\gof\vm.py"
/>
<Compile
Include=
"theano\gof\__init__.py"
/>
<Compile
Include=
"theano\gradient.py"
/>
<Compile
Include=
"theano\ifelse.py"
/>
<Compile
Include=
"theano\misc\buildbot_filter.py"
/>
<Compile
Include=
"theano\misc\check_blas.py"
/>
<Compile
Include=
"theano\misc\check_duplicate_key.py"
/>
<Compile
Include=
"theano\misc\cudamat_utils.py"
/>
<Compile
Include=
"theano\misc\gnumpy_utils.py"
/>
<Compile
Include=
"theano\misc\hooks\argparse.py"
/>
<Compile
Include=
"theano\misc\hooks\check_whitespace.py"
/>
<Compile
Include=
"theano\misc\hooks\reindent.py"
/>
<Compile
Include=
"theano\misc\latence_gpu_transfert.py"
/>
<Compile
Include=
"theano\misc\may_share_memory.py"
/>
<Compile
Include=
"theano\misc\pycuda_example.py"
/>
<Compile
Include=
"theano\misc\pycuda_init.py"
/>
<Compile
Include=
"theano\misc\pycuda_utils.py"
/>
<Compile
Include=
"theano\misc\safe_asarray.py"
/>
<Compile
Include=
"theano\misc\strutil.py"
/>
<Compile
Include=
"theano\misc\tests\test_cudamat_utils.py"
/>
<Compile
Include=
"theano\misc\tests\test_gnumpy_utils.py"
/>
<Compile
Include=
"theano\misc\tests\test_may_share_memory.py"
/>
<Compile
Include=
"theano\misc\tests\test_pycuda_example.py"
/>
<Compile
Include=
"theano\misc\tests\test_pycuda_theano_simple.py"
/>
<Compile
Include=
"theano\misc\tests\test_pycuda_utils.py"
/>
<Compile
Include=
"theano\misc\__init__.py"
/>
<Compile
Include=
"theano\printing.py"
/>
<Compile
Include=
"theano\raise_op.py"
/>
<Compile
Include=
"theano\sandbox\conv.py"
/>
<Compile
Include=
"theano\sandbox\cuda\basic_ops.py"
/>
<Compile
Include=
"theano\sandbox\cuda\blas.py"
/>
<Compile
Include=
"theano\sandbox\cuda\elemwise.py"
/>
<Compile
Include=
"theano\sandbox\cuda\GpuConv3D.py"
/>
<Compile
Include=
"theano\sandbox\cuda\GpuConvGrad3D.py"
/>
<Compile
Include=
"theano\sandbox\cuda\GpuConvTransp3D.py"
/>
<Compile
Include=
"theano\sandbox\cuda\kernel_codegen.py"
/>
<Compile
Include=
"theano\sandbox\cuda\nnet.py"
/>
<Compile
Include=
"theano\sandbox\cuda\nvcc_compiler.py"
/>
<Compile
Include=
"theano\sandbox\cuda\opt.py"
/>
<Compile
Include=
"theano\sandbox\cuda\rng_curand.py"
/>
<Compile
Include=
"theano\sandbox\cuda\type.py"
/>
<Compile
Include=
"theano\sandbox\cuda\var.py"
/>
<Compile
Include=
"theano\sandbox\cuda\__init__.py"
/>
<Compile
Include=
"theano\sandbox\debug.py"
/>
<Compile
Include=
"theano\sandbox\downsample.py"
/>
<Compile
Include=
"theano\sandbox\fourier.py"
/>
<Compile
Include=
"theano\sandbox\linalg\ops.py"
/>
<Compile
Include=
"theano\sandbox\linalg\__init__.py"
/>
<Compile
Include=
"theano\sandbox\minimal.py"
/>
<Compile
Include=
"theano\sandbox\multinomial.py"
/>
<Compile
Include=
"theano\sandbox\neighbourhoods.py"
/>
<Compile
Include=
"theano\sandbox\neighbours.py"
/>
<Compile
Include=
"theano\sandbox\rng_mrg.py"
/>
<Compile
Include=
"theano\sandbox\softsign.py"
/>
<Compile
Include=
"theano\sandbox\solve.py"
/>
<Compile
Include=
"theano\sandbox\symbolic_module.py"
/>
<Compile
Include=
"theano\sandbox\test_multinomial.py"
/>
<Compile
Include=
"theano\sandbox\test_neighbourhoods.py"
/>
<Compile
Include=
"theano\sandbox\test_neighbours.py"
/>
<Compile
Include=
"theano\sandbox\test_rng_mrg.py"
/>
<Compile
Include=
"theano\sandbox\test_theano_object.py"
/>
<Compile
Include=
"theano\sandbox\theano_object.py"
/>
<Compile
Include=
"theano\sandbox\__init__.py"
/>
<Compile
Include=
"theano\scalar\basic.py"
/>
<Compile
Include=
"theano\scalar\basic_scipy.py"
/>
<Compile
Include=
"theano\scalar\sharedvar.py"
/>
<Compile
Include=
"theano\scalar\__init__.py"
/>
<Compile
Include=
"theano\scan_module\scan.py"
/>
<Compile
Include=
"theano\scan_module\scan_op.py"
/>
<Compile
Include=
"theano\scan_module\scan_opt.py"
/>
<Compile
Include=
"theano\scan_module\scan_perform_ext.py"
/>
<Compile
Include=
"theano\scan_module\scan_utils.py"
/>
<Compile
Include=
"theano\scan_module\scan_views.py"
/>
<Compile
Include=
"theano\scan_module\__init__.py"
/>
<Compile
Include=
"theano\sparse\basic.py"
/>
<Compile
Include=
"theano\sparse\sandbox\sp.py"
/>
<Compile
Include=
"theano\sparse\sandbox\test_sp.py"
/>
<Compile
Include=
"theano\sparse\sandbox\truedot.py"
/>
<Compile
Include=
"theano\sparse\sandbox\__init__.py"
/>
<Compile
Include=
"theano\sparse\sharedvar.py"
/>
<Compile
Include=
"theano\sparse\__init__.py"
/>
<Compile
Include=
"theano\tensor\basic.py"
/>
<Compile
Include=
"theano\tensor\blas.py"
/>
<Compile
Include=
"theano\tensor\blas_headers.py"
/>
<Compile
Include=
"theano\tensor\blas_scipy.py"
/>
<Compile
Include=
"theano\tensor\deprecated\rmodule.py"
/>
<Compile
Include=
"theano\tensor\deprecated\test_rmodule.py"
/>
<Compile
Include=
"theano\tensor\deprecated\__init__.py"
/>
<Compile
Include=
"theano\tensor\elemwise.py"
/>
<Compile
Include=
"theano\tensor\elemwise_cgen.py"
/>
<Compile
Include=
"theano\tensor\inplace.py"
/>
<Compile
Include=
"theano\tensor\nnet\conv.py"
/>
<Compile
Include=
"theano\tensor\nnet\Conv3D.py"
/>
<Compile
Include=
"theano\tensor\nnet\ConvGrad3D.py"
/>
<Compile
Include=
"theano\tensor\nnet\ConvTransp3D.py"
/>
<Compile
Include=
"theano\tensor\nnet\nnet.py"
/>
<Compile
Include=
"theano\tensor\nnet\sigm.py"
/>
<Compile
Include=
"theano\tensor\nnet\__init__.py"
/>
<Compile
Include=
"theano\tensor\opt.py"
/>
<Compile
Include=
"theano\tensor\opt_uncanonicalize.py"
/>
<Compile
Include=
"theano\tensor\randomstreams.py"
/>
<Compile
Include=
"theano\tensor\raw_random.py"
/>
<Compile
Include=
"theano\tensor\sharedvar.py"
/>
<Compile
Include=
"theano\tensor\shared_randomstreams.py"
/>
<Compile
Include=
"theano\tensor\signal\conv.py"
/>
<Compile
Include=
"theano\tensor\signal\downsample.py"
/>
<Compile
Include=
"theano\tensor\signal\__init__.py"
/>
<Compile
Include=
"theano\tensor\tensor_grad.py"
/>
<Compile
Include=
"theano\tensor\xlogx.py"
/>
<Compile
Include=
"theano\tensor\__init__.py"
/>
<Compile
Include=
"theano\updates.py"
/>
<Compile
Include=
"theano\__init__.py"
/>
</ItemGroup>
<ItemGroup>
<Folder
Include=
"theano\"
/>
<Folder
Include=
"theano\compile\"
/>
<Folder
Include=
"theano\compile\sandbox\"
/>
<Folder
Include=
"theano\compile\tests\"
/>
<Folder
Include=
"theano\gof\"
/>
<Folder
Include=
"theano\gof\sandbox\"
/>
<Folder
Include=
"theano\gof\tests\"
/>
<Folder
Include=
"theano\misc\"
/>
<Folder
Include=
"theano\misc\hooks\"
/>
<Folder
Include=
"theano\misc\tests\"
/>
<Folder
Include=
"theano\sandbox\"
/>
<Folder
Include=
"theano\sandbox\cuda\"
/>
<Folder
Include=
"theano\sandbox\linalg\"
/>
<Folder
Include=
"theano\scalar\"
/>
<Folder
Include=
"theano\scan_module\"
/>
<Folder
Include=
"theano\sparse\"
/>
<Folder
Include=
"theano\sparse\sandbox\"
/>
<Folder
Include=
"theano\tensor\"
/>
<Folder
Include=
"theano\tensor\deprecated\"
/>
<Folder
Include=
"theano\tensor\nnet\"
/>
<Folder
Include=
"theano\tensor\signal\"
/>
</ItemGroup>
<ItemGroup>
<Content
Include=
"theano\sandbox\cuda\conv.cu"
/>
<Content
Include=
"theano\sandbox\cuda\conv_full_kernel.cu"
/>
<Content
Include=
"theano\sandbox\cuda\conv_kernel.cu"
/>
<Content
Include=
"theano\sandbox\cuda\cuda_ndarray.cu"
/>
<Content
Include=
"theano\sandbox\cuda\cuda_ndarray.cuh"
/>
</ItemGroup>
<Import
Project=
"$(MSBuildToolsPath)\Microsoft.Common.targets"
/>
</Project>
\ No newline at end of file
Theano.sln
0 → 100644
浏览文件 @
1da867d8
Microsoft Visual Studio Solution File, Format Version 11.00
# Visual Studio 2010
Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "Theano", "Theano.pyproj", "{B67D762D-0020-4E02-9DDF-7DB4F89B1DD3}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Release|Any CPU = Release|Any CPU
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
{B67D762D-0020-4E02-9DDF-7DB4F89B1DD3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{B67D762D-0020-4E02-9DDF-7DB4F89B1DD3}.Release|Any CPU.ActiveCfg = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
EndGlobalSection
EndGlobal
theano/gof/compiledir.py
浏览文件 @
1da867d8
...
...
@@ -4,6 +4,7 @@ import errno
import
os
import
platform
import
re
import
sys
import
theano
from
theano.configparser
import
config
,
AddConfigVar
,
ConfigParam
,
StrParam
...
...
@@ -14,7 +15,7 @@ def default_compiledirname():
platform
.
platform
(),
platform
.
processor
(),
platform
.
python_version
()])
platform_id
=
re
.
sub
(
"[
\
(
\
)
\
s]+"
,
"_"
,
platform_id
)
platform_id
=
re
.
sub
(
"[
\
(
\
)
\
s
,
]+"
,
"_"
,
platform_id
)
return
'compiledir_'
+
platform_id
...
...
@@ -50,9 +51,19 @@ def filter_compiledir(path):
return
path
# TODO Using the local user profile on Windows is currently disabled as it
# is not documented yet, and may break some existing code. It will be enabled
# in a future code update.
if
False
and
sys
.
platform
==
'win32'
:
# On Windows we should not write temporary files to a directory
# that is part of the roaming part of the user profile. Instead
# we use the local part of the user profile.
basecompiledir
=
os
.
path
.
join
(
os
.
environ
[
'LOCALAPPDATA'
],
'theano'
)
else
:
basecompiledir
=
os
.
path
.
join
(
config
.
home
,
'.theano'
)
AddConfigVar
(
'base_compiledir'
,
"arch-independent cache directory for compiled modules"
,
StrParam
(
os
.
path
.
join
(
config
.
home
,
'.theano'
)
,
allow_override
=
False
))
StrParam
(
basecompiledir
,
allow_override
=
False
))
AddConfigVar
(
'compiledir'
,
"arch-dependent cache directory for compiled modules"
,
...
...
theano/sandbox/cuda/__init__.py
浏览文件 @
1da867d8
import
atexit
,
logging
,
os
,
stat
,
sys
import
atexit
,
logging
,
os
,
s
hutil
,
s
tat
,
sys
from
theano.compile
import
optdb
from
theano.gof.cmodule
import
get_lib_extension
from
theano.configparser
import
config
,
AddConfigVar
,
StrParam
...
...
@@ -122,7 +122,11 @@ if cuda_available:
try
:
open
(
libcuda_ndarray_so
)
.
close
()
except
IOError
:
os
.
symlink
(
cuda_ndarray_so
,
libcuda_ndarray_so
)
if
sys
.
platform
==
"win32"
:
# The Python `os` module does not support symlinks on win32.
shutil
.
copyfile
(
cuda_ndarray_so
,
libcuda_ndarray_so
)
else
:
os
.
symlink
(
cuda_ndarray_so
,
libcuda_ndarray_so
)
try
:
gpu_init
()
...
...
theano/sandbox/cuda/basic_ops.py
浏览文件 @
1da867d8
...
...
@@ -471,7 +471,10 @@ class GpuSum(Op):
)
{
"""
%
locals
()
print
>>
sio
,
"int new_dims[
%(nd_out)
s]; "
%
locals
()
if
nd_out
>
0
:
print
>>
sio
,
"int new_dims[
%(nd_out)
s]; "
%
locals
()
else
:
print
>>
sio
,
"int *new_dims=NULL; "
j
=
0
for
i
in
xrange
(
nd_in
):
...
...
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
1da867d8
#define _CUDA_NDARRAY_C
#include <Python.h>
#include <structmember.h>
...
...
@@ -3420,6 +3422,292 @@ CudaNdarray_Dimshuffle(PyObject* _unused, PyObject* args)
return
NULL
;
}
int
cnda_structure_size
(
int
nd
)
{
// dim0, dim1, ...
// str0, str1, ...
// log2(dim0), log2(dim1), ...
return
nd
+
nd
+
nd
;
}
const
int
*
CudaNdarray_HOST_DIMS
(
const
CudaNdarray
*
self
)
{
return
self
->
host_structure
;
}
const
int
*
CudaNdarray_HOST_STRIDES
(
const
CudaNdarray
*
self
)
{
return
self
->
host_structure
+
self
->
nd
;
}
const
int
*
CudaNdarray_HOST_LOG2DIMS
(
const
CudaNdarray
*
self
)
{
return
self
->
host_structure
+
2
*
self
->
nd
;
}
void
cnda_mark_dev_structure_dirty
(
CudaNdarray
*
self
)
{
self
->
dev_structure_fresh
=
0
;
}
int
CudaNdarray_EqualAndIgnore
(
CudaNdarray
*
cnda1
,
CudaNdarray
*
cnda2
,
int
ignoreSync
,
int
ignoreBase
)
{
int
verbose
=
1
;
if
(
!
ignoreSync
&&
cnda1
->
dev_structure_fresh
!=
cnda2
->
dev_structure_fresh
)
{
if
(
verbose
)
fprintf
(
stdout
,
"CUDANDARRAY_EQUAL FAILED : 1
\n
"
);
return
0
;
}
if
(
cnda1
->
nd
!=
cnda2
->
nd
)
{
if
(
verbose
)
fprintf
(
stdout
,
"CUDANDARRAY_EQUAL FAILED : 2
\n
"
);
return
0
;
}
for
(
int
i
=
0
;
i
<
2
*
cnda1
->
nd
;
i
++
)
{
if
(
cnda1
->
host_structure
[
i
]
!=
cnda2
->
host_structure
[
i
])
{
if
(
verbose
)
fprintf
(
stdout
,
"CUDANDARRAY_EQUAL : host_structure : %d, %d, %d
\n
"
,
i
,
cnda1
->
host_structure
[
i
],
cnda2
->
host_structure
[
i
]);
return
0
;
}
}
if
(
!
ignoreBase
&&
cnda1
->
base
!=
cnda2
->
base
)
{
if
(
verbose
)
fprintf
(
stdout
,
"CUDANDARRAY_EQUAL FAILED : 4"
);
return
0
;
}
else
if
(
cnda1
->
data_allocated
!=
cnda2
->
data_allocated
)
{
if
(
verbose
)
fprintf
(
stdout
,
"CUDANDARRAY_EQUAL FAILED : 5"
);
return
0
;
}
else
if
(
cnda1
->
data_allocated
&&
cnda1
->
devdata
!=
cnda2
->
devdata
)
{
if
(
verbose
)
fprintf
(
stdout
,
"CUDANDARRAY_EQUAL FAILED : 6"
);
// no need to check devdata if data is not allocated
return
0
;
}
return
1
;
}
int
CudaNdarray_Equal
(
CudaNdarray
*
cnda1
,
CudaNdarray
*
cnda2
)
{
return
CudaNdarray_EqualAndIgnore
(
cnda1
,
cnda2
,
0
,
0
);
}
void
CudaNdarray_set_dim
(
CudaNdarray
*
self
,
int
idx
,
int
d
)
{
if
((
idx
>=
self
->
nd
)
||
(
idx
<
0
)
||
(
d
<
0
))
{
fprintf
(
stderr
,
"WARNING: probably bad CudaNdarray_set_dim arguments: %i %i
\n
"
,
idx
,
d
);
}
if
(
d
!=
self
->
host_structure
[
idx
])
{
self
->
host_structure
[
idx
]
=
d
;
int
log2d
=
(
int
)
log2
((
double
)
d
);
self
->
host_structure
[
idx
+
2
*
self
->
nd
]
=
(
d
==
(
1
<<
log2d
))
?
log2d
:
-
1
;
cnda_mark_dev_structure_dirty
(
self
);
}
}
void
CudaNdarray_set_stride
(
CudaNdarray
*
self
,
int
idx
,
int
s
)
{
if
((
idx
>=
self
->
nd
)
||
(
idx
<
0
))
{
fprintf
(
stderr
,
"WARNING: probably bad CudaNdarray_set_stride arguments: %i %i
\n
"
,
idx
,
s
);
}
if
(
s
!=
CudaNdarray_HOST_STRIDES
(
self
)[
idx
])
{
self
->
host_structure
[
idx
+
self
->
nd
]
=
s
;
cnda_mark_dev_structure_dirty
(
self
);
}
}
int
cnda_copy_structure_to_device
(
CudaNdarray
*
self
)
{
cublasSetVector
(
cnda_structure_size
(
self
->
nd
),
sizeof
(
int
),
self
->
host_structure
,
1
,
self
->
dev_structure
,
1
);
CNDA_THREAD_SYNC
;
if
(
CUBLAS_STATUS_SUCCESS
!=
cublasGetError
())
{
PyErr_SetString
(
PyExc_RuntimeError
,
"error copying structure to device memory"
);
return
-
1
;
}
self
->
dev_structure_fresh
=
1
;
return
0
;
}
const
int
*
CudaNdarray_DEV_DIMS
(
CudaNdarray
*
self
)
{
if
(
!
self
->
dev_structure_fresh
)
{
if
(
cnda_copy_structure_to_device
(
self
))
return
NULL
;
}
return
self
->
dev_structure
;
}
const
int
*
CudaNdarray_DEV_STRIDES
(
CudaNdarray
*
self
)
{
if
(
!
self
->
dev_structure_fresh
)
{
if
(
cnda_copy_structure_to_device
(
self
))
return
NULL
;
}
return
self
->
dev_structure
+
self
->
nd
;
}
const
int
*
CudaNdarray_DEV_LOG2DIMS
(
CudaNdarray
*
self
)
{
if
(
!
self
->
dev_structure_fresh
)
{
if
(
cnda_copy_structure_to_device
(
self
))
return
NULL
;
}
return
self
->
dev_structure
+
2
*
self
->
nd
;
}
float
*
CudaNdarray_DEV_DATA
(
const
CudaNdarray
*
self
)
{
return
self
->
devdata
;
}
/**
* Return the number of elements in the ndarray (product of the dimensions)
*/
int
CudaNdarray_SIZE
(
const
CudaNdarray
*
self
)
{
if
(
self
->
nd
==
-
1
)
return
0
;
int
size
=
1
;
for
(
int
i
=
0
;
i
<
self
->
nd
;
++
i
)
{
size
*=
CudaNdarray_HOST_DIMS
(
self
)[
i
];
}
return
size
;
}
PyObject
*
CudaNdarray_SIZE_Object
(
const
CudaNdarray
*
self
,
void
*
closure
)
{
return
PyInt_FromLong
(
CudaNdarray_SIZE
(
self
));
}
int
CudaNdarray_set_nd
(
CudaNdarray
*
self
,
const
int
nd
)
{
if
(
nd
!=
self
->
nd
)
{
if
(
self
->
dev_structure
)
{
if
(
device_free
(
self
->
dev_structure
))
{
return
-
1
;
}
self
->
dev_structure
=
NULL
;
}
if
(
self
->
host_structure
)
{
free
(
self
->
host_structure
);
self
->
host_structure
=
NULL
;
self
->
nd
=
-
1
;
}
if
(
nd
==
-
1
)
return
0
;
self
->
host_structure
=
(
int
*
)
malloc
(
cnda_structure_size
(
nd
)
*
sizeof
(
int
));
if
(
NULL
==
self
->
host_structure
)
{
PyErr_SetString
(
PyExc_MemoryError
,
"Failed to allocate dim or str"
);
return
-
1
;
}
//initialize all dimensions and strides to 0
for
(
int
i
=
0
;
i
<
cnda_structure_size
(
nd
);
++
i
)
{
self
->
host_structure
[
i
]
=
0
;
}
int
struct_size
=
cnda_structure_size
(
nd
);
if
(
struct_size
)
{
self
->
dev_structure
=
(
int
*
)
device_malloc
(
struct_size
*
sizeof
(
int
));
if
(
NULL
==
self
->
dev_structure
)
{
free
(
self
->
host_structure
);
self
->
host_structure
=
NULL
;
self
->
dev_structure
=
NULL
;
return
-
1
;
}
}
self
->
nd
=
nd
;
self
->
dev_structure_fresh
=
0
;
}
return
0
;
}
int
CudaNdarray_set_device_data
(
CudaNdarray
*
self
,
float
*
data
,
CudaNdarray
*
base
)
{
return
CudaNdarray_set_device_data
(
self
,
data
,
(
PyObject
*
)
base
);
}
PyObject
*
CudaNdarray_IS_C_Contiguous
(
CudaNdarray
*
self
)
{
return
PyBool_FromLong
(
CudaNdarray_is_c_contiguous
(
self
));
}
void
fprint_CudaNdarray
(
FILE
*
fd
,
const
CudaNdarray
*
self
)
{
fprintf
(
fd
,
"CudaNdarray <%p, %p> nd=%i dev_structure_fresh=%d data_allocated=%d
\n
"
,
self
,
self
->
devdata
,
self
->
nd
,
self
->
dev_structure_fresh
,
self
->
data_allocated
);
fprintf
(
fd
,
"
\t
HOST_DIMS: "
);
for
(
int
i
=
0
;
i
<
self
->
nd
;
++
i
)
{
fprintf
(
fd
,
"%i
\t
"
,
CudaNdarray_HOST_DIMS
(
self
)[
i
]);
}
fprintf
(
fd
,
"
\n\t
HOST_STRIDES: "
);
for
(
int
i
=
0
;
i
<
self
->
nd
;
++
i
)
{
fprintf
(
fd
,
"%i
\t
"
,
CudaNdarray_HOST_STRIDES
(
self
)[
i
]);
}
int
data
=
0
;
fprintf
(
fd
,
"
\n\t
DEV_DIMS: "
);
for
(
int
i
=
0
;
i
<
self
->
nd
;
++
i
)
{
cublasGetVector
(
1
,
sizeof
(
int
),
self
->
dev_structure
+
i
,
1
,
&
data
,
1
);
fprintf
(
fd
,
"%i
\t
"
,
data
);
}
fprintf
(
fd
,
"
\n\t
DEV_STRIDES: "
);
for
(
int
i
=
0
;
i
<
self
->
nd
;
++
i
)
{
cublasGetVector
(
1
,
sizeof
(
int
),
self
->
dev_structure
+
self
->
nd
+
i
,
1
,
&
data
,
1
);
fprintf
(
fd
,
"%i
\t
"
,
data
);
}
fprintf
(
fd
,
"
\n
"
);
}
/*
Local Variables:
mode:c++
...
...
theano/sandbox/cuda/cuda_ndarray.cuh
浏览文件 @
1da867d8
...
...
@@ -6,6 +6,16 @@
#include <cublas.h>
#ifdef _WIN32
#ifdef _CUDA_NDARRAY_C
#define DllExport __declspec( dllexport )
#else
#define DllExport __declspec( dllimport )
#endif
#else
#define DllExport
#endif
typedef
float
real
;
#define REAL_TYPENUM 11
...
...
@@ -36,8 +46,8 @@ typedef float real;
* device_malloc will set the Python error message before returning None.
* device_free will return nonzero on failure (after setting the python error message)
*/
void
*
device_malloc
(
size_t
size
);
int
device_free
(
void
*
ptr
);
DllExport
void
*
device_malloc
(
size_t
size
);
DllExport
int
device_free
(
void
*
ptr
);
template
<
typename
T
>
static
T
ceil_intdiv
(
T
a
,
T
b
)
...
...
@@ -81,114 +91,50 @@ struct CudaNdarray
* Return a CudaNdarray whose 'nd' dimensions are all 0.
* if nd==-1, it is not initialized.
*/
PyObject
*
DllExport
PyObject
*
CudaNdarray_New
(
int
nd
=-
1
);
/**
* Return 1 for a CudaNdarray otw 0
*/
int
DllExport
int
CudaNdarray_Check
(
const
PyObject
*
ob
);
/**
* Return 1 for a CudaNdarray otw 0
*/
int
DllExport
int
CudaNdarray_CheckExact
(
const
PyObject
*
ob
);
/**
* Return true for a C-contiguous CudaNdarray, else false
*/
bool
DllExport
bool
CudaNdarray_is_c_contiguous
(
const
CudaNdarray
*
self
);
/****
* Returns the number of elements necessary in host_structure and dev_structure for a given number of dimensions.
*/
int
cnda_structure_size
(
int
nd
)
{
// dim0, dim1, ...
// str0, str1, ...
// log2(dim0), log2(dim1), ...
return
nd
+
nd
+
nd
;
}
DllExport
int
cnda_structure_size
(
int
nd
);
const
int
*
CudaNdarray_HOST_DIMS
(
const
CudaNdarray
*
self
)
{
return
self
->
host_structure
;
}
const
int
*
CudaNdarray_HOST_STRIDES
(
const
CudaNdarray
*
self
)
{
return
self
->
host_structure
+
self
->
nd
;
}
const
int
*
CudaNdarray_HOST_LOG2DIMS
(
const
CudaNdarray
*
self
)
{
return
self
->
host_structure
+
2
*
self
->
nd
;
}
DllExport
const
int
*
CudaNdarray_HOST_DIMS
(
const
CudaNdarray
*
self
);
void
cnda_mark_dev_structure_dirty
(
CudaNdarray
*
self
)
{
self
->
dev_structure_fresh
=
0
;
}
DllExport
const
int
*
CudaNdarray_HOST_STRIDES
(
const
CudaNdarray
*
self
);
int
CudaNdarray_EqualAndIgnore
(
CudaNdarray
*
cnda1
,
CudaNdarray
*
cnda2
,
int
ignoreSync
,
int
ignoreBase
)
{
int
verbose
=
1
;
if
(
!
ignoreSync
&&
cnda1
->
dev_structure_fresh
!=
cnda2
->
dev_structure_fresh
)
{
if
(
verbose
)
fprintf
(
stdout
,
"CUDANDARRAY_EQUAL FAILED : 1
\n
"
);
return
0
;
}
if
(
cnda1
->
nd
!=
cnda2
->
nd
)
{
if
(
verbose
)
fprintf
(
stdout
,
"CUDANDARRAY_EQUAL FAILED : 2
\n
"
);
return
0
;
}
for
(
int
i
=
0
;
i
<
2
*
cnda1
->
nd
;
i
++
)
{
if
(
cnda1
->
host_structure
[
i
]
!=
cnda2
->
host_structure
[
i
])
{
if
(
verbose
)
fprintf
(
stdout
,
"CUDANDARRAY_EQUAL : host_structure : %d, %d, %d
\n
"
,
i
,
cnda1
->
host_structure
[
i
],
cnda2
->
host_structure
[
i
]);
return
0
;
}
}
DllExport
const
int
*
CudaNdarray_HOST_LOG2DIMS
(
const
CudaNdarray
*
self
);
if
(
!
ignoreBase
&&
cnda1
->
base
!=
cnda2
->
base
)
{
if
(
verbose
)
fprintf
(
stdout
,
"CUDANDARRAY_EQUAL FAILED : 4"
);
return
0
;
}
else
if
(
cnda1
->
data_allocated
!=
cnda2
->
data_allocated
)
{
if
(
verbose
)
fprintf
(
stdout
,
"CUDANDARRAY_EQUAL FAILED : 5"
);
return
0
;
}
else
if
(
cnda1
->
data_allocated
&&
cnda1
->
devdata
!=
cnda2
->
devdata
)
{
if
(
verbose
)
fprintf
(
stdout
,
"CUDANDARRAY_EQUAL FAILED : 6"
);
// no need to check devdata if data is not allocated
return
0
;
}
DllExport
void
cnda_mark_dev_structure_dirty
(
CudaNdarray
*
self
);
return
1
;
}
DllExport
int
CudaNdarray_EqualAndIgnore
(
CudaNdarray
*
cnda1
,
CudaNdarray
*
cnda2
,
int
ignoreSync
,
int
ignoreBase
);
// Default: do not ignore sync of dev and host structures in comparing, and do not ignore difference in base pointers
int
CudaNdarray_Equal
(
CudaNdarray
*
cnda1
,
CudaNdarray
*
cnda2
)
{
return
CudaNdarray_EqualAndIgnore
(
cnda1
,
cnda2
,
0
,
0
);
}
DllExport
int
CudaNdarray_Equal
(
CudaNdarray
*
cnda1
,
CudaNdarray
*
cnda2
);
/****
* Set the idx'th dimension to value d.
...
...
@@ -197,173 +143,44 @@ CudaNdarray_Equal(CudaNdarray *cnda1, CudaNdarray *cnda2)
*
* Does not sync structure to host.
*/
void
CudaNdarray_set_dim
(
CudaNdarray
*
self
,
int
idx
,
int
d
)
{
if
((
idx
>=
self
->
nd
)
||
(
idx
<
0
)
||
(
d
<
0
))
{
fprintf
(
stderr
,
"WARNING: probably bad CudaNdarray_set_dim arguments: %i %i
\n
"
,
idx
,
d
);
}
DllExport
void
CudaNdarray_set_dim
(
CudaNdarray
*
self
,
int
idx
,
int
d
);
if
(
d
!=
self
->
host_structure
[
idx
])
{
self
->
host_structure
[
idx
]
=
d
;
int
log2d
=
(
int
)
log2
((
double
)
d
);
self
->
host_structure
[
idx
+
2
*
self
->
nd
]
=
(
d
==
(
1
<<
log2d
))
?
log2d
:
-
1
;
cnda_mark_dev_structure_dirty
(
self
);
}
}
void
CudaNdarray_set_stride
(
CudaNdarray
*
self
,
int
idx
,
int
s
)
{
if
((
idx
>=
self
->
nd
)
||
(
idx
<
0
))
{
fprintf
(
stderr
,
"WARNING: probably bad CudaNdarray_set_stride arguments: %i %i
\n
"
,
idx
,
s
);
}
DllExport
void
CudaNdarray_set_stride
(
CudaNdarray
*
self
,
int
idx
,
int
s
);
if
(
s
!=
CudaNdarray_HOST_STRIDES
(
self
)[
idx
])
{
self
->
host_structure
[
idx
+
self
->
nd
]
=
s
;
cnda_mark_dev_structure_dirty
(
self
);
}
}
/***
* Update dependent variables from the contents of CudaNdarray_HOST_DIMS(self) and CudaNdarray_HOST_STRIDES(self)
*
* This means: recalculate the log2dims and transfer structure to the card
*/
int
cnda_copy_structure_to_device
(
CudaNdarray
*
self
)
{
cublasSetVector
(
cnda_structure_size
(
self
->
nd
),
sizeof
(
int
),
self
->
host_structure
,
1
,
self
->
dev_structure
,
1
);
CNDA_THREAD_SYNC
;
if
(
CUBLAS_STATUS_SUCCESS
!=
cublasGetError
())
{
PyErr_SetString
(
PyExc_RuntimeError
,
"error copying structure to device memory"
);
return
-
1
;
}
self
->
dev_structure_fresh
=
1
;
return
0
;
}
DllExport
int
cnda_copy_structure_to_device
(
CudaNdarray
*
self
);
const
int
*
CudaNdarray_DEV_DIMS
(
CudaNdarray
*
self
)
{
if
(
!
self
->
dev_structure_fresh
)
{
if
(
cnda_copy_structure_to_device
(
self
))
return
NULL
;
}
return
self
->
dev_structure
;
}
const
int
*
CudaNdarray_DEV_STRIDES
(
CudaNdarray
*
self
)
{
if
(
!
self
->
dev_structure_fresh
)
{
if
(
cnda_copy_structure_to_device
(
self
))
return
NULL
;
}
return
self
->
dev_structure
+
self
->
nd
;
}
const
int
*
CudaNdarray_DEV_LOG2DIMS
(
CudaNdarray
*
self
)
{
if
(
!
self
->
dev_structure_fresh
)
{
if
(
cnda_copy_structure_to_device
(
self
))
return
NULL
;
}
return
self
->
dev_structure
+
2
*
self
->
nd
;
}
float
*
CudaNdarray_DEV_DATA
(
const
CudaNdarray
*
self
)
{
return
self
->
devdata
;
}
DllExport
const
int
*
CudaNdarray_DEV_DIMS
(
CudaNdarray
*
self
);
DllExport
const
int
*
CudaNdarray_DEV_STRIDES
(
CudaNdarray
*
self
);
DllExport
const
int
*
CudaNdarray_DEV_LOG2DIMS
(
CudaNdarray
*
self
);
DllExport
float
*
CudaNdarray_DEV_DATA
(
const
CudaNdarray
*
self
);
/**
* Return the number of elements in the ndarray (product of the dimensions)
*/
int
CudaNdarray_SIZE
(
const
CudaNdarray
*
self
)
{
if
(
self
->
nd
==
-
1
)
return
0
;
int
size
=
1
;
for
(
int
i
=
0
;
i
<
self
->
nd
;
++
i
)
{
size
*=
CudaNdarray_HOST_DIMS
(
self
)[
i
];
}
return
size
;
}
static
PyObject
*
CudaNdarray_SIZE_Object
(
const
CudaNdarray
*
self
,
void
*
closure
)
{
return
PyInt_FromLong
(
CudaNdarray_SIZE
(
self
));
}
DllExport
int
CudaNdarray_SIZE
(
const
CudaNdarray
*
self
);
static
PyObject
*
CudaNdarray_SIZE_Object
(
const
CudaNdarray
*
self
,
void
*
closure
);
/**
* Allocate a new CudaNdarray with room for given number of dimensions
*
* No Storage space is allocated (and all dimensions are 0)
*/
PyObject
*
CudaNdarray_new_nd
(
const
int
nd
);
DllExport
PyObject
*
CudaNdarray_new_nd
(
const
int
nd
);
/**
* [Re]allocate a CudaNdarray with access to 'nd' dimensions.
*
* Note: This does not allocate storage for data.
*/
int
CudaNdarray_set_nd
(
CudaNdarray
*
self
,
const
int
nd
)
{
if
(
nd
!=
self
->
nd
)
{
if
(
self
->
dev_structure
)
{
if
(
device_free
(
self
->
dev_structure
))
{
return
-
1
;
}
self
->
dev_structure
=
NULL
;
}
if
(
self
->
host_structure
)
{
free
(
self
->
host_structure
);
self
->
host_structure
=
NULL
;
self
->
nd
=
-
1
;
}
if
(
nd
==
-
1
)
return
0
;
self
->
host_structure
=
(
int
*
)
malloc
(
cnda_structure_size
(
nd
)
*
sizeof
(
int
));
if
(
NULL
==
self
->
host_structure
)
{
PyErr_SetString
(
PyExc_MemoryError
,
"Failed to allocate dim or str"
);
return
-
1
;
}
//initialize all dimensions and strides to 0
for
(
int
i
=
0
;
i
<
cnda_structure_size
(
nd
);
++
i
)
{
self
->
host_structure
[
i
]
=
0
;
}
int
struct_size
=
cnda_structure_size
(
nd
);
if
(
struct_size
)
{
self
->
dev_structure
=
(
int
*
)
device_malloc
(
struct_size
*
sizeof
(
int
));
if
(
NULL
==
self
->
dev_structure
)
{
free
(
self
->
host_structure
);
self
->
host_structure
=
NULL
;
self
->
dev_structure
=
NULL
;
return
-
1
;
}
}
self
->
nd
=
nd
;
self
->
dev_structure_fresh
=
0
;
}
return
0
;
}
DllExport
int
CudaNdarray_set_nd
(
CudaNdarray
*
self
,
const
int
nd
);
/**
* CudaNdarray_alloc_contiguous
...
...
@@ -373,7 +190,7 @@ int CudaNdarray_set_nd(CudaNdarray * self, const int nd)
* Note: CudaNdarray_alloc_contiguous is templated to work for both int dimensions and npy_intp dimensions
*/
template
<
typename
inttype
>
int
CudaNdarray_alloc_contiguous
(
CudaNdarray
*
self
,
const
int
nd
,
const
inttype
*
dim
)
static
int
CudaNdarray_alloc_contiguous
(
CudaNdarray
*
self
,
const
int
nd
,
const
inttype
*
dim
)
{
// allocate an empty ndarray with c_contiguous access
// return 0 on success
...
...
@@ -434,9 +251,8 @@ int CudaNdarray_alloc_contiguous(CudaNdarray *self, const int nd, const inttype
/*
* Return a CudaNdarray whose 'nd' dimensions are set to dims, and allocated.
*/
template
<
typename
inttype
>
PyObject
*
CudaNdarray_NewDims
(
int
nd
,
const
inttype
*
dims
)
template
<
typename
inttype
>
static
PyObject
*
CudaNdarray_NewDims
(
int
nd
,
const
inttype
*
dims
)
{
CudaNdarray
*
rval
=
(
CudaNdarray
*
)
CudaNdarray_New
();
if
(
rval
)
...
...
@@ -456,103 +272,64 @@ CudaNdarray_NewDims(int nd, const inttype * dims)
*
* Set self to be a view of given `data`, owned by existing CudaNdarray `base`.
*/
int
CudaNdarray_set_device_data
(
CudaNdarray
*
self
,
float
*
data
,
PyObject
*
base
);
int
CudaNdarray_set_device_data
(
CudaNdarray
*
self
,
float
*
data
,
CudaNdarray
*
base
)
{
return
CudaNdarray_set_device_data
(
self
,
data
,
(
PyObject
*
)
base
);
}
DllExport
int
CudaNdarray_set_device_data
(
CudaNdarray
*
self
,
float
*
data
,
PyObject
*
base
);
DllExport
int
CudaNdarray_set_device_data
(
CudaNdarray
*
self
,
float
*
data
,
CudaNdarray
*
base
);
/**
* Return an independent copy of self
*/
PyObject
*
CudaNdarray_DeepCopy
(
CudaNdarray
*
self
,
PyObject
*
memo
);
DllExport
PyObject
*
CudaNdarray_DeepCopy
(
CudaNdarray
*
self
,
PyObject
*
memo
);
/**
* Return an independent copy of self
*/
PyObject
*
CudaNdarray_Copy
(
CudaNdarray
*
self
);
DllExport
PyObject
*
CudaNdarray_Copy
(
CudaNdarray
*
self
);
/**
* Return a new object obtained by summing over the dimensions for which there is a 1 in the mask.
*/
PyObject
*
CudaNdarray_ReduceSum
(
CudaNdarray
*
self
,
PyObject
*
py_reduce_mask
);
DllExport
PyObject
*
CudaNdarray_ReduceSum
(
CudaNdarray
*
self
,
PyObject
*
py_reduce_mask
);
/**
* Transfer the contents of numpy array `obj` to `self`.
*
* self is reallocated to have the correct dimensions if necessary.
*/
int
CudaNdarray_CopyFromArray
(
CudaNdarray
*
self
,
PyArrayObject
*
obj
);
DllExport
int
CudaNdarray_CopyFromArray
(
CudaNdarray
*
self
,
PyArrayObject
*
obj
);
/**
* Transfer the contents of CudaNdarray `other` to `self`.
*
* self is reallocated to have the correct dimensions if necessary.
*/
int
CudaNdarray_CopyFromCudaNdarray
(
CudaNdarray
*
self
,
CudaNdarray
*
other
,
bool
unbroadcast
=
false
);
DllExport
int
CudaNdarray_CopyFromCudaNdarray
(
CudaNdarray
*
self
,
CudaNdarray
*
other
,
bool
unbroadcast
=
false
);
/**
* Transfer the contents of CudaNdarray `self` to a new numpy ndarray.
*/
PyObject
*
DllExport
PyObject
*
CudaNdarray_CreateArrayObj
(
CudaNdarray
*
self
);
PyObject
*
DllExport
PyObject
*
CudaNdarray_ZEROS
(
int
n
,
int
*
dims
);
/**
* True iff the strides look like [dim[nd-2], dim[nd-3], ... , dim[0], 1]
*/
bool
CudaNdarray_is_c_contiguous
(
const
CudaNdarray
*
self
);
PyObject
*
CudaNdarray_IS_C_Contiguous
(
CudaNdarray
*
self
)
{
return
PyBool_FromLong
(
CudaNdarray_is_c_contiguous
(
self
));
}
int
CudaNdarray_gemm
(
float
alpha
,
const
CudaNdarray
*
A
,
const
CudaNdarray
*
B
,
float
beta
,
CudaNdarray
*
C
);
int
CudaNdarray_sger
(
float
alpha
,
CudaNdarray
*
x
,
CudaNdarray
*
y
,
CudaNdarray
*
A
);
DllExport
bool
CudaNdarray_is_c_contiguous
(
const
CudaNdarray
*
self
);
DllExport
PyObject
*
CudaNdarray_IS_C_Contiguous
(
CudaNdarray
*
self
);
int
CudaNdarray_reduce_sum
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
int
CudaNdarray_reduce_prod
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
int
CudaNdarray_reduce_min
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
int
CudaNdarray_reduce_max
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
DllExport
int
CudaNdarray_gemm
(
float
alpha
,
const
CudaNdarray
*
A
,
const
CudaNdarray
*
B
,
float
beta
,
CudaNdarray
*
C
);
DllExport
int
CudaNdarray_sger
(
float
alpha
,
CudaNdarray
*
x
,
CudaNdarray
*
y
,
CudaNdarray
*
A
);
int
CudaNdarray_dimshuffle
(
CudaNdarray
*
self
,
unsigned
int
len
,
const
int
*
pattern
);
DllExport
int
CudaNdarray_reduce_sum
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
DllExport
int
CudaNdarray_reduce_prod
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
DllExport
int
CudaNdarray_reduce_min
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
DllExport
int
CudaNdarray_reduce_max
(
CudaNdarray
*
self
,
CudaNdarray
*
A
);
void
fprint_CudaNdarray
(
FILE
*
fd
,
const
CudaNdarray
*
self
)
{
fprintf
(
fd
,
"CudaNdarray <%p, %p> nd=%i dev_structure_fresh=%d data_allocated=%d
\n
"
,
self
,
self
->
devdata
,
self
->
nd
,
self
->
dev_structure_fresh
,
self
->
data_allocated
);
fprintf
(
fd
,
"
\t
HOST_DIMS: "
);
for
(
int
i
=
0
;
i
<
self
->
nd
;
++
i
)
{
fprintf
(
fd
,
"%i
\t
"
,
CudaNdarray_HOST_DIMS
(
self
)[
i
]);
}
fprintf
(
fd
,
"
\n\t
HOST_STRIDES: "
);
for
(
int
i
=
0
;
i
<
self
->
nd
;
++
i
)
{
fprintf
(
fd
,
"%i
\t
"
,
CudaNdarray_HOST_STRIDES
(
self
)[
i
]);
}
DllExport
int
CudaNdarray_dimshuffle
(
CudaNdarray
*
self
,
unsigned
int
len
,
const
int
*
pattern
);
int
data
=
0
;
fprintf
(
fd
,
"
\n\t
DEV_DIMS: "
);
for
(
int
i
=
0
;
i
<
self
->
nd
;
++
i
)
{
cublasGetVector
(
1
,
sizeof
(
int
),
self
->
dev_structure
+
i
,
1
,
&
data
,
1
);
fprintf
(
fd
,
"%i
\t
"
,
data
);
}
fprintf
(
fd
,
"
\n\t
DEV_STRIDES: "
);
for
(
int
i
=
0
;
i
<
self
->
nd
;
++
i
)
{
cublasGetVector
(
1
,
sizeof
(
int
),
self
->
dev_structure
+
self
->
nd
+
i
,
1
,
&
data
,
1
);
fprintf
(
fd
,
"%i
\t
"
,
data
);
}
fprintf
(
fd
,
"
\n
"
);
}
static
void
fprint_CudaNdarray
(
FILE
*
fd
,
const
CudaNdarray
*
self
);
#endif
/*
...
...
theano/sandbox/cuda/elemwise.py
浏览文件 @
1da867d8
...
...
@@ -534,33 +534,44 @@ class NaiveAlgo(object):
# collapse dimension that are broadcast in all inputs.
# need to be done before contiguous collapse as it will break it.
# do the dimensions and the strides
if
nd
>
0
:
print
>>
sio
,
"int local_dims[
%(nd)
s];"
%
locals
()
else
:
print
>>
sio
,
"int *local_dims=NULL;"
if
nb_inputs
>
0
and
nd
>
0
:
print
>>
sio
,
"""
int local_str[
%(nb_inputs)
s][
%(nd)
s];
int local_ostr[
%(nb_inputs)
s][
%(nd)
s];
"""
%
locals
()
else
:
print
>>
sio
,
"""
int local_str[1][1];
int local_ostr[1][1];
"""
print
>>
sio
,
"""
int local_dims[
%(nd)
s];
int local_str[
%(nb_inputs)
s][
%(nd)
s];
int local_ostr[
%(nb_inputs)
s][
%(nd)
s];
int nd_collapse =
%(nd)
s;
for(int i=0;i<
%(nd)
s;i++){//init new dim
local_dims[i]=dims[i];
}
"""
%
locals
()
"""
%
locals
()
for
ipos
in
xrange
(
len
(
node
.
inputs
)):
print
>>
sio
,
"""
for(int i=0;i<
%(nd)
s;i++){//init new strides
local_str[
%(ipos)
s][i]=i
%(ipos)
s_str[i];
}
"""
%
locals
()
"""
%
locals
()
for
ipos
in
xrange
(
len
(
node
.
outputs
)):
print
>>
sio
,
"""
for(int i=0;i<
%(nd)
s;i++){//init new strides
local_ostr[
%(ipos)
s][i]=o
%(ipos)
s_str[i];
}
"""
%
locals
()
"""
%
locals
()
if
self
.
verbose
>
2
:
print
>>
sio
,
'std::cerr <<"before broadcast collapse
\\
n";'
print
>>
sio
,
'std::cerr<< "nd_collapse "<< nd_collapse << "
\\
n"; '
print
>>
sio
,
'std::cerr << "local_dims";'
for
d
in
xrange
(
nd
):
print
>>
sio
,
'std::cerr << " " << local_dims[
%(d)
s]; '
%
locals
()
print
>>
sio
,
'std::cerr << " " << local_dims[
%(d)
s]; '
%
locals
()
print
>>
sio
,
'std::cerr << "
\\
n";'
for
ipos
in
xrange
(
len
(
node
.
inputs
)):
...
...
@@ -611,11 +622,18 @@ class NaiveAlgo(object):
# collapse contiguous dimensions (ignoring scalars, generic version(collapse any dimensions, right, left, middle))
# this is a good idea because we make less index calculation in the gpu.
print
>>
sio
,
"int nd_collapse_[
%(nd)
s] = {"
%
locals
()
+
','
.
join
([
'1'
for
x
in
xrange
(
nd
)])
+
"};"
if
nd
>
0
:
print
>>
sio
,
"int nd_collapse_[
%(nd)
s] = {"
%
locals
()
+
','
.
join
([
'1'
for
x
in
xrange
(
nd
)])
+
"};"
else
:
print
>>
sio
,
"int *nd_collapse_ = NULL;"
for
ipos
in
xrange
(
len
(
node
.
inputs
)):
if
not
_logical_scalar
(
node
.
inputs
[
ipos
]):
print
>>
sio
,
"""
int nd_collapse_
%(ipos)
s[
%(nd)
s] = {"""
%
locals
()
+
','
.
join
([
'1'
for
x
in
xrange
(
nd
)])
+
"};"
if
nd
>
0
:
print
>>
sio
,
"""
int nd_collapse_
%(ipos)
s[
%(nd)
s] = {"""
%
locals
()
+
','
.
join
([
'1'
for
x
in
xrange
(
nd
)])
+
"};"
else
:
print
>>
sio
,
"""
int *nd_collapse_
%(ipos)
s = NULL;"""
%
locals
()
print
>>
sio
,
"""
can_collapse_
%(nodename)
s(nd_collapse, local_dims, local_str[
%(ipos)
s], nd_collapse_
%(ipos)
s);
for(int i=0;i<nd_collapse;i++){
...
...
@@ -839,9 +857,14 @@ nd_collapse_[i]=0;
//std::cerr << "C_CODE
%(opname)
s START
\\
n";
//standard elemwise size checks
"""
%
locals
()
print
>>
sio
,
"""
int dims[
%(nd)
s] = {
%(initial_dims)
s};
"""
%
locals
()
if
nd
>
0
:
print
>>
sio
,
"""
int dims[
%(nd)
s] = {
%(initial_dims)
s};
"""
%
locals
()
else
:
print
>>
sio
,
"""
int *dims = NULL;
"""
#check that all inputs have valid dimensions
emitted_inames
=
{}
...
...
@@ -851,9 +874,14 @@ nd_collapse_[i]=0;
continue
broadcasts
=
', '
.
join
(
map
(
str
,
map
(
int
,
node
.
inputs
[
id
]
.
broadcastable
)))
nd
=
node
.
inputs
[
id
]
.
ndim
print
>>
sio
,
"""
int broadcasts_
%(iname)
s[
%(nd)
s] = {
%(broadcasts)
s};
"""
%
locals
()
if
nd
>
0
:
print
>>
sio
,
"""
int broadcasts_
%(iname)
s[
%(nd)
s] = {
%(broadcasts)
s};
"""
%
locals
()
else
:
print
>>
sio
,
"""
int *broadcasts_
%(iname)
s = NULL;
"""
%
locals
()
emitted_inames
[
iname
]
=
node
.
inputs
[
id
]
#check that all inputs have valid dimensions
emitted_inames
=
{}
...
...
theano/sandbox/cuda/nvcc_compiler.py
浏览文件 @
1da867d8
...
...
@@ -164,7 +164,12 @@ def nvcc_module_compile_str(
if
config
.
nvcc
.
compiler_bindir
:
cmd
.
extend
([
'--compiler-bindir'
,
config
.
nvcc
.
compiler_bindir
])
if
sys
.
platform
!=
'win32'
:
if
sys
.
platform
==
'win32'
:
# add flags for Microsoft compiler to create .pdb files
preargs2
.
append
(
'/Zi'
)
cmd
.
extend
([
'-Xlinker'
,
'/DEBUG'
])
if
sys
.
platform
!=
'win32'
:
if
local_bitwidth
()
==
64
:
cmd
.
append
(
'-m64'
)
preargs2
.
append
(
'-m64'
)
...
...
@@ -180,8 +185,10 @@ def nvcc_module_compile_str(
if
sys
.
platform
!=
'darwin'
:
# the 64bit CUDA libs are in the same files as are named by the function above
rpaths
.
append
(
os
.
path
.
join
(
config
.
cuda
.
root
,
'lib64'
))
for
rpath
in
rpaths
:
cmd
.
extend
([
'-Xlinker'
,
','
.
join
([
'-rpath'
,
rpath
])])
if
sys
.
platform
!=
'win32'
:
# the -rpath option is not understood by the Microsoft linker
for
rpath
in
rpaths
:
cmd
.
extend
([
'-Xlinker'
,
','
.
join
([
'-rpath'
,
rpath
])])
cmd
.
extend
([
flag
for
flag
in
config
.
nvcc
.
flags
.
split
(
' '
)
if
flag
])
cmd
.
extend
(
'-I
%
s'
%
idir
for
idir
in
include_dirs
)
cmd
.
extend
([
'-o'
,
lib_filename
])
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论