Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
4937b42b
提交
4937b42b
authored
4月 25, 2014
作者:
abergeron
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1817 from nouiz/crash
Remove overflow and detect it if it happen again.
上级
c19a06d7
d1bfcfeb
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
24 行增加
和
10 行删除
+24
-10
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+3
-1
cuda_ndarray.cuh
theano/sandbox/cuda/cuda_ndarray.cuh
+21
-9
没有找到文件。
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
4937b42b
...
@@ -915,12 +915,14 @@ __global__ void k_take_3(const int d0, const int d1, const int d2,
...
@@ -915,12 +915,14 @@ __global__ void k_take_3(const int d0, const int d1, const int d2,
npy_int64
idx
=
indices
[
i0
];
npy_int64
idx
=
indices
[
i0
];
if
(
idx
<
0
)
if
(
idx
<
0
)
idx
+=
dB0
;
// To allow negative indexing.
idx
+=
dB0
;
// To allow negative indexing.
if
((
idx
<
0
)
||
(
idx
>=
dB0
))
if
((
idx
<
0
)
||
(
idx
>=
dB0
))
{
// Any value other the 0 probably work. But to be more safe, I want
// Any value other the 0 probably work. But to be more safe, I want
// to change all bits to prevent problem with concurrent write that
// to change all bits to prevent problem with concurrent write that
// could cross cache line. But this should not happen with the
// could cross cache line. But this should not happen with the
// current code and driver.
// current code and driver.
*
err
=
0xFFFF
;
*
err
=
0xFFFF
;
continue
;
}
for
(
int
i1
=
threadIdx
.
x
;
i1
<
d1
;
i1
+=
blockDim
.
x
){
for
(
int
i1
=
threadIdx
.
x
;
i1
<
d1
;
i1
+=
blockDim
.
x
){
for
(
int
i2
=
threadIdx
.
y
;
i2
<
d2
;
i2
+=
blockDim
.
y
){
for
(
int
i2
=
threadIdx
.
y
;
i2
<
d2
;
i2
+=
blockDim
.
y
){
int
a_idx
=
i0
*
sA0
+
i1
*
sA1
+
i2
*
sA2
;
int
a_idx
=
i0
*
sA0
+
i1
*
sA1
+
i2
*
sA2
;
...
...
theano/sandbox/cuda/cuda_ndarray.cuh
浏览文件 @
4937b42b
...
@@ -34,6 +34,11 @@
...
@@ -34,6 +34,11 @@
#include <numpy/arrayobject.h>
#include <numpy/arrayobject.h>
#include <stdio.h>
#include <stdio.h>
#include <stdint.h>
#ifndef SIZE_MAX
#define SIZE_MAX ((size_t)-1)
#endif
#include <cublas.h>
#include <cublas.h>
...
@@ -342,7 +347,7 @@ static int CudaNdarray_alloc_contiguous(CudaNdarray *self, const int nd,
...
@@ -342,7 +347,7 @@ static int CudaNdarray_alloc_contiguous(CudaNdarray *self, const int nd,
{
{
// allocate an empty ndarray with c_contiguous access
// allocate an empty ndarray with c_contiguous access
// return 0 on success
// return 0 on success
in
t
size
=
1
;
//set up the strides for contiguous tensor
size_
t
size
=
1
;
//set up the strides for contiguous tensor
assert
(
nd
>=
0
);
assert
(
nd
>=
0
);
// Here we modify the host structure to have the desired shape and
// Here we modify the host structure to have the desired shape and
...
@@ -357,6 +362,13 @@ static int CudaNdarray_alloc_contiguous(CudaNdarray *self, const int nd,
...
@@ -357,6 +362,13 @@ static int CudaNdarray_alloc_contiguous(CudaNdarray *self, const int nd,
{
{
CudaNdarray_set_stride
(
self
,
i
,
(
dim
[
i
]
==
1
)
?
0
:
size
);
CudaNdarray_set_stride
(
self
,
i
,
(
dim
[
i
]
==
1
)
?
0
:
size
);
CudaNdarray_set_dim
(
self
,
i
,
dim
[
i
]);
CudaNdarray_set_dim
(
self
,
i
,
dim
[
i
]);
//Detect overflow on unsigned integer
if
(
size
>
(
SIZE_MAX
/
dim
[
i
]))
{
PyErr_Format
(
PyExc_AssertionError
,
"Can't store in size_t the bytes resquested"
,
size
);
return
-
1
;
}
size
=
size
*
dim
[
i
];
size
=
size
*
dim
[
i
];
}
}
}
}
...
@@ -366,6 +378,14 @@ static int CudaNdarray_alloc_contiguous(CudaNdarray *self, const int nd,
...
@@ -366,6 +378,14 @@ static int CudaNdarray_alloc_contiguous(CudaNdarray *self, const int nd,
{
{
CudaNdarray_set_stride
(
self
,
i
,
(
dim
[
i
]
==
1
)
?
0
:
size
);
CudaNdarray_set_stride
(
self
,
i
,
(
dim
[
i
]
==
1
)
?
0
:
size
);
CudaNdarray_set_dim
(
self
,
i
,
dim
[
i
]);
CudaNdarray_set_dim
(
self
,
i
,
dim
[
i
]);
//Detect overflow on unsigned integer
if
(
size
>
(
SIZE_MAX
/
dim
[
i
]))
{
PyErr_Format
(
PyExc_AssertionError
,
"Can't store in size_t the bytes resquested"
,
size
);
return
-
1
;
}
size
=
size
*
dim
[
i
];
size
=
size
*
dim
[
i
];
}
}
}
}
...
@@ -393,14 +413,6 @@ static int CudaNdarray_alloc_contiguous(CudaNdarray *self, const int nd,
...
@@ -393,14 +413,6 @@ static int CudaNdarray_alloc_contiguous(CudaNdarray *self, const int nd,
return
-
1
;
return
-
1
;
}
}
if
(
size
<
0
)
{
PyErr_Format
(
PyExc_AssertionError
,
"size (%i) < 0"
,
size
);
return
-
1
;
}
self
->
devdata
=
(
float
*
)
device_malloc
(
size
*
sizeof
(
real
));
self
->
devdata
=
(
float
*
)
device_malloc
(
size
*
sizeof
(
real
));
if
(
size
&&
!
self
->
devdata
)
if
(
size
&&
!
self
->
devdata
)
{
{
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论