Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9b441ceb
提交
9b441ceb
authored
4月 12, 2010
作者:
fsavard
浏览文件
操作
浏览文件
下载
差异文件
Merge
上级
c4018c57
333b0887
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
113 行增加
和
45 行删除
+113
-45
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+4
-5
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+54
-38
cuda_ndarray.cuh
theano/sandbox/cuda/cuda_ndarray.cuh
+55
-0
test_cuda_ndarray.py
theano/sandbox/cuda/tests/test_cuda_ndarray.py
+0
-2
没有找到文件。
theano/sandbox/cuda/basic_ops.py
浏览文件 @
9b441ceb
...
@@ -1465,13 +1465,12 @@ class GpuJoin(tensor.Join):
...
@@ -1465,13 +1465,12 @@ class GpuJoin(tensor.Join):
final_shape
=
list
(
cndas
[
0
]
.
shape
)
final_shape
=
list
(
cndas
[
0
]
.
shape
)
final_shape
[
axis
]
=
width_sum
final_shape
[
axis
]
=
width_sum
# just to be explicit,
set -
1 for broadcastable
# just to be explicit,
check that dim=
1 for broadcastable
# dimensions
# dimensions
for
i
,
val
in
enumerate
(
node
.
outputs
[
0
]
.
type
.
broadcastable
):
for
i
,
bcastable
in
enumerate
(
node
.
outputs
[
0
]
.
type
.
broadcastable
):
if
val
:
assert
not
bcastable
or
final_shape
[
i
]
==
1
,
"Broadcastable dimension but dim != 1, this is invalid"
final_shape
[
i
]
=
-
1
rval
=
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
.
zeros
_with_pattern
(
final_shape
)
rval
=
cuda_ndarray
.
cuda_ndarray
.
CudaNdarray
.
zeros
(
final_shape
)
curpos
=
0
curpos
=
0
...
...
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
9b441ceb
...
@@ -249,65 +249,65 @@ PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self)
...
@@ -249,65 +249,65 @@ PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self)
}
}
// declared as a static method
// declared as a static method
(hence "dummy" is not used)
// Based on _Copy and _dimshuffle
// Based on _Copy and _dimshuffle
PyObject* CudaNdarray_Zeros
WithPattern(PyObject* dummy, PyObject* pattern
)
PyObject* CudaNdarray_Zeros
(PyObject* dummy, PyObject* shape
)
{
{
if(!PySequence_Check(
pattern
))
if(!PySequence_Check(
shape
))
{
{
PyErr_SetString(PyExc_TypeError, "
pattern
argument must be a sequence");
PyErr_SetString(PyExc_TypeError, "
shape
argument must be a sequence");
return NULL;
return NULL;
}
}
int
patlen = PySequence_Length(pattern
);
int
shplen = PySequence_Length(shape
);
if (
pat
len == 0)
if (
shp
len == 0)
{
{
PyErr_SetString(PyExc_ValueError,
PyErr_SetString(PyExc_ValueError,
"CudaNdarray_
NewWithPattern: empty pattern
");
"CudaNdarray_
Zeros: empty shape not allowed
");
return NULL;
return NULL;
}
}
//fprintf(stdout, "Pattern length: %d\n",
pat
len);
//fprintf(stdout, "Pattern length: %d\n",
shp
len);
int* newdims = (int *)malloc(sizeof(int) * 2 *
pat
len);
int* newdims = (int *)malloc(sizeof(int) * 2 *
shp
len);
if (!newdims)
if (!newdims)
{
{
PyErr_SetString(PyExc_MemoryError,
PyErr_SetString(PyExc_MemoryError,
"CudaNdarray_
NewWithPattern
: Failed to allocate temporary space");
"CudaNdarray_
Zeros
: Failed to allocate temporary space");
return NULL;
return NULL;
}
}
int* newstrides = newdims +
pat
len;
int* newstrides = newdims +
shp
len;
// strides are in number of floats, not bytes
// strides are in number of floats, not bytes
int cur_stride = 1;
int cur_stride = 1;
// start from the end to compute strides
// start from the end to compute strides
for (int i =
pat
len-1; i >= 0; --i)
for (int i =
shp
len-1; i >= 0; --i)
{
{
PyObject*
pat_el_obj = PySequence_GetItem(pattern
, i);
PyObject*
shp_el_obj = PySequence_GetItem(shape
, i);
if(
pat
_el_obj == NULL)
if(
shp
_el_obj == NULL)
{
{
// shouldn't happen since we checked length before...
// shouldn't happen since we checked length before...
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_
NewWithPattern
: Index out of bound in sequence");
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_
Zeros
: Index out of bound in sequence");
free(newdims);
free(newdims);
return NULL;
return NULL;
}
}
int
pat_el = PyInt_AsLong(pat
_el_obj);
int
shp_el = PyInt_AsLong(shp
_el_obj);
if (
pat_el =
= 0)
if (
shp_el <
= 0)
{
{
PyErr_SetString(PyExc_ValueError, "CudaNdarray_
NewWithPattern: pattern must not contain 0
for size of a dimension");
PyErr_SetString(PyExc_ValueError, "CudaNdarray_
Zeros: shape must not contain 0 (or negative value)
for size of a dimension");
free(newdims);
free(newdims);
return NULL;
return NULL;
}
}
//
apparently, from looking at
alloc_contiguous, we set
//
based on
alloc_contiguous, we set
// stride=0 if the dim == 1
// stride=0 if the dim == 1
if (
pat_el < 0 || pat
_el == 1)
if (
shp
_el == 1)
{
{
// broadcast
// broadcast
newdims[i] = 1;
newdims[i] = 1;
...
@@ -315,7 +315,7 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern)
...
@@ -315,7 +315,7 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern)
}
}
else
else
{
{
newdims[i] =
pat
_el;
newdims[i] =
shp
_el;
newstrides[i] = cur_stride;
newstrides[i] = cur_stride;
}
}
...
@@ -328,14 +328,14 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern)
...
@@ -328,14 +328,14 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern)
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_new_null();
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_new_null();
if (!rval)
if (!rval)
{
{
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_
NewWithPattern
: call to new_null failed");
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_
Zeros
: call to new_null failed");
free(newdims);
free(newdims);
return NULL;
return NULL;
}
}
if (CudaNdarray_alloc_contiguous(rval,
pat
len, newdims))
if (CudaNdarray_alloc_contiguous(rval,
shp
len, newdims))
{
{
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_
NewWithPattern
: allocation failed.");
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_
Zeros
: allocation failed.");
free(newdims);
free(newdims);
Py_DECREF(rval);
Py_DECREF(rval);
return NULL;
return NULL;
...
@@ -352,16 +352,9 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern)
...
@@ -352,16 +352,9 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern)
return NULL;
return NULL;
}
}
// change the strides to account for broadcastability
// (not necessary as alloc_contiguous sets stride=0 for dim=1)
//for (int i = 0; i < patlen; ++i)
//{
// CudaNdarray_set_stride(rval, i, newstrides[i]);
//}
if (cnda_copy_structure_to_device(rval))
if (cnda_copy_structure_to_device(rval))
{
{
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_
NewWithPattern
: syncing structure to device failed");
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_
Zeros
: syncing structure to device failed");
free(newdims);
free(newdims);
Py_DECREF(rval);
Py_DECREF(rval);
return NULL;
return NULL;
...
@@ -707,9 +700,9 @@ static PyMethodDef CudaNdarray_methods[] =
...
@@ -707,9 +700,9 @@ static PyMethodDef CudaNdarray_methods[] =
{"__deepcopy__",
{"__deepcopy__",
(PyCFunction)CudaNdarray_DeepCopy, METH_O,
(PyCFunction)CudaNdarray_DeepCopy, METH_O,
"Create a copy of this object"},
"Create a copy of this object"},
{"zeros
_with_pattern
",
{"zeros",
(PyCFunction)CudaNdarray_Zeros
WithPattern
, METH_STATIC,
(PyCFunction)CudaNdarray_Zeros, METH_STATIC,
"Create a new CudaNdarray with specified shape
and broadcastability
, filled with zeros."},
"Create a new CudaNdarray with specified shape, filled with zeros."},
{"copy",
{"copy",
(PyCFunction)CudaNdarray_Copy, METH_NOARGS,
(PyCFunction)CudaNdarray_Copy, METH_NOARGS,
"Create a copy of this object"},
"Create a copy of this object"},
...
@@ -1331,7 +1324,6 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *v)
...
@@ -1331,7 +1324,6 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *v)
return -1;
return -1;
}
}
// Check that 'v' is compatible?
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_Subscript(o, key);
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_Subscript(o, key);
if(rval == NULL)
if(rval == NULL)
...
@@ -1349,14 +1341,38 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *v)
...
@@ -1349,14 +1341,38 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *v)
Py_DECREF(rval);
Py_DECREF(rval);
return -1;
return -1;
}
}
if (cnda_copy_structure_to_device(rval))
{
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_setitem: syncing structure to device failed");
Py_DECREF(rval);
return NULL;
}
CudaNdarray *viewCopyForComparison =
(CudaNdarray*)CudaNdarray_View(rval);
PyObject *baseSavedForComparison = rval->base;
if(!viewCopyForComparison)
{
PyErr_SetString(PyExc_RuntimeError, "__setitem__ could not allocate a view to verify copy results.");
Py_DECREF((PyObject*)rval);
return -1;
}
if(CudaNdarray_CopyFromCudaNdarray(rval, (CudaNdarray*)v))
if(CudaNdarray_CopyFromCudaNdarray(rval, (CudaNdarray*)v))
{
{
Py_DECREF(rval);
Py_DECREF(viewCopyForComparison);
Py_DECREF((PyObject*)rval);
return -1;
return -1;
}
}
// If it fails, deallocate memory (DECREF?)
// Check that copy didn't modify shape or strides
assert (CudaNdarray_EqualAndIgnore(viewCopyForComparison, rval, 1, 1));
assert (rval->base == baseSavedForComparison);
assert (rval->dev_structure_fresh);
Py_DECREF((PyObject*)viewCopyForComparison);
return 0;
return 0;
}
}
...
...
theano/sandbox/cuda/cuda_ndarray.cuh
浏览文件 @
9b441ceb
...
@@ -125,6 +125,61 @@ cnda_mark_dev_structure_dirty(CudaNdarray * self)
...
@@ -125,6 +125,61 @@ cnda_mark_dev_structure_dirty(CudaNdarray * self)
{
{
self->dev_structure_fresh = 0;
self->dev_structure_fresh = 0;
}
}
int
CudaNdarray_EqualAndIgnore(CudaNdarray *cnda1, CudaNdarray *cnda2, int ignoreSync, int ignoreBase)
{
int verbose = 1;
if (!ignoreSync && cnda1->dev_structure_fresh != cnda2->dev_structure_fresh)
{
if(verbose) fprintf(stdout, "CUDANDARRAY_EQUAL FAILED : 1\n");
return 0;
}
if (cnda1->nd != cnda2->nd)
{
if(verbose) fprintf(stdout, "CUDANDARRAY_EQUAL FAILED : 2\n");
return 0;
}
for (int i=0; i < 2*cnda1->nd; i++)
{
if (cnda1->host_structure[i] != cnda2->host_structure[i])
{
if(verbose)
fprintf(stdout, "CUDANDARRAY_EQUAL : host_structure : %d, %d, %d\n", i, cnda1->host_structure[i], cnda2->host_structure[i]);
return 0;
}
}
if (!ignoreBase && cnda1->base != cnda2->base)
{
if(verbose) fprintf(stdout, "CUDANDARRAY_EQUAL FAILED : 4");
return 0;
}
else if (cnda1->data_allocated != cnda2->data_allocated)
{
if(verbose) fprintf(stdout, "CUDANDARRAY_EQUAL FAILED : 5");
return 0;
}
else if (cnda1->data_allocated && cnda1->devdata != cnda2->devdata)
{
if(verbose) fprintf(stdout, "CUDANDARRAY_EQUAL FAILED : 6");
// no need to check devdata if data is not allocated
return 0;
}
return 1;
}
// Default: do not ignore sync of dev and host structures in comparing, and do not ignore difference in base pointers
int
CudaNdarray_Equal(CudaNdarray *cnda1, CudaNdarray *cnda2)
{
return CudaNdarray_EqualAndIgnore(cnda1, cnda2, 0, 0);
}
/****
/****
* Set the idx'th dimension to value d.
* Set the idx'th dimension to value d.
*
*
...
...
theano/sandbox/cuda/tests/test_cuda_ndarray.py
浏览文件 @
9b441ceb
...
@@ -382,12 +382,10 @@ def test_setitem_rightvalue_ndarray_fails():
...
@@ -382,12 +382,10 @@ def test_setitem_rightvalue_ndarray_fails():
assert
True
assert
True
'''
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_setitem_matrixvector1
()
test_setitem_matrixvector1
()
test_setitem_matrix_tensor3
()
test_setitem_matrix_tensor3
()
test_setitem_broadcast_must_fail
()
test_setitem_broadcast_must_fail
()
test_setitem_assign_to_slice
()
test_setitem_assign_to_slice
()
test_setitem_rightvalue_ndarray_fails
()
test_setitem_rightvalue_ndarray_fails
()
'''
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论