Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8215457d
提交
8215457d
authored
2月 17, 2016
作者:
Frédéric Bastien
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4067 from abergeron/debugmode_empty
Make DebugMode handle a special version of perform.
上级
189069be
d3530b07
隐藏空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
50 行增加
和
14 行删除
+50
-14
op.txt
doc/extending/op.txt
+13
-0
debugmode.py
theano/compile/debugmode.py
+9
-7
cc.py
theano/gof/cc.py
+7
-5
op.py
theano/gof/op.py
+6
-2
basic_ops.py
theano/sandbox/cuda/basic_ops.py
+7
-0
basic_ops.py
theano/sandbox/gpuarray/basic_ops.py
+4
-0
basic.py
theano/tensor/basic.py
+4
-0
__init__.py
theano/tensor/deprecated/__init__.py
+0
-0
没有找到文件。
doc/extending/op.txt
浏览文件 @
8215457d
...
@@ -266,6 +266,19 @@ Optional methods or attributes
...
@@ -266,6 +266,19 @@ Optional methods or attributes
As done in the Alloc op, you can return False only in some cases by
As done in the Alloc op, you can return False only in some cases by
analyzing the graph from the node parameter.
analyzing the graph from the node parameter.
.. function:: debug_perform(node, inputs, output_storage)
Undefined by default.
If you define this function then it will be used instead of C code
or perform() to do the computation while debugging (currently
DebugMode, but others may also use it in the future). It has the
same signature and contract as :func:`perform`.
This enables ops that cause trouble with DebugMode with their
normal behaviour to adopt a different one when run under that
mode. If your op doesn't have any problems, don't implement this.
If you want your op to work with gradient.grad() you also need to
If you want your op to work with gradient.grad() you also need to
implement the functions described below.
implement the functions described below.
...
...
theano/compile/debugmode.py
浏览文件 @
8215457d
...
@@ -1849,8 +1849,10 @@ class _Linker(gof.link.LocalLinker):
...
@@ -1849,8 +1849,10 @@ class _Linker(gof.link.LocalLinker):
if
new_node
is
not
None
:
if
new_node
is
not
None
:
node
=
new_node
node
=
new_node
debug
=
hasattr
(
node
.
op
,
'debug_perform'
)
try
:
try
:
if
not
self
.
maker
.
mode
.
check_c_code
:
if
not
self
.
maker
.
mode
.
check_c_code
or
debug
:
raise
utils
.
MethodNotDefined
()
raise
utils
.
MethodNotDefined
()
# Ops that do not inherit from gof.op.Op don't have certain
# Ops that do not inherit from gof.op.Op don't have certain
# methods defined that the CLinker expects (Scan is an
# methods defined that the CLinker expects (Scan is an
...
@@ -1868,18 +1870,18 @@ class _Linker(gof.link.LocalLinker):
...
@@ -1868,18 +1870,18 @@ class _Linker(gof.link.LocalLinker):
# Pure ops don't really have a perform ( or their perform just
# Pure ops don't really have a perform ( or their perform just
# raises an not implemented exception), so in those cases we
# raises an not implemented exception), so in those cases we
# consider that we don't have a python implementation
# consider that we don't have a python implementation
if
(
self
.
maker
.
mode
.
check_py_code
or
thunks_c
[
-
1
]
is
None
)
and
\
if
(((
self
.
maker
.
mode
.
check_py_code
or
thunks_c
[
-
1
]
is
None
)
and
node
.
op
.
perform
.
__code__
!=
gof
.
op
.
PureOp
.
perform
.
__code__
:
node
.
op
.
perform
.
__code__
!=
gof
.
op
.
PureOp
.
perform
.
__code__
)
or
debug
):
thunk
=
node
.
op
.
make_py_thunk
(
node
,
storage_map
,
compute_map
,
thunk
=
node
.
op
.
make_py_thunk
(
node
,
storage_map
,
compute_map
,
no_recycling
)
no_recycling
,
debug
=
debug
)
thunks_py
.
append
(
thunk
)
thunks_py
.
append
(
thunk
)
else
:
else
:
thunks_py
.
append
(
None
)
thunks_py
.
append
(
None
)
if
not
self
.
maker
.
mode
.
check_c_code
and
thunks_py
[
-
1
]
is
None
:
if
not
self
.
maker
.
mode
.
check_c_code
and
thunks_py
[
-
1
]
is
None
:
_logger
.
warn
(
_logger
.
warn
(
"Op
%
s doesn't have a perform, "
"Op
%
s don't have a perform, forcing check of the c code"
%
"forcing check of the C code"
%
node
.
op
)
node
.
op
)
thunk
=
node
.
op
.
make_c_thunk
(
node
,
storage_map
,
compute_map
,
thunk
=
node
.
op
.
make_c_thunk
(
node
,
storage_map
,
compute_map
,
no_recycling
)
no_recycling
)
thunks_c
[
-
1
]
=
thunk
thunks_c
[
-
1
]
=
thunk
...
...
theano/gof/cc.py
浏览文件 @
8215457d
...
@@ -1177,11 +1177,13 @@ class CLinker(link.Linker):
...
@@ -1177,11 +1177,13 @@ class CLinker(link.Linker):
List of lists of length 1. In order to use
List of lists of length 1. In order to use
the thunk returned by __compile__, the inputs must be put in
the thunk returned by __compile__, the inputs must be put in
that storage. If None, storage will be allocated.
that storage. If None, storage will be allocated.
@param output_storage: list of lists of length 1. The thunk returned
output_storage: list of lists of length 1.
by __compile__ will put the variables of the computation in these
The thunk returned by __compile__ will put the variables
lists. If None, storage will be allocated.
of the computation in these lists. If None, storage will
@param storage_map: dict that map variables to storages. This is used
be allocated.
when you need to customize the storage of this thunk.
storage_map: dict that map variables to storages.
This is used when you need to customize the storage of
this thunk.
Returns: thunk, input_storage, output_storage
Returns: thunk, input_storage, output_storage
...
...
theano/gof/op.py
浏览文件 @
8215457d
...
@@ -890,7 +890,8 @@ class Op(utils.object2, PureOp, CLinkerOp):
...
@@ -890,7 +890,8 @@ class Op(utils.object2, PureOp, CLinkerOp):
rval
.
lazy
=
False
rval
.
lazy
=
False
return
rval
return
rval
def
make_py_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
):
def
make_py_thunk
(
self
,
node
,
storage_map
,
compute_map
,
no_recycling
,
debug
=
False
):
"""
"""
Like make_thunk() but only makes python thunks.
Like make_thunk() but only makes python thunks.
...
@@ -898,7 +899,10 @@ class Op(utils.object2, PureOp, CLinkerOp):
...
@@ -898,7 +899,10 @@ class Op(utils.object2, PureOp, CLinkerOp):
node_input_storage
=
[
storage_map
[
r
]
for
r
in
node
.
inputs
]
node_input_storage
=
[
storage_map
[
r
]
for
r
in
node
.
inputs
]
node_output_storage
=
[
storage_map
[
r
]
for
r
in
node
.
outputs
]
node_output_storage
=
[
storage_map
[
r
]
for
r
in
node
.
outputs
]
p
=
node
.
op
.
perform
if
debug
:
p
=
node
.
op
.
debug_perform
else
:
p
=
node
.
op
.
perform
params
=
node
.
run_params
()
params
=
node
.
run_params
()
...
...
theano/sandbox/cuda/basic_ops.py
浏览文件 @
8215457d
...
@@ -3682,6 +3682,13 @@ class GpuAllocEmpty(GpuOp):
...
@@ -3682,6 +3682,13 @@ class GpuAllocEmpty(GpuOp):
output
.
type
.
filter_checks_isfinite
=
False
output
.
type
.
filter_checks_isfinite
=
False
return
Apply
(
self
,
shape
,
[
output
])
return
Apply
(
self
,
shape
,
[
output
])
def
debug_perform
(
self
,
node
,
inputs
,
out_
):
self
.
perform
(
self
,
node
,
inputs
,
out_
)
# __setitem__ is limited on CudaNdarray
tmp
=
numpy
.
empty
(
out_
[
0
][
0
]
.
shape
,
dtype
=
'float32'
)
tmp
.
fill
(
-
123456789
)
out_
[
0
][
0
][:]
=
tmp
def
perform
(
self
,
node
,
inputs
,
out_
):
def
perform
(
self
,
node
,
inputs
,
out_
):
out
,
=
out_
out
,
=
out_
sh
=
tuple
([
int
(
i
)
for
i
in
inputs
])
sh
=
tuple
([
int
(
i
)
for
i
in
inputs
])
...
...
theano/sandbox/gpuarray/basic_ops.py
浏览文件 @
8215457d
...
@@ -723,6 +723,10 @@ class GpuAllocEmpty(HideC, Alloc):
...
@@ -723,6 +723,10 @@ class GpuAllocEmpty(HideC, Alloc):
output
.
type
.
filter_checks_isfinite
=
False
output
.
type
.
filter_checks_isfinite
=
False
return
Apply
(
self
,
sh
,
[
output
])
return
Apply
(
self
,
sh
,
[
output
])
def
debug_perform
(
self
,
node
,
inputs
,
out_
,
ctx
):
self
.
perform
(
node
,
inputs
,
out_
,
ctx
)
out_
[
0
][
0
][:]
=
-
123456789
def
perform
(
self
,
node
,
inputs
,
out_
,
ctx
):
def
perform
(
self
,
node
,
inputs
,
out_
,
ctx
):
out
=
out_
[
0
]
out
=
out_
[
0
]
sh
=
[
int
(
i
)
for
i
in
inputs
]
sh
=
[
int
(
i
)
for
i
in
inputs
]
...
...
theano/tensor/basic.py
浏览文件 @
8215457d
...
@@ -6240,6 +6240,10 @@ class AllocEmpty(gof.Op):
...
@@ -6240,6 +6240,10 @@ class AllocEmpty(gof.Op):
output
.
type
.
filter_checks_isfinite
=
False
output
.
type
.
filter_checks_isfinite
=
False
return
Apply
(
self
,
shape
,
[
output
])
return
Apply
(
self
,
shape
,
[
output
])
def
debug_perform
(
self
,
node
,
inputs
,
out_
):
self
.
perform
(
node
,
inputs
,
out_
)
out_
[
0
][
0
]
.
fill
(
-
123456789
)
def
perform
(
self
,
node
,
inputs
,
out_
):
def
perform
(
self
,
node
,
inputs
,
out_
):
out
,
=
out_
out
,
=
out_
sh
=
tuple
([
int
(
i
)
for
i
in
inputs
])
sh
=
tuple
([
int
(
i
)
for
i
in
inputs
])
...
...
theano/tensor/deprecated/__init__.py
deleted
100644 → 0
浏览文件 @
189069be
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论