Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2345e188
提交
2345e188
authored
9月 25, 2013
作者:
Pascal Lamblin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1536 from nouiz/doc
Mixed small stuff
上级
aa7ef7fd
383edfbc
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
116 行增加
和
24 行删除
+116
-24
.travis.yml
.travis.yml
+1
-1
index.txt
doc/library/sparse/index.txt
+12
-4
nnet.txt
doc/library/tensor/nnet/nnet.txt
+14
-0
cuda_ndarray.cu
theano/sandbox/cuda/cuda_ndarray.cu
+70
-12
ops.py
theano/sandbox/linalg/ops.py
+2
-0
nnet.py
theano/tensor/nnet/nnet.py
+2
-5
test_nnet.py
theano/tensor/nnet/tests/test_nnet.py
+15
-2
没有找到文件。
.travis.yml
浏览文件 @
2345e188
...
@@ -35,7 +35,7 @@ script:
...
@@ -35,7 +35,7 @@ script:
-
df -h
-
df -h
-
ulimit -a
-
ulimit -a
-
echo $PART
-
echo $PART
-
theano-nose $PART
-
theano-nose
-v
$PART
#after_script:
#after_script:
...
...
doc/library/sparse/index.txt
浏览文件 @
2345e188
...
@@ -197,18 +197,21 @@ List of Implemented Operations
...
@@ -197,18 +197,21 @@ List of Implemented Operations
- :class:`Dot <theano.sparse.basic.Dot>` and
- :class:`Dot <theano.sparse.basic.Dot>` and
:func:`dot <theano.sparse.basic.dot>`.
:func:`dot <theano.sparse.basic.dot>`.
- One of the inputs must be sparse, the other sparse or dense.
- The grad implemented is regular.
- The grad implemented is regular.
- No C code for perform and no C code for grad.
- No C code for perform and no C code for grad.
- Return a dense for perform and a dense for grad.
- Return a dense for perform and a dense for grad.
- :class:`StructuredDot <theano.sparse.basic.StructuredDot>`
- :class:`StructuredDot <theano.sparse.basic.StructuredDot>`
and :func:`structured_dot <theano.sparse.basic.structured_dot>`.
and :func:`structured_dot <theano.sparse.basic.structured_dot>`.
- The first input is sparse, the second can be sparse or dense.
- The grad implemented is structured.
- The grad implemented is structured.
- C code for perform and grad.
- C code for perform and grad.
- Return a dense for perforn and a sparse for grad.
- Return a dense for perforn and a sparse for grad.
- :class:`TrueDot <theano.sparse.basic.TrueDot>` and
- :class:`TrueDot <theano.sparse.basic.TrueDot>` and
:func:`true_dot <theano.sparse.basic.true_dot>`.
:func:`true_dot <theano.sparse.basic.true_dot>`.
- The first input is sparse, the second can be sparse or dense.
- The grad implemented is regular.
- The grad implemented is regular.
- No C code for perform and no C code for grad.
- No C code for perform and no C code for grad.
- Return a Sparse for perform and a Sparse for grad.
- Return a Sparse for perform and a Sparse for grad.
...
@@ -217,17 +220,22 @@ List of Implemented Operations
...
@@ -217,17 +220,22 @@ List of Implemented Operations
- :class:`SamplingDot <theano.sparse.basic.SamplingDot>` and
- :class:`SamplingDot <theano.sparse.basic.SamplingDot>` and
``sampling_dot``.
``sampling_dot``.
- Both input must be dense.
- The grad implemented is structured for `p`.
- The grad implemented is structured for `p`.
- Sample of the dot and sample of the gradient.
- Sample of the dot and sample of the gradient.
- C code for perform but not for grad.
- C code for perform but not for grad.
- Return sparse for perform and grad.
- Return sparse for perform and grad.
- :class:`Usmm <theano.sparse.basic.Usmm>` and ``usmm``.
- :class:`Usmm <theano.sparse.basic.Usmm>` and ``usmm``.
- You *shouldn't* insert this op yourself!
- There is optimization that transform a
:class:`Dot <theano.sparse.basic.Dot>` to ``Usmm`` when possible.
- This op is the equivalent of gemm for sparse dot.
- This op is the equivalent of gemm for sparse dot.
- There is no grad implemented for this op
.
- There is no grad implemented for this op
and this is not needed as
- There is optimization that transform a
you don't insert it yourself.
:class:`Dot <theano.sparse.basic.Dot>` to ``Usmm`` when possibl
e.
- One of the inputs must be sparse, the other sparse or dens
e.
You shouldn't need to insert it yourself.
- Return a dense for perform
- Slice Operations
- Slice Operations
- sparse_variable[N, N], return a tensor scalar.
- sparse_variable[N, N], return a tensor scalar.
...
...
doc/library/tensor/nnet/nnet.txt
浏览文件 @
2345e188
...
@@ -116,6 +116,20 @@
...
@@ -116,6 +116,20 @@
The softmax function will, when applied to a matrix, compute the softmax values row-wise.
The softmax function will, when applied to a matrix, compute the softmax values row-wise.
:note: this insert a particular op. But this op don't yet
implement the Rop for hessian free. If you want that, implement
this equivalent code that have the Rop implemented
``exp(x)/exp(x).sum(1, keep_dims=True)``. Theano should
optimize this by inserting the softmax op itself. The code of
the softmax op is more numeriacaly stable by using this code:
.. code-block:: python
e_x = exp(x - x.max(axis=1, keep_dims=True))
out = e_x / e_x.sum(axis=1, keep_dims=True)
Example of use:
.. code-block:: python
.. code-block:: python
x,y,b = T.dvectors('x','y','b')
x,y,b = T.dvectors('x','y','b')
...
...
theano/sandbox/cuda/cuda_ndarray.cu
浏览文件 @
2345e188
...
@@ -64,6 +64,14 @@ void * device_malloc(size_t size)
...
@@ -64,6 +64,14 @@ void * device_malloc(size_t size)
void
*
device_malloc
(
size_t
size
,
int
verbose
)
void
*
device_malloc
(
size_t
size
,
int
verbose
)
{
{
#if PRECHECK_ERROR
cudaThreadSynchronize
();
cudaError_t
prevError
=
cudaGetLastError
();
if
(
cudaSuccess
!=
prevError
)
{
fprintf
(
stderr
,
"Error existed before calling device_malloc.
\n
"
);
}
#endif
void
*
rval
=
NULL
;
void
*
rval
=
NULL
;
cudaError_t
err
=
cudaMalloc
(
&
rval
,
size
);
cudaError_t
err
=
cudaMalloc
(
&
rval
,
size
);
if
(
cudaSuccess
!=
err
)
if
(
cudaSuccess
!=
err
)
...
@@ -81,7 +89,7 @@ void * device_malloc(size_t size, int verbose)
...
@@ -81,7 +89,7 @@ void * device_malloc(size_t size, int verbose)
cudaGetLastError
();
cudaGetLastError
();
fprintf
(
stderr
,
fprintf
(
stderr
,
"Error when tring to find the memory information"
"Error when tring to find the memory information"
" on the GPU
\n
"
);
" on the GPU
: %s
\n
"
,
cudaGetErrorString
(
err2
)
);
}
}
#if COMPUTE_GPU_MEM_USED
#if COMPUTE_GPU_MEM_USED
fprintf
(
stderr
,
fprintf
(
stderr
,
...
@@ -98,7 +106,8 @@ void * device_malloc(size_t size, int verbose)
...
@@ -98,7 +106,8 @@ void * device_malloc(size_t size, int verbose)
#endif
#endif
}
}
PyErr_Format
(
PyExc_MemoryError
,
PyErr_Format
(
PyExc_MemoryError
,
"Error allocating %li bytes of device memory (%s)."
,
(
long
)
size
,
cudaGetErrorString
(
err
));
"Error allocating %li bytes of device memory (%s)."
,
(
long
)
size
,
cudaGetErrorString
(
err
));
return
NULL
;
return
NULL
;
}
}
if
(
rval
!=
NULL
){
if
(
rval
!=
NULL
){
...
@@ -109,14 +118,19 @@ void * device_malloc(size_t size, int verbose)
...
@@ -109,14 +118,19 @@ void * device_malloc(size_t size, int verbose)
#if COMPUTE_GPU_MEM_USED
#if COMPUTE_GPU_MEM_USED
_allocated_size
+=
size
;
_allocated_size
+=
size
;
_max_allocated_size
=
std
::
max
(
_max_allocated_size
,
_allocated_size
);
_max_allocated_size
=
std
::
max
(
_max_allocated_size
,
_allocated_size
);
int
i
=
0
;
for
(
int
i
=
0
;
i
<
TABLE_SIZE
;
i
++
){
for
(;
i
<
TABLE_SIZE
;
i
++
){
if
(
NULL
==
_alloc_size_table
[
i
].
ptr
){
if
(
NULL
==
_alloc_size_table
[
i
].
ptr
){
_alloc_size_table
[
i
].
ptr
=
rval
;
_alloc_size_table
[
i
].
ptr
=
rval
;
_alloc_size_table
[
i
].
size
=
size
;
_alloc_size_table
[
i
].
size
=
size
;
break
;
break
;
}
}
}
}
if
(
i
==
TABLE_SIZE
){
fprintf
(
stderr
,
"When tracking GPU malloc, our table size wasn't big enough."
" So we loose some tracking. Raise the value of TABLE_SIZE in the file cuda_ndarra.cu"
);
}
#endif
#endif
}
}
//fprintf(stderr,
//fprintf(stderr,
...
@@ -129,23 +143,48 @@ void * device_malloc(size_t size, int verbose)
...
@@ -129,23 +143,48 @@ void * device_malloc(size_t size, int verbose)
//printf("MEMSET\n");
//printf("MEMSET\n");
}
}
#if PRINT_FREE_MALLOC
#if PRINT_FREE_MALLOC
fprintf
(
stderr
,
"device malloc %p
\n
"
,
rval
);
fprintf
(
stderr
,
"device malloc %p
of size %d
\n
"
,
rval
,
size
);
#endif
#endif
return
rval
;
return
rval
;
}
}
int
device_free
(
void
*
ptr
)
int
device_free
(
void
*
ptr
)
{
{
#if PRINT_FREE_MALLOC
fprintf
(
stderr
,
"device_free %p
\n
"
,
ptr
);
#endif
#if PRECHECK_ERROR
#if PRECHECK_ERROR
cudaThreadSynchronize
();
cudaError_t
prevError
=
cudaGetLastError
();
cudaError_t
prevError
=
cudaGetLastError
();
if
(
cudaSuccess
!=
prevError
)
if
(
cudaSuccess
!=
prevError
)
{
{
fprintf
(
stderr
,
"Error existed before calling device_free.
\n
"
);
fprintf
(
stderr
,
"Error existed before calling device_free.
\n
"
);
}
}
#endif
#endif
#if PRINT_FREE_MALLOC
size_t
free
=
0
,
total
=
0
;
cudaError_t
err2
=
cudaMemGetInfo
(
&
free
,
&
total
);
if
(
err2
!=
cudaSuccess
){
cudaGetLastError
();
fprintf
(
stderr
,
"Error when tring to find the memory information"
" on the GPU: %s
\n
"
,
cudaGetErrorString
(
err2
));
}
#if COMPUTE_GPU_MEM_USED
{
int
i
=
0
;
for
(;
i
<
TABLE_SIZE
;
i
++
)
if
(
_alloc_size_table
[
i
].
ptr
==
ptr
){
break
;
}
assert
(
i
<
TABLE_SIZE
);
fprintf
(
stderr
,
"device_free %p of size %d."
" Driver report %d bytes free and %d bytes total
\n
"
,
ptr
,
_alloc_size_table
[
i
].
size
,
free
,
total
);
}
#else
fprintf
(
stderr
,
"device_free %p."
" Driver report %d bytes free and %d bytes total
\n
"
,
ptr
,
free
,
total
);
#endif
#endif
// if there is no gpu context, the call to cudaFree will fail; skip it entirely
// if there is no gpu context, the call to cudaFree will fail; skip it entirely
if
(
!
g_gpu_context_active
)
{
if
(
!
g_gpu_context_active
)
{
...
@@ -164,15 +203,34 @@ int device_free(void *ptr)
...
@@ -164,15 +203,34 @@ int device_free(void *ptr)
// it returns something else I still don't see why we should ignore
// it returns something else I still don't see why we should ignore
// it. All we want to do here is reset the flag.
// it. All we want to do here is reset the flag.
cudaGetLastError
();
cudaGetLastError
();
size_t
free
=
0
,
total
=
0
;
cudaError_t
err2
=
cudaMemGetInfo
(
&
free
,
&
total
);
if
(
err2
!=
cudaSuccess
){
cudaGetLastError
();
fprintf
(
stderr
,
"Error when tring to find the memory information"
" on the GPU: %s
\n
"
,
cudaGetErrorString
(
err2
));
}
#if COMPUTE_GPU_MEM_USED
#if COMPUTE_GPU_MEM_USED
{
int
i
=
0
;
for
(;
i
<
TABLE_SIZE
;
i
++
)
if
(
_alloc_size_table
[
i
].
ptr
==
ptr
){
break
;
}
assert
(
i
<
TABLE_SIZE
);
fprintf
(
stderr
,
fprintf
(
stderr
,
"Error freeing device pointer %p (%s).%d byte already allocated
\n
"
,
"Error freeing device pointer %p (%s) of size %d. %d byte already allocated."
ptr
,
cudaGetErrorString
(
err
),
_allocated_size
);
" Driver report %d bytes free and %d bytes total
\n
"
,
ptr
,
cudaGetErrorString
(
err
),
_alloc_size_table
[
i
].
size
,
_allocated_size
,
free
,
total
);
}
#else
#else
fprintf
(
stderr
,
fprintf
(
stderr
,
"Error freeing device pointer %p (%s).
\n
"
,
"Error freeing device pointer %p (%s)."
" Driver report %d bytes free and %d bytes total
\n
"
,
ptr
,
ptr
,
cudaGetErrorString
(
err
));
cudaGetErrorString
(
err
)
,
free
,
total
);
#endif
#endif
PyErr_Format
(
PyExc_MemoryError
,
PyErr_Format
(
PyExc_MemoryError
,
"error freeing device pointer %p (%s)"
,
"error freeing device pointer %p (%s)"
,
...
...
theano/sandbox/linalg/ops.py
浏览文件 @
2345e188
...
@@ -963,6 +963,7 @@ class Eigh(Eig):
...
@@ -963,6 +963,7 @@ class Eigh(Eig):
_numop
=
staticmethod
(
numpy
.
linalg
.
eigh
)
_numop
=
staticmethod
(
numpy
.
linalg
.
eigh
)
def
__init__
(
self
,
UPLO
=
'L'
):
def
__init__
(
self
,
UPLO
=
'L'
):
assert
UPLO
in
[
'L'
,
'U'
]
self
.
UPLO
=
UPLO
self
.
UPLO
=
UPLO
def
__str__
(
self
):
def
__str__
(
self
):
...
@@ -1031,6 +1032,7 @@ class EighGrad(Op):
...
@@ -1031,6 +1032,7 @@ class EighGrad(Op):
"""
"""
def
__init__
(
self
,
UPLO
=
'L'
):
def
__init__
(
self
,
UPLO
=
'L'
):
assert
UPLO
in
[
'L'
,
'U'
]
self
.
UPLO
=
UPLO
self
.
UPLO
=
UPLO
if
UPLO
==
'L'
:
if
UPLO
==
'L'
:
self
.
tri0
=
numpy
.
tril
self
.
tri0
=
numpy
.
tril
...
...
theano/tensor/nnet/nnet.py
浏览文件 @
2345e188
...
@@ -360,11 +360,8 @@ class Softmax(gof.Op):
...
@@ -360,11 +360,8 @@ class Softmax(gof.Op):
def
perform
(
self
,
node
,
input_storage
,
output_storage
):
def
perform
(
self
,
node
,
input_storage
,
output_storage
):
x
,
=
input_storage
x
,
=
input_storage
sm
=
numpy
.
zeros_like
(
x
)
e_x
=
numpy
.
exp
(
x
-
x
.
max
(
axis
=
1
)[:,
None
])
for
i
in
xrange
(
sm
.
shape
[
0
]):
sm
=
e_x
/
e_x
.
sum
(
axis
=
1
)[:,
None
]
row
=
x
[
i
]
sm
[
i
]
=
numpy
.
exp
(
row
-
numpy
.
max
(
row
))
sm
[
i
]
/=
numpy
.
sum
(
sm
[
i
])
output_storage
[
0
][
0
]
=
sm
output_storage
[
0
][
0
]
=
sm
def
grad
(
self
,
inp
,
grads
):
def
grad
(
self
,
inp
,
grads
):
...
...
theano/tensor/nnet/tests/test_nnet.py
浏览文件 @
2345e188
...
@@ -8,9 +8,8 @@ from theano import config
...
@@ -8,9 +8,8 @@ from theano import config
from
theano
import
tensor
as
T
from
theano
import
tensor
as
T
from
theano
import
tensor
from
theano
import
tensor
from
theano
import
gof
from
theano
import
gof
from
theano.gof.python25
import
all
from
theano.tests
import
unittest_tools
as
utt
from
theano.tests
import
unittest_tools
as
utt
from
theano
import
printing
,
pprint
from
theano
import
printing
from
theano.tensor.nnet
import
(
categorical_crossentropy
,
from
theano.tensor.nnet
import
(
categorical_crossentropy
,
crossentropy_categorical_1hot
,
crossentropy_categorical_1hot
,
crossentropy_softmax_1hot
,
crossentropy_softmax_1hot
,
...
@@ -1270,6 +1269,20 @@ class Test_softmax_opt:
...
@@ -1270,6 +1269,20 @@ class Test_softmax_opt:
assert
softmax
in
f_ops
assert
softmax
in
f_ops
f
(
self
.
rng
.
rand
(
3
,
4
)
.
astype
(
config
.
floatX
))
f
(
self
.
rng
.
rand
(
3
,
4
)
.
astype
(
config
.
floatX
))
def
test_basic_keepdims
(
self
):
c
=
T
.
matrix
()
p_y
=
T
.
exp
(
c
)
/
T
.
exp
(
c
)
.
sum
(
axis
=
1
,
keepdims
=
True
)
# test that function contains softmax and no div.
f
=
theano
.
function
([
c
],
p_y
,
mode
=
self
.
mode
)
f_ops
=
[
n
.
op
for
n
in
f
.
maker
.
fgraph
.
toposort
()]
#print '--- f ='
#printing.debugprint(f)
#print '==='
assert
len
(
f_ops
)
==
1
assert
softmax
in
f_ops
f
(
self
.
rng
.
rand
(
3
,
4
)
.
astype
(
config
.
floatX
))
def
test_grad
(
self
):
def
test_grad
(
self
):
c
=
T
.
matrix
()
c
=
T
.
matrix
()
p_y
=
T
.
exp
(
c
)
/
T
.
exp
(
c
)
.
sum
(
axis
=
1
)
.
dimshuffle
(
0
,
'x'
)
p_y
=
T
.
exp
(
c
)
/
T
.
exp
(
c
)
.
sum
(
axis
=
1
)
.
dimshuffle
(
0
,
'x'
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论