Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
627bd58e
提交
627bd58e
authored
5月 19, 2011
作者:
Frederic Bastien
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Removed old function CudaNdarray_new_null() as it is deprecated. Now use CudaNdarray_New()
上级
b5bdebb8
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
32 行增加
和
43 行删除
+32
-43
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+2
-2
blas.py
theano/sandbox/cuda/blas.py
+8
-8
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+12
-18
cuda_ndarray.cuh
theano/sandbox/cuda/cuda_ndarray.cuh
+2
-7
elemwise.py
theano/sandbox/cuda/elemwise.py
+2
-2
nnet.py
theano/sandbox/cuda/nnet.py
+6
-6
没有找到文件。
theano/sandbox/cuda/basic_ops.py
浏览文件 @
627bd58e
...
@@ -1934,7 +1934,7 @@ class GpuAlloc(Op):
...
@@ -1934,7 +1934,7 @@ class GpuAlloc(Op):
str
+=
"||CudaNdarray_HOST_DIMS(
%(out)
s)[
%(idx)
s]!=dims[
%(idx)
s]"
%
locals
()
str
+=
"||CudaNdarray_HOST_DIMS(
%(out)
s)[
%(idx)
s]!=dims[
%(idx)
s]"
%
locals
()
str
+=
"""){
str
+=
"""){
Py_XDECREF(
%(out)
s);
Py_XDECREF(
%(out)
s);
%(out)
s= (CudaNdarray*)CudaNdarray_
new_null
();
%(out)
s= (CudaNdarray*)CudaNdarray_
New
();
CudaNdarray_alloc_contiguous(
%(out)
s,
%(nd)
s, dims);
CudaNdarray_alloc_contiguous(
%(out)
s,
%(nd)
s, dims);
}
}
if (CudaNdarray_CopyFromCudaNdarray(
%(out)
s,
%(value)
s, true))
if (CudaNdarray_CopyFromCudaNdarray(
%(out)
s,
%(value)
s, true))
...
@@ -1952,7 +1952,7 @@ class GpuAlloc(Op):
...
@@ -1952,7 +1952,7 @@ class GpuAlloc(Op):
return
[
None
for
i
in
inputs
]
return
[
None
for
i
in
inputs
]
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
2
,)
return
(
3
,)
gpu_alloc
=
GpuAlloc
()
gpu_alloc
=
GpuAlloc
()
...
...
theano/sandbox/cuda/blas.py
浏览文件 @
627bd58e
...
@@ -22,7 +22,7 @@ class GpuDot22(Op):
...
@@ -22,7 +22,7 @@ class GpuDot22(Op):
return
Apply
(
self
,
[
x
,
y
],
[
x
.
type
()])
return
Apply
(
self
,
[
x
,
y
],
[
x
.
type
()])
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
1
,
0
)
return
(
1
,
1
)
def
c_code
(
self
,
node
,
nodename
,
inputs
,
outputs
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inputs
,
outputs
,
sub
):
x
,
y
=
inputs
x
,
y
=
inputs
...
@@ -48,7 +48,7 @@ class GpuDot22(Op):
...
@@ -48,7 +48,7 @@ class GpuDot22(Op):
npy_intp dims[2];
npy_intp dims[2];
dims[0] = CudaNdarray_HOST_DIMS(
%(x)
s)[0];
dims[0] = CudaNdarray_HOST_DIMS(
%(x)
s)[0];
dims[1] = CudaNdarray_HOST_DIMS(
%(y)
s)[1];
dims[1] = CudaNdarray_HOST_DIMS(
%(y)
s)[1];
%(z)
s = (CudaNdarray*)CudaNdarray_
new_null
();
%(z)
s = (CudaNdarray*)CudaNdarray_
New
();
if ((NULL ==
%(z)
s) || CudaNdarray_alloc_contiguous(
%(z)
s, 2, dims))
if ((NULL ==
%(z)
s) || CudaNdarray_alloc_contiguous(
%(z)
s, 2, dims))
{
{
if (
%(z)
s)
if (
%(z)
s)
...
@@ -90,7 +90,7 @@ class GpuDot22Scalar(Op):
...
@@ -90,7 +90,7 @@ class GpuDot22Scalar(Op):
return
Apply
(
self
,
[
x
,
y
,
a
],
[
x
.
type
()])
return
Apply
(
self
,
[
x
,
y
,
a
],
[
x
.
type
()])
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
1
,
0
)
return
(
1
,
1
)
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
x
,
y
,
a
=
inputs
x
,
y
,
a
=
inputs
...
@@ -122,7 +122,7 @@ class GpuDot22Scalar(Op):
...
@@ -122,7 +122,7 @@ class GpuDot22Scalar(Op):
npy_intp dims[2];
npy_intp dims[2];
dims[0] = CudaNdarray_HOST_DIMS(
%(x)
s)[0];
dims[0] = CudaNdarray_HOST_DIMS(
%(x)
s)[0];
dims[1] = CudaNdarray_HOST_DIMS(
%(y)
s)[1];
dims[1] = CudaNdarray_HOST_DIMS(
%(y)
s)[1];
%(z)
s = (CudaNdarray*)CudaNdarray_
new_null
();
%(z)
s = (CudaNdarray*)CudaNdarray_
New
();
if ((NULL ==
%(z)
s) || CudaNdarray_alloc_contiguous(
%(z)
s, 2, dims))
if ((NULL ==
%(z)
s) || CudaNdarray_alloc_contiguous(
%(z)
s, 2, dims))
{
{
if (
%(z)
s)
if (
%(z)
s)
...
@@ -436,7 +436,7 @@ class GpuDownsampleFactorMax(Op):
...
@@ -436,7 +436,7 @@ class GpuDownsampleFactorMax(Op):
#def perform(self, node, input_storage, output_storage):
#def perform(self, node, input_storage, output_storage):
#raise NotImplementedError('only C is implemented')
#raise NotImplementedError('only C is implemented')
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
2
)
return
(
3
)
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
x
,
=
inp
x
,
=
inp
z
,
=
out
z
,
=
out
...
@@ -473,7 +473,7 @@ class GpuDownsampleFactorMax(Op):
...
@@ -473,7 +473,7 @@ class GpuDownsampleFactorMax(Op):
|| (CudaNdarray_HOST_DIMS(
%(z)
s)[3] != dims[3]))
|| (CudaNdarray_HOST_DIMS(
%(z)
s)[3] != dims[3]))
{
{
Py_XDECREF(
%(z)
s);
Py_XDECREF(
%(z)
s);
%(z)
s = (CudaNdarray*)CudaNdarray_
new_null
();
%(z)
s = (CudaNdarray*)CudaNdarray_
New
();
if ((NULL ==
%(z)
s)
if ((NULL ==
%(z)
s)
|| CudaNdarray_alloc_contiguous(
%(z)
s, 4, dims))
|| CudaNdarray_alloc_contiguous(
%(z)
s, 4, dims))
{
{
...
@@ -588,7 +588,7 @@ class GpuDownsampleFactorMaxGrad(Op):
...
@@ -588,7 +588,7 @@ class GpuDownsampleFactorMaxGrad(Op):
return
Apply
(
self
,
[
x
,
z
,
gz
],
[
x
.
type
()])
return
Apply
(
self
,
[
x
,
z
,
gz
],
[
x
.
type
()])
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
#return ()
#return ()
return
(
4
,)
return
(
5
,)
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
x
,
z
,
gz
=
inp
x
,
z
,
gz
=
inp
...
@@ -611,7 +611,7 @@ class GpuDownsampleFactorMaxGrad(Op):
...
@@ -611,7 +611,7 @@ class GpuDownsampleFactorMaxGrad(Op):
|| (CudaNdarray_HOST_DIMS(
%(gx)
s)[3] != CudaNdarray_HOST_DIMS(
%(x)
s)[3]))
|| (CudaNdarray_HOST_DIMS(
%(gx)
s)[3] != CudaNdarray_HOST_DIMS(
%(x)
s)[3]))
{
{
Py_XDECREF(
%(gx)
s);
Py_XDECREF(
%(gx)
s);
%(gx)
s = (CudaNdarray*)CudaNdarray_
new_null
();
%(gx)
s = (CudaNdarray*)CudaNdarray_
New
();
if ((NULL ==
%(gx)
s)
if ((NULL ==
%(gx)
s)
|| CudaNdarray_alloc_contiguous(
%(gx)
s, 4, CudaNdarray_HOST_DIMS(
%(x)
s)))
|| CudaNdarray_alloc_contiguous(
%(gx)
s, 4, CudaNdarray_HOST_DIMS(
%(x)
s)))
{
{
...
...
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
627bd58e
...
@@ -350,7 +350,7 @@ PyObject* CudaNdarray_ZEROS(int n, int * dims)
...
@@ -350,7 +350,7 @@ PyObject* CudaNdarray_ZEROS(int n, int * dims)
// total_elements now contains the size of the array, in reals
// total_elements now contains the size of the array, in reals
int total_size = total_elements * sizeof(real);
int total_size = total_elements * sizeof(real);
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_
new_null
();
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_
New
();
if (!rval)
if (!rval)
{
{
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_ZEROS: call to new_null failed");
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_ZEROS: call to new_null failed");
...
@@ -448,7 +448,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
...
@@ -448,7 +448,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
PyObject * CudaNdarray_Copy(CudaNdarray * self)
PyObject * CudaNdarray_Copy(CudaNdarray * self)
{
{
PyObject * rval = CudaNdarray_
new_null
();
PyObject * rval = CudaNdarray_
New
();
if ((!rval) || (-1 == self->nd))
if ((!rval) || (-1 == self->nd))
{
{
return rval;
return rval;
...
@@ -509,7 +509,7 @@ PyObject * CudaNdarray_ReduceSum(CudaNdarray * self, PyObject * py_reduce_mask)
...
@@ -509,7 +509,7 @@ PyObject * CudaNdarray_ReduceSum(CudaNdarray * self, PyObject * py_reduce_mask)
PyErr_SetString(PyExc_TypeError, "length of reduce_mask must match self->nd");
PyErr_SetString(PyExc_TypeError, "length of reduce_mask must match self->nd");
return NULL;
return NULL;
}
}
CudaNdarray * self_sum = (CudaNdarray*)CudaNdarray_
new_null
();
CudaNdarray * self_sum = (CudaNdarray*)CudaNdarray_
New
();
if (!self_sum)
if (!self_sum)
{
{
return NULL;
return NULL;
...
@@ -666,9 +666,8 @@ PyObject * CudaNdarray_Reshape(CudaNdarray * self, PyObject * shape)
...
@@ -666,9 +666,8 @@ PyObject * CudaNdarray_Reshape(CudaNdarray * self, PyObject * shape)
}
}
// allocate new space (TODO: test to see if we can re-use old one)
// allocate new space (TODO: test to see if we can re-use old one)
CudaNdarray * rval = (CudaNdarray * )CudaNdarray_new_null();
CudaNdarray * rval = (CudaNdarray * )CudaNdarray_New();
if (!rval || CudaNdarray_alloc_contiguous(rval, rval_nd, rval_dims))
if (!rval || CudaNdarray_alloc_contiguous(rval, rval_nd, rval_dims)){
{
Py_XDECREF(rval);
Py_XDECREF(rval);
free(rval_dims);
free(rval_dims);
return NULL;
return NULL;
...
@@ -754,7 +753,7 @@ PyObject * CudaNdarray_SetShapeI(CudaNdarray * self, PyObject *args)
...
@@ -754,7 +753,7 @@ PyObject * CudaNdarray_SetShapeI(CudaNdarray * self, PyObject *args)
static PyObject *
static PyObject *
CudaNdarray_exp(CudaNdarray* self)
CudaNdarray_exp(CudaNdarray* self)
{
{
CudaNdarray * rval = (CudaNdarray *)CudaNdarray_
new_null
();
CudaNdarray * rval = (CudaNdarray *)CudaNdarray_
New
();
if ((NULL == rval) || CudaNdarray_alloc_contiguous(rval, self->nd, CudaNdarray_HOST_DIMS(self)))
if ((NULL == rval) || CudaNdarray_alloc_contiguous(rval, self->nd, CudaNdarray_HOST_DIMS(self)))
{
{
Py_XDECREF(rval);
Py_XDECREF(rval);
...
@@ -872,7 +871,7 @@ CudaNdarray_add(PyObject* py_self, PyObject * py_other)
...
@@ -872,7 +871,7 @@ CudaNdarray_add(PyObject* py_self, PyObject * py_other)
}
}
size *= (unsigned int) CudaNdarray_HOST_DIMS(self)[i];
size *= (unsigned int) CudaNdarray_HOST_DIMS(self)[i];
}
}
CudaNdarray * rval = (CudaNdarray *)CudaNdarray_
new_null
();
CudaNdarray * rval = (CudaNdarray *)CudaNdarray_
New
();
if (!rval || CudaNdarray_alloc_contiguous(rval, self->nd, CudaNdarray_HOST_DIMS(self)))
if (!rval || CudaNdarray_alloc_contiguous(rval, self->nd, CudaNdarray_HOST_DIMS(self)))
{
{
Py_XDECREF(rval);
Py_XDECREF(rval);
...
@@ -2061,7 +2060,7 @@ CudaNdarray_from_gpu_pointer(PyObject* _unused, PyObject* args)
...
@@ -2061,7 +2060,7 @@ CudaNdarray_from_gpu_pointer(PyObject* _unused, PyObject* args)
return NULL;
return NULL;
}
}
rval = CudaNdarray_
new_null
();
rval = CudaNdarray_
New
();
if (CudaNdarray_set_nd((CudaNdarray *)rval, nd))
if (CudaNdarray_set_nd((CudaNdarray *)rval, nd))
{
{
...
@@ -2136,7 +2135,7 @@ CudaNdarray_Dot(PyObject* _unused, PyObject* args)
...
@@ -2136,7 +2135,7 @@ CudaNdarray_Dot(PyObject* _unused, PyObject* args)
PyErr_SetString(PyExc_TypeError, "need 2d CudaNdarray arg for now");
PyErr_SetString(PyExc_TypeError, "need 2d CudaNdarray arg for now");
goto CudaNdarray_dot_fail;
goto CudaNdarray_dot_fail;
}
}
rval = CudaNdarray_
new_null
();
rval = CudaNdarray_
New
();
if (!rval)
if (!rval)
{
{
goto CudaNdarray_dot_fail;
goto CudaNdarray_dot_fail;
...
@@ -2246,7 +2245,7 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s
...
@@ -2246,7 +2245,7 @@ filter(PyObject* __unsed_self, PyObject *args) // args = (data, broadcastable, s
}
}
else
else
{
{
rval = (CudaNdarray*) CudaNdarray_
new_null
();
rval = (CudaNdarray*) CudaNdarray_
New
();
}
}
if (rval)
if (rval)
{
{
...
@@ -2450,16 +2449,11 @@ CudaNdarray_is_c_contiguous(const CudaNdarray * self)
...
@@ -2450,16 +2449,11 @@ CudaNdarray_is_c_contiguous(const CudaNdarray * self)
}
}
return c_contiguous;
return c_contiguous;
}
}
PyObject *
CudaNdarray_new_null()
{
//TODO: this function is deprecated... do not use. Consider removing.
return CudaNdarray_New(-1);
}
PyObject *
PyObject *
CudaNdarray_new_nd(int nd)
CudaNdarray_new_nd(int nd)
{
{
CudaNdarray * rval = (CudaNdarray*) CudaNdarray_
new_null
();
CudaNdarray * rval = (CudaNdarray*) CudaNdarray_
New
();
if (!rval || CudaNdarray_set_nd(rval, nd))
if (!rval || CudaNdarray_set_nd(rval, nd))
{
{
Py_XDECREF(rval);
Py_XDECREF(rval);
...
...
theano/sandbox/cuda/cuda_ndarray.cuh
浏览文件 @
627bd58e
...
@@ -81,7 +81,7 @@ struct CudaNdarray
...
@@ -81,7 +81,7 @@ struct CudaNdarray
* Return a CudaNdarray whose 'nd' dimensions are all 0.
* Return a CudaNdarray whose 'nd' dimensions are all 0.
*/
*/
PyObject *
PyObject *
CudaNdarray_New(int nd);
CudaNdarray_New(int nd
=-1
);
/**
/**
* Return 1 for a CudaNdarray otw 0
* Return 1 for a CudaNdarray otw 0
...
@@ -296,11 +296,6 @@ CudaNdarray_SIZE_Object(const CudaNdarray *self, void *closure)
...
@@ -296,11 +296,6 @@ CudaNdarray_SIZE_Object(const CudaNdarray *self, void *closure)
}
}
/**
* Allocate a new CudaNdarray with nd==-1
*/
PyObject * CudaNdarray_new_null();
/**
/**
* Allocate a new CudaNdarray with room for given number of dimensions
* Allocate a new CudaNdarray with room for given number of dimensions
*
*
...
@@ -424,7 +419,7 @@ template<typename inttype>
...
@@ -424,7 +419,7 @@ template<typename inttype>
PyObject *
PyObject *
CudaNdarray_NewDims(int nd, const inttype * dims)
CudaNdarray_NewDims(int nd, const inttype * dims)
{
{
CudaNdarray * rval = (CudaNdarray*)CudaNdarray_
new_null
();
CudaNdarray * rval = (CudaNdarray*)CudaNdarray_
New
();
if (rval)
if (rval)
{
{
if (CudaNdarray_alloc_contiguous(rval, nd, dims))
if (CudaNdarray_alloc_contiguous(rval, nd, dims))
...
...
theano/sandbox/cuda/elemwise.py
浏览文件 @
627bd58e
...
@@ -37,7 +37,7 @@ def get_str_list_logical_scalar(node, value_str='ii_i%i_value', data_str='ii_i%i
...
@@ -37,7 +37,7 @@ def get_str_list_logical_scalar(node, value_str='ii_i%i_value', data_str='ii_i%i
class
NaiveAlgo
(
object
):
class
NaiveAlgo
(
object
):
verbose
=
0
# 1, 2 or 3 for more verbose output.
verbose
=
0
# 1, 2 or 3 for more verbose output.
cache_version
=
()
cache_version
=
()
cache_version
=
(
'debug'
,
1
3
,
verbose
)
cache_version
=
(
'debug'
,
1
4
,
verbose
)
def
__init__
(
self
,
scalar_op
,
sync
=
True
,
inplace_pattern
=
{}):
def
__init__
(
self
,
scalar_op
,
sync
=
True
,
inplace_pattern
=
{}):
"""
"""
...
@@ -888,7 +888,7 @@ nd_collapse_[i]=0;
...
@@ -888,7 +888,7 @@ nd_collapse_[i]=0;
}
}
if (NULL ==
%(oname)
s)
if (NULL ==
%(oname)
s)
{
{
%(oname)
s = (CudaNdarray*)CudaNdarray_
new_null
();
%(oname)
s = (CudaNdarray*)CudaNdarray_
New
();
if (!
%(oname)
s)
if (!
%(oname)
s)
{
{
//error string already set
//error string already set
...
...
theano/sandbox/cuda/nnet.py
浏览文件 @
627bd58e
...
@@ -191,7 +191,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx (Op):
...
@@ -191,7 +191,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx (Op):
def
make_node
(
self
,
dy
,
sm
,
y_idx
):
def
make_node
(
self
,
dy
,
sm
,
y_idx
):
return
Apply
(
self
,
[
dy
,
sm
,
y_idx
],[
sm
.
type
()])
return
Apply
(
self
,
[
dy
,
sm
,
y_idx
],[
sm
.
type
()])
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
3
,)
return
(
4
,)
#return ()
#return ()
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
dnll
,
sm
,
y_idx
=
inp
dnll
,
sm
,
y_idx
=
inp
...
@@ -221,7 +221,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx (Op):
...
@@ -221,7 +221,7 @@ class GpuCrossentropySoftmax1HotWithBiasDx (Op):
|| (CudaNdarray_HOST_DIMS(
%(dx)
s)[1] != CudaNdarray_HOST_DIMS(
%(sm)
s)[1]))
|| (CudaNdarray_HOST_DIMS(
%(dx)
s)[1] != CudaNdarray_HOST_DIMS(
%(sm)
s)[1]))
{
{
Py_XDECREF(
%(dx)
s);
Py_XDECREF(
%(dx)
s);
%(dx)
s = (CudaNdarray*)CudaNdarray_
new_null
();
%(dx)
s = (CudaNdarray*)CudaNdarray_
New
();
if ((NULL ==
%(dx)
s)
if ((NULL ==
%(dx)
s)
|| CudaNdarray_alloc_contiguous(
%(dx)
s, 2, CudaNdarray_HOST_DIMS(
%(sm)
s)))
|| CudaNdarray_alloc_contiguous(
%(dx)
s, 2, CudaNdarray_HOST_DIMS(
%(sm)
s)))
{
{
...
@@ -309,7 +309,7 @@ class GpuSoftmax (Op):
...
@@ -309,7 +309,7 @@ class GpuSoftmax (Op):
return
shape
return
shape
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
#return ()
#return ()
return
(
2
,)
+
inline_softmax
.
code_version
return
(
3
,)
+
inline_softmax
.
code_version
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
x
,
=
inp
x
,
=
inp
z
,
=
out
z
,
=
out
...
@@ -325,7 +325,7 @@ class GpuSoftmax (Op):
...
@@ -325,7 +325,7 @@ class GpuSoftmax (Op):
|| (CudaNdarray_HOST_DIMS(
%(z)
s)[1] != CudaNdarray_HOST_DIMS(
%(x)
s)[1]))
|| (CudaNdarray_HOST_DIMS(
%(z)
s)[1] != CudaNdarray_HOST_DIMS(
%(x)
s)[1]))
{
{
Py_XDECREF(
%(z)
s);
Py_XDECREF(
%(z)
s);
%(z)
s = (CudaNdarray*)CudaNdarray_
new_null
();
%(z)
s = (CudaNdarray*)CudaNdarray_
New
();
if ((NULL ==
%(z)
s)
if ((NULL ==
%(z)
s)
|| CudaNdarray_alloc_contiguous(
%(z)
s, 2, CudaNdarray_HOST_DIMS(
%(x)
s)))
|| CudaNdarray_alloc_contiguous(
%(z)
s, 2, CudaNdarray_HOST_DIMS(
%(x)
s)))
{
{
...
@@ -398,7 +398,7 @@ class GpuSoftmaxWithBias (Op):
...
@@ -398,7 +398,7 @@ class GpuSoftmaxWithBias (Op):
return
[
shape
[
0
]]
return
[
shape
[
0
]]
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
#return ()
#return ()
return
(
2
,)
+
inline_softmax
.
code_version
return
(
3
,)
+
inline_softmax
.
code_version
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
x
,
b
=
inp
x
,
b
=
inp
...
@@ -426,7 +426,7 @@ class GpuSoftmaxWithBias (Op):
...
@@ -426,7 +426,7 @@ class GpuSoftmaxWithBias (Op):
|| (CudaNdarray_HOST_DIMS(
%(z)
s)[1] != CudaNdarray_HOST_DIMS(
%(x)
s)[1]))
|| (CudaNdarray_HOST_DIMS(
%(z)
s)[1] != CudaNdarray_HOST_DIMS(
%(x)
s)[1]))
{
{
Py_XDECREF(
%(z)
s);
Py_XDECREF(
%(z)
s);
%(z)
s = (CudaNdarray*)CudaNdarray_
new_null
();
%(z)
s = (CudaNdarray*)CudaNdarray_
New
();
if ((NULL ==
%(z)
s)
if ((NULL ==
%(z)
s)
|| CudaNdarray_alloc_contiguous(
%(z)
s, 2, CudaNdarray_HOST_DIMS(
%(x)
s)))
|| CudaNdarray_alloc_contiguous(
%(z)
s, 2, CudaNdarray_HOST_DIMS(
%(x)
s)))
{
{
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论