Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
5bb459cf
提交
5bb459cf
authored
12月 17, 2020
作者:
Michael Osthege
提交者:
Brandon T. Willard
12月 21, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Migrate all of theano.gof.link to theano.link sub-package
上级
8bf588af
隐藏空白字符变更
内嵌
并排
正在显示
15 个修改的文件
包含
637 行增加
和
615 行删除
+637
-615
debugging_with_stepmode.txt
doc/sandbox/debugging_with_stepmode.txt
+1
-1
test_cc.py
tests/gof/test_cc.py
+1
-1
test_link.py
tests/gof/test_link.py
+1
-2
test_opt_uncanonicalize.py
tests/tensor/test_opt_uncanonicalize.py
+2
-1
debugmode.py
theano/compile/debugmode.py
+5
-4
types.py
theano/compile/function/types.py
+2
-2
__init__.py
theano/gof/__init__.py
+8
-2
cc.py
theano/gof/cc.py
+2
-2
vm.py
theano/gof/vm.py
+1
-3
__init__.py
theano/link/__init__.py
+17
-1
basic.py
theano/link/basic.py
+539
-7
debugging.py
theano/link/debugging.py
+50
-581
jax_linker.py
theano/link/jax/jax_linker.py
+2
-2
op.py
theano/scan/op.py
+3
-3
scan_perform.pyx
theano/scan/scan_perform.pyx
+3
-3
没有找到文件。
doc/sandbox/debugging_with_stepmode.txt
浏览文件 @
5bb459cf
...
...
@@ -11,7 +11,7 @@ purpose of it is to hack it to investigate what your own particular program is d
.. code-block:: python
from theano.
gof.
link import WrapLinkerMany
from theano.link import WrapLinkerMany
from theano import config
from theano.compile.mode import (Mode, register_mode, predefined_modes, predefined_linkers,
predefined_optimizers)
...
...
tests/gof/test_cc.py
浏览文件 @
5bb459cf
...
...
@@ -5,9 +5,9 @@ import theano
from
theano.gof
import
fg
from
theano.gof.cc
import
CLinker
,
DualLinker
,
OpWiseCLinker
from
theano.gof.graph
import
Apply
,
Constant
,
Variable
from
theano.gof.link
import
PerformLinker
from
theano.gof.op
import
Op
from
theano.gof.type
import
Type
from
theano.link
import
PerformLinker
def
as_variable
(
x
):
...
...
tests/gof/test_link.py
浏览文件 @
5bb459cf
...
...
@@ -5,10 +5,9 @@ import numpy as np
import
theano
from
theano.gof
import
fg
,
graph
from
theano.gof.graph
import
Apply
,
Constant
,
Variable
from
theano.gof.link
import
PerformLinker
,
WrapLinker
from
theano.gof.op
import
Op
from
theano.gof.type
import
Type
from
theano.link
import
Container
from
theano.link
import
Container
,
PerformLinker
,
WrapLinker
from
theano.utils
import
cmp
...
...
tests/tensor/test_opt_uncanonicalize.py
浏览文件 @
5bb459cf
...
...
@@ -6,6 +6,7 @@ from tests import unittest_tools as utt
from
theano
import
config
,
function
,
scalar
from
theano.gof
import
FunctionGraph
from
theano.gof.opt
import
out2in
from
theano.link.basic
import
PerformLinker
# from theano.tensor import matrix,max_and_argmax,MaaxAndArgmax,neg
from
theano.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
...
...
@@ -158,7 +159,7 @@ def test_local_dimshuffle_alloc():
g
=
FunctionGraph
([
x
],
[
out
])
reshape_dimshuffle
(
g
)
l
=
theano
.
gof
.
PerformLinker
()
l
=
PerformLinker
()
l
.
accept
(
g
)
f
=
l
.
make_function
()
...
...
theano/compile/debugmode.py
浏览文件 @
5bb459cf
...
...
@@ -28,7 +28,8 @@ from theano.compile.function.types import (
from
theano.compile.mode
import
Mode
,
register_mode
from
theano.compile.ops
import
OutputGuard
,
_output_guard
from
theano.gof
import
graph
,
ops_with_inner_function
,
utils
from
theano.gof.link
import
raise_with_op
from
theano.link.basic
import
LocalLinker
from
theano.link.debugging
import
raise_with_op
from
theano.utils
import
get_unbound_function
...
...
@@ -1739,14 +1740,14 @@ class _DummyLinker:
return
self
class
_Linker
(
link
.
LocalLinker
):
class
_Linker
(
LocalLinker
):
"""
Special debugging linker.
"""
def
__init__
(
self
,
maker
,
schedule
=
None
):
super
(
gof
.
LocalLinker
,
self
)
.
__init__
()
super
()
.
__init__
()
self
.
fgraph
=
None
self
.
maker
=
maker
super
()
.
__init__
(
scheduler
=
schedule
)
...
...
@@ -1792,7 +1793,7 @@ class _Linker(link.LocalLinker):
# the function's outputs will always be freshly allocated.
no_recycling
=
[]
input_storage
,
output_storage
,
storage_map
=
theano
.
gof
.
link
.
map_storage
(
input_storage
,
output_storage
,
storage_map
=
theano
.
link
.
map_storage
(
fgraph
,
order
,
input_storage_
,
output_storage_
,
storage_map
)
...
...
theano/compile/function/types.py
浏览文件 @
5bb459cf
...
...
@@ -15,7 +15,7 @@ import numpy as np
import
theano
import
theano.compile.profiling
from
theano
import
config
,
gof
from
theano
import
config
,
gof
,
link
from
theano.compile.io
import
In
,
SymbolicInput
,
SymbolicOutput
from
theano.compile.ops
import
deep_copy_op
,
view_op
from
theano.gof
import
graph
...
...
@@ -978,7 +978,7 @@ class Function:
thunk
=
None
if
hasattr
(
self
.
fn
,
"thunks"
):
thunk
=
self
.
fn
.
thunks
[
self
.
fn
.
position_of_error
]
gof
.
link
.
raise_with_op
(
link
.
raise_with_op
(
self
.
maker
.
fgraph
,
node
=
self
.
fn
.
nodes
[
self
.
fn
.
position_of_error
],
thunk
=
thunk
,
...
...
theano/gof/__init__.py
浏览文件 @
5bb459cf
...
...
@@ -5,7 +5,6 @@ from theano.gof.cc import CLinker, DualLinker, HideC, OpWiseCLinker
from
theano.gof.destroyhandler
import
DestroyHandler
from
theano.gof.fg
import
FunctionGraph
,
InconsistencyError
,
MissingInputError
from
theano.gof.graph
import
Apply
,
Constant
,
Variable
,
view_roots
from
theano.gof.link
import
PerformLinker
,
WrapLinker
,
WrapLinkerMany
from
theano.gof.op
import
(
COp
,
Op
,
...
...
@@ -47,7 +46,14 @@ from theano.gof.toolbox import (
)
from
theano.gof.type
import
CEnumType
,
EnumList
,
EnumType
,
Generic
,
Type
,
generic
from
theano.gof.utils
import
MethodNotDefined
,
hashtype
,
object2
from
theano.link
import
Container
,
Linker
,
LocalLinker
from
theano.link
import
(
Container
,
Linker
,
LocalLinker
,
PerformLinker
,
WrapLinker
,
WrapLinkerMany
,
)
if
theano
.
config
.
cmodule__preload_cache
:
...
...
theano/gof/cc.py
浏览文件 @
5bb459cf
...
...
@@ -11,8 +11,8 @@ from io import StringIO
import
numpy
as
np
from
theano
import
config
from
theano.gof
import
cmodule
,
graph
,
link
,
utils
from
theano
import
config
,
link
from
theano.gof
import
cmodule
,
graph
,
utils
from
theano.gof.callcache
import
CallCache
from
theano.gof.compilelock
import
get_lock
,
release_lock
...
...
theano/gof/vm.py
浏览文件 @
5bb459cf
...
...
@@ -13,9 +13,7 @@ import warnings
from
collections
import
defaultdict
import
theano.gof.cmodule
from
theano
import
config
from
.
import
link
from
theano
import
config
,
link
logger
=
logging
.
getLogger
(
__name__
)
...
...
theano/link/__init__.py
浏览文件 @
5bb459cf
from
theano.link.basic
import
Container
,
Linker
,
LocalLinker
import
sys
from
theano.link.basic
import
(
Container
,
Linker
,
LocalLinker
,
PerformLinker
,
WrapLinker
,
WrapLinkerMany
,
gc_helper
,
map_storage
,
streamline
,
)
from
theano.link.debugging
import
raise_with_op
,
set_excepthook
set_excepthook
(
handler
=
sys
.
stdout
)
theano/link/basic.py
浏览文件 @
5bb459cf
import
typing
from
copy
import
copy
,
deepcopy
from
theano
import
config
,
utils
from
theano.gof.fg
import
FunctionGraph
from
theano.gof.graph
import
Apply
from
theano.gof.graph
import
Apply
,
Constant
from
theano.gof.type
import
Type
from
theano.utils
import
deprecated
from
theano.gof.utils
import
to_return_values
from
theano.link.debugging
import
raise_with_op
class
Container
:
...
...
@@ -188,7 +190,7 @@ class Linker:
f
"make_thunk method of {type(self)} is not implemented."
)
@deprecated
(
"Marked for deletion. Only tests use it."
)
@
utils.
deprecated
(
"Marked for deletion. Only tests use it."
)
def
make_function
(
self
,
unpack_single
=
True
,
**
kwargs
):
"""
Returns a function that takes values corresponding to the inputs of the
...
...
@@ -210,8 +212,6 @@ class Linker:
length 1 will be returned.
"""
from
theano.gof
import
utils
thunk
,
inputs
,
outputs
=
self
.
make_thunk
(
**
kwargs
)
def
execute
(
*
args
):
...
...
@@ -225,7 +225,7 @@ class Linker:
variable
.
data
=
arg
thunk
()
if
unpack_single
:
return
utils
.
to_return_values
([
variable
.
data
for
variable
in
outputs
])
return
to_return_values
([
variable
.
data
for
variable
in
outputs
])
else
:
return
[
variable
.
data
for
variable
in
outputs
]
...
...
@@ -249,7 +249,7 @@ class Linker:
The result of the scheduling or toposort operation.
"""
if
callable
(
self
.
_scheduler
):
return
self
.
scheduler
(
fgraph
)
return
self
.
_
scheduler
(
fgraph
)
return
fgraph
.
toposort
()
...
...
@@ -279,3 +279,535 @@ class LocalLinker(Linker):
raise
NotImplementedError
(
f
"make_all method of {type(self)} is not implemented."
)
def
map_storage
(
fgraph
:
FunctionGraph
,
order
:
typing
.
Iterable
[
Apply
],
input_storage
:
typing
.
Optional
[
typing
.
List
],
output_storage
:
typing
.
Optional
[
typing
.
List
],
storage_map
:
typing
.
Dict
=
None
,
)
->
typing
.
Tuple
[
typing
.
List
,
typing
.
List
,
typing
.
Dict
]:
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
:param fgraph: The current fgraph. This function uses the inputs and outputs attributes.
:param order: an iterable over Apply instances (in program running order)
:param input_storage: None or existing input storage (see below)
:param output_storage: None or existing output storage (see below)
:rtype: 3-tuple
:returns: (list of storage for inputs, list of storage for outputs, and the `storage_map`)
Parameters
----------
fgraph
The current fgraph. This function uses the inputs and outputs
attributes.
order
An iterable over Apply instances (in program running order).
input_storage
None or existing input storage (see below).
output_storage
None or existing output storage (see below).
Returns
-------
3-tuple
List of storage for inputs, list of storage for outputs, and
the `storage_map`.
Extended summary
----------------
This function iterates over the nodes in `order` and ensures that for every
input and output `Variable`, there is a unique storage container. This is
returned as a dictionary Variable -> storage called the `storage_map`.
This function also returns `input_storage`, which is a list of storages
corresponding to fgraph.inputs.
This function also returns `output_storage`, which is a list of storages
corresponding to fgraph.outputs.
"""
# each Apply argument's data is stored in a list of length 1 (these lists act like pointers)
if
storage_map
is
None
:
storage_map
=
{}
# input_storage is a list of data-containers for the inputs.
if
input_storage
is
None
:
input_storage
=
[[
None
]
for
input
in
fgraph
.
inputs
]
else
:
assert
len
(
fgraph
.
inputs
)
==
len
(
input_storage
)
# add input storage into storage_map
for
r
,
storage
in
zip
(
fgraph
.
inputs
,
input_storage
):
if
r
in
storage_map
:
assert
storage_map
[
r
]
is
storage
,
(
"Given input_storage conflicts "
"with storage in given storage_"
"map. Given input_storage: "
,
storage
,
"Storage in storage_ma"
"p: "
,
storage_map
[
r
],
)
else
:
storage_map
[
r
]
=
storage
# for orphan in fgraph.orphans:
# if not isinstance(orphan, Constant):
# raise TypeError("Cannot link a graph with non-constant orphans.", orphan)
# storage_map[orphan] = [orphan.data]
# allocate output storage
if
output_storage
is
not
None
:
assert
len
(
fgraph
.
outputs
)
==
len
(
output_storage
)
for
r
,
storage
in
zip
(
fgraph
.
outputs
,
output_storage
):
if
r
in
storage_map
:
assert
storage_map
[
r
]
is
storage
,
(
"Given output_storage confl"
"icts with storage in given"
" storage_map. Given output"
"_storage: "
,
storage
,
"Sto"
"rage in storage_map: "
,
storage_map
[
r
],
)
else
:
storage_map
[
r
]
=
storage
# allocate storage for intermediate computation
for
node
in
order
:
for
r
in
node
.
inputs
:
if
r
not
in
storage_map
:
assert
isinstance
(
r
,
Constant
)
storage_map
[
r
]
=
[
r
.
data
]
for
r
in
node
.
outputs
:
storage_map
.
setdefault
(
r
,
[
None
])
for
r
in
fgraph
.
outputs
:
if
isinstance
(
r
,
Constant
):
storage_map
.
setdefault
(
r
,
[
r
.
data
])
# extract output storage
if
output_storage
is
None
:
output_storage
=
[
storage_map
[
r
]
for
r
in
fgraph
.
outputs
]
return
input_storage
,
output_storage
,
storage_map
def
add_clear_storage
(
f
,
computed
,
storage_map
):
def
clear_storage
():
for
c
in
computed
:
storage_map
[
c
][
0
]
=
None
f
.
clear_storage
=
clear_storage
def
streamline
(
fgraph
:
FunctionGraph
,
thunks
,
order
,
post_thunk_old_storage
=
None
,
no_recycling
=
None
,
nice_errors
=
True
,
)
->
typing
.
Callable
[[],
typing
.
NoReturn
]:
"""
WRITEME
Parameters
----------
fgraph
thunks
The list of program instructions.
order
The list of apply instances that gave rise to the thunks
(same order as thunks).
post_thunk_old_storage
A list (corresponding to thunks, order) whose elements are lists of
storage cells, that should be cleared after running thecorresponding
thunk. A value of None disables this functionality.
no_recycling
Storage elements that cannot be 'recycled' by repeatedly executing the
program. These storage elements are cleared before re-running.
nice_errors
Run in such a way that the double-traceback is printed. This costs a
bit of performance in the inner python loop.
"""
if
no_recycling
is
None
:
no_recycling
=
[]
if
len
(
thunks
)
!=
len
(
order
):
raise
ValueError
(
"Length of thunks and order must match"
,
(
len
(
thunks
),
len
(
order
))
)
if
post_thunk_old_storage
:
if
len
(
thunks
)
!=
len
(
post_thunk_old_storage
):
raise
ValueError
(
"Length of thunks and post_thunk_old_storage must match"
,
(
len
(
thunks
),
len
(
post_thunk_old_storage
)),
)
def
streamline_default_f
():
for
x
in
no_recycling
:
x
[
0
]
=
None
try
:
for
thunk
,
node
,
old_storage
in
zip
(
thunks
,
order
,
post_thunk_old_storage
):
thunk
()
for
old_s
in
old_storage
:
old_s
[
0
]
=
None
except
Exception
:
raise_with_op
(
fgraph
,
node
,
thunk
)
f
=
streamline_default_f
elif
nice_errors
:
def
streamline_nice_errors_f
():
for
x
in
no_recycling
:
x
[
0
]
=
None
try
:
for
thunk
,
node
in
zip
(
thunks
,
order
):
thunk
()
except
Exception
:
raise_with_op
(
fgraph
,
node
,
thunk
)
f
=
streamline_nice_errors_f
else
:
# don't worry about raise_with_op, just go a little faster.
# there is a mix of python and c thunks
def
streamline_fast_f
():
for
x
in
no_recycling
:
x
[
0
]
=
None
for
thunk
in
thunks
:
thunk
()
f
=
streamline_fast_f
return
f
def
gc_helper
(
node_list
:
typing
.
List
[
Apply
]):
"""
Return the set of Variable instances which are computed by node_list.
Parameters
----------
node_list
List of Apply instances in program execution order.
Returns
-------
2-tuple
FIRST, the set of Variable instances which are computed by node_list,
and SECOND a dictionary that maps each Variable instance to a the last
node to use Variable as an input.
Extended Summary
----------------
This is used to allow garbage collection within graphs.
It ignores view_map and destroy_map. This isn't needed as python
have reference count. In Theano gc, we should not take into
account view_map and destroy_map as if the thunk decided to create
a new output, we would delay uselessly its gc by Python.
"""
# for freeing memory
last_user
=
{}
computed
=
set
()
for
node
in
node_list
:
for
input
in
node
.
inputs
:
last_user
[
input
]
=
node
for
output
in
node
.
outputs
:
computed
.
add
(
output
)
return
computed
,
last_user
class
PerformLinker
(
LocalLinker
):
"""
Basic L{Linker} subclass that calls the perform method on each L{Op} in
the L{FunctionGraph} in the order given by L{Linker.schedule}.
"""
def
__init__
(
self
,
allow_gc
=
None
,
schedule
=
None
):
if
allow_gc
is
None
:
allow_gc
=
config
.
allow_gc
self
.
fgraph
=
None
super
()
.
__init__
(
allow_gc
=
allow_gc
,
scheduler
=
schedule
)
def
accept
(
self
,
fgraph
,
no_recycling
=
None
,
profile
=
None
):
"""
Parameters
----------
fgraph
A PerformLinker can have accepted one FunctionGraph instance at a time.
no_recycling
WRITEME
Returns
-------
object
self (TODO: WHY? Who calls this function?)
"""
if
no_recycling
is
None
:
no_recycling
=
[]
if
self
.
fgraph
is
not
None
and
self
.
fgraph
is
not
fgraph
:
return
type
(
self
)(
allow_gc
=
self
.
allow_gc
)
.
accept
(
fgraph
,
no_recycling
,
profile
)
# raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
self
.
fgraph
=
fgraph
self
.
no_recycling
=
no_recycling
return
self
def
make_all
(
self
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
):
"""
Returns Function to run all nodes, list of input containers, list of outputs
Parameters
----------
input_storage
list of storages corresponding to fgraph.inputs
output_storage
list of storages corresponding to fgraph.outputs
Returns
-------
object
Function to run all nodes, list of input containers, list of output
containers, list of thunks (for all programs), list of nodes
(for all programs).
"""
fgraph
=
self
.
fgraph
order
=
self
.
schedule
(
fgraph
)
no_recycling
=
self
.
no_recycling
input_storage
,
output_storage
,
storage_map
=
map_storage
(
fgraph
,
order
,
input_storage
,
output_storage
,
storage_map
)
compute_map
=
{}
for
k
in
storage_map
:
compute_map
[
k
]
=
[
k
.
owner
is
None
]
thunks
=
[]
for
node
in
order
:
# Maker sure we don't use C version of the code, but rather only
# the python version
# Note : ops that implement their own make thunk don't usually
# have this attribute defiend !!
thunks
+=
[
node
.
op
.
make_thunk
(
node
,
storage_map
,
compute_map
,
no_recycling
,
"py"
)
]
thunks
[
-
1
]
.
inputs
=
[
storage_map
[
v
]
for
v
in
node
.
inputs
]
thunks
[
-
1
]
.
outputs
=
[
storage_map
[
v
]
for
v
in
node
.
outputs
]
computed
,
last_user
=
gc_helper
(
order
)
if
self
.
allow_gc
:
post_thunk_old_storage
=
[]
else
:
post_thunk_old_storage
=
None
for
node
in
order
:
if
self
.
allow_gc
:
post_thunk_old_storage
.
append
(
[
storage_map
[
input
]
for
input
in
node
.
inputs
if
(
input
in
computed
)
and
(
input
not
in
fgraph
.
outputs
)
and
(
node
==
last_user
[
input
])
]
)
if
no_recycling
is
True
:
# True seems like some special code for *everything*?? -JB
# FunctionMaker always passes a list I think -JB
no_recycling
=
list
(
storage_map
.
values
())
no_recycling
=
utils
.
difference
(
no_recycling
,
input_storage
)
else
:
no_recycling
=
[
storage_map
[
r
]
for
r
in
no_recycling
if
r
not
in
fgraph
.
inputs
]
# The function that actually runs your program is one of the f's in streamline.
f
=
streamline
(
fgraph
,
thunks
,
order
,
post_thunk_old_storage
,
no_recycling
=
no_recycling
)
f
.
allow_gc
=
(
self
.
allow_gc
)
# HACK: this is a way of passing an arg to Function.__call__
add_clear_storage
(
f
,
computed
,
storage_map
)
f
.
storage_map
=
storage_map
return
(
f
,
[
Container
(
input
,
storage
)
for
input
,
storage
in
zip
(
fgraph
.
inputs
,
input_storage
)
],
[
Container
(
output
,
storage
,
readonly
=
True
)
for
output
,
storage
in
zip
(
fgraph
.
outputs
,
output_storage
)
],
thunks
,
order
,
)
class
WrapLinker
(
Linker
):
"""
This class makes it easier to run several L{LocalLinker}s in parallel, and
offers some control over how each thunk is run.
A wrapper function must be provided, and it can be used to execute the
thunks, inspect the nodes, print stuff out, etc.
The constructor initializes a WrapLinker.
Parameters
----------
linkers : list of L{LocalLinker} subclasses, whose make_all() method returns
thunks in the same order.
For each node in the graph, each linker will provide a
thunk. This class makes it possible to iterate over each linker's
program in parallel.
wrapper : lambda (fgraph, i, i_node, i_thunk1, i_thunk2, ...) : None
Does some user-defined action for the i'th element of the program.
i_thunk<n> is the thunk returned by the n'th linker. (If you want
to run the program, make sure to call the necessary thunks in this
function.)
Notes
-----
The outputs of the first linker will be returned.
This linker ensures that each linker has its own storage for inputs and
outputs and intermediate variables. There is no interference between
linkers.
"""
def
__init__
(
self
,
linkers
,
wrapper
):
self
.
fgraph
=
None
self
.
linkers
=
linkers
self
.
wrapper
=
wrapper
def
__copy__
(
self
):
"""
Shallow copy of a WrapLinker.
Returns
-------
object
A copy of self, where each of the linkers in self.linkers
have been shallow-copied.
It is useful because in FunctionMaker, copy.copy is called on the
Mode's linker, so that it is not modified inplace when linker.accept()
is called. In this case, we want the wrapped linkers to be copied too.
"""
other
=
self
.
__class__
(
linkers
=
[
copy
(
x
)
for
x
in
self
.
linkers
],
wrapper
=
self
.
wrapper
)
return
other
def
clone
(
self
,
allow_gc
=
None
):
return
self
.
__class__
(
linkers
=
[
x
.
clone
(
allow_gc
=
allow_gc
)
for
x
in
self
.
linkers
],
wrapper
=
self
.
wrapper
,
)
def
accept
(
self
,
fgraph
,
no_recycling
=
None
,
profile
=
None
):
"""
Parameters
----------
fgraph : gof.FunctionGraph
The fgraph which we will link.
no_recycling : a list of Variables that belong to fgraph.
If a Variable is in no_recycling, L{WrapLinker} will clear
the output storage associated to it (for each linker in linkers)
during the computation to avoid reusing it.
"""
if
no_recycling
is
None
:
no_recycling
=
[]
if
self
.
fgraph
is
not
None
and
self
.
fgraph
is
not
fgraph
:
return
type
(
self
)(
self
.
linkers
,
self
.
wrapper
)
.
accept
(
fgraph
,
no_recycling
)
self
.
fgraph
=
fgraph
self
.
no_recycling
=
no_recycling
self
.
linkers
=
[
linker
.
accept
(
fgraph
,
no_recycling
)
for
linker
in
self
.
linkers
]
return
self
def
pre
(
self
,
f
,
inputs
,
order
,
thunk_groups
):
pass
def
make_thunk
(
self
,
**
kwargs
):
no_recycling
=
self
.
no_recycling
make_all
=
[
self
.
linkers
[
0
]
.
make_all
(
**
kwargs
)]
kwargs
.
pop
(
"input_storage"
,
None
)
make_all
+=
[
x
.
make_all
(
**
kwargs
)
for
x
in
self
.
linkers
[
1
:]]
fns
,
input_lists
,
output_lists
,
thunk_lists
,
order_lists
=
zip
(
*
make_all
)
order_list0
=
order_lists
[
0
]
for
order_list
in
order_lists
[
1
:]:
if
not
order_list0
==
order_list
:
raise
Exception
(
"All linkers to WrapLinker should execute operations in the same order."
)
inputs0
=
input_lists
[
0
]
outputs0
=
output_lists
[
0
]
thunk_groups
=
list
(
zip
(
*
thunk_lists
))
order
=
[
x
[
0
]
for
x
in
zip
(
*
order_lists
)]
to_reset
=
[]
for
thunks
,
node
in
zip
(
thunk_groups
,
order
):
for
j
,
output
in
enumerate
(
node
.
outputs
):
if
output
in
no_recycling
:
for
thunk
in
thunks
:
to_reset
.
append
(
thunk
.
outputs
[
j
])
wrapper
=
self
.
wrapper
pre
=
self
.
pre
def
f
():
for
inputs
in
input_lists
[
1
:]:
for
input1
,
input2
in
zip
(
inputs0
,
inputs
):
input2
.
storage
[
0
]
=
copy
(
input1
.
storage
[
0
])
for
x
in
to_reset
:
x
[
0
]
=
None
pre
(
self
,
[
input
.
data
for
input
in
input_lists
[
0
]],
order
,
thunk_groups
)
for
i
,
(
thunks
,
node
)
in
enumerate
(
zip
(
thunk_groups
,
order
)):
try
:
wrapper
(
self
.
fgraph
,
i
,
node
,
*
thunks
)
except
Exception
:
raise_with_op
(
self
.
fgraph
,
node
,
*
thunks
)
f
.
thunk_groups
=
thunk_groups
return
f
,
inputs0
,
outputs0
def
WrapLinkerMany
(
linkers
,
wrappers
):
"""
Variant on WrapLinker that runs a series of wrapper functions instead of
just one.
"""
def
wrapper
(
*
args
):
for
f
in
wrappers
:
f
(
*
args
)
return
WrapLinker
(
linkers
,
wrapper
)
theano/
gof/link
.py
→
theano/
link/debugging
.py
浏览文件 @
5bb459cf
import
io
import
sys
import
traceback
from
copy
import
copy
from
io
import
StringIO
from
sys
import
getsizeof
from
warnings
import
warn
import
warnings
from
operator
import
itemgetter
import
numpy
as
np
import
theano
from
theano.gof
import
graph
,
utils
from
theano.link.basic
import
Container
,
Linker
,
LocalLinker
from
theano
import
config
from
theano.gof.fg
import
FunctionGraph
from
.utils
import
undef
__excepthook
=
sys
.
excepthook
def
log_thunk_trace
(
value
,
f
=
sys
.
stderr
):
def
__log_thunk_trace
(
value
,
handler
):
"""
Log Theano's diagnostic stack trace for an exception
raised by raise_with_op.
"""
# in future, consider accepting `write` as arg rather than file
# to support writing to a logger
def
write
(
msg
):
print
(
f
"log_thunk_trace: {msg.strip()}"
,
file
=
f
)
print
(
f
"log_thunk_trace: {msg.strip()}"
,
file
=
handler
)
if
hasattr
(
value
,
"__thunk_trace__"
):
trace2
=
value
.
__thunk_trace__
...
...
@@ -51,40 +42,41 @@ def log_thunk_trace(value, f=sys.stderr):
)
def
thunk_hook
(
type
,
value
,
trace
):
"""
This function is meant to replace excepthook and do some
special work if the exception value has a __thunk_trace__
field.
In that case, it retrieves the field, which should
contain a trace as returned by L{traceback.extract_stack},
and prints it out on L{stderr}.
def
set_excepthook
(
handler
:
io
.
TextIOWrapper
):
def
thunk_hook
(
type
,
value
,
trace
):
"""
This function is meant to replace excepthook and do some
special work if the exception value has a __thunk_trace__
field.
In that case, it retrieves the field, which should
contain a trace as returned by L{traceback.extract_stack},
and prints it out on L{stderr}.
The normal excepthook is then called.
The normal excepthook is then called.
Parameters:
----------
type
Exception class
value
Exception instance
trace
Traceback object
Parameters:
----------
type
Exception class
value
Exception instance
trace
Traceback object
Notes
-----
This hook replaced in testing, so it does not run.
"""
log_thunk_trace
(
value
)
__excepthook
(
type
,
value
,
trace
)
Notes
-----
This hook replaced in testing, so it does not run.
"""
__log_thunk_trace
(
value
,
handler
=
handler
)
sys
.
__excepthook__
(
type
,
value
,
trace
)
sys
.
excepthook
=
thunk_hook
sys
.
excepthook
=
thunk_hook
# TODO: Make this work with linker defined schedule
def
raise_with_op
(
fgraph
,
node
,
thunk
=
None
,
exc_info
=
None
,
storage_map
=
None
):
def
raise_with_op
(
fgraph
:
FunctionGraph
,
node
,
thunk
=
None
,
exc_info
=
None
,
storage_map
=
None
):
"""
Re-raise an exception while annotating the exception object with
debug info.
...
...
@@ -117,7 +109,10 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
The exception is not annotated if it is of type `KeyboardInterrupt`.
TODO: Make this work with linker defined schedule
"""
verbosity
=
config
.
exception_verbosity
if
exc_info
is
None
:
exc_info
=
sys
.
exc_info
()
exc_type
,
exc_value
,
exc_trace
=
exc_info
...
...
@@ -168,7 +163,7 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
+
f
"
\n
Inputs strides: {strides}"
+
f
"
\n
Inputs values: {scalar_values}"
)
if
theano
.
config
.
exception_
verbosity
==
"high"
:
if
verbosity
==
"high"
:
detailed_err_msg
+=
"
\n
Inputs type_num:
%
s"
%
str
(
[
getattr
(
getattr
(
i
[
0
],
"dtype"
,
""
),
"num"
,
""
)
for
i
in
thunk
.
inputs
]
)
...
...
@@ -188,7 +183,7 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
detailed_err_msg
+=
"
\n
Backtrace when the node is created(use Theano flag traceback__limit=N to make it longer):
\n
"
# Print separate message for each element in the list of batcktraces
sio
=
StringIO
()
sio
=
io
.
StringIO
()
for
subtr
in
tr
:
traceback
.
print_list
(
subtr
,
sio
)
detailed_err_msg
+=
str
(
sio
.
getvalue
())
...
...
@@ -201,15 +196,17 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
" Theano optimizations can be disabled with 'optimizer=None'."
)
if
theano
.
config
.
exception_verbosity
==
"high"
:
if
verbosity
==
"high"
:
import
theano.printing
f
=
StringIO
()
f
=
io
.
StringIO
()
theano
.
printing
.
debugprint
(
node
,
file
=
f
,
stop_on_name
=
True
,
print_type
=
True
)
detailed_err_msg
+=
"
\n
Debugprint of the apply node:
\n
"
detailed_err_msg
+=
f
.
getvalue
()
# Prints output_map
if
theano
.
config
.
exception_
verbosity
==
"high"
and
storage_map
is
not
None
:
if
verbosity
==
"high"
and
storage_map
is
not
None
:
detailed_err_msg
+=
"
\n
Storage map footprint:
\n
"
shared_input_list
=
[
item
...
...
@@ -283,7 +280,7 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
if
k
.
type
.
may_share_memory
(
data
,
input_data
):
total_size
-=
sz
else
:
bytes
=
getsizeof
(
storage_map
[
k
][
0
])
bytes
=
sys
.
getsizeof
(
storage_map
[
k
][
0
])
storage_map_item
.
append
(
bytes
)
storage_map_item
.
append
(
-
1
)
...
...
@@ -297,8 +294,6 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
storage_map_item
.
append
(
None
)
storage_map_list
.
append
(
storage_map_item
)
from
operator
import
itemgetter
storage_map_list
.
sort
(
key
=
itemgetter
(
3
),
reverse
=
True
)
for
item
in
storage_map_list
:
if
item
[
3
]
==
-
1
:
...
...
@@ -317,11 +312,11 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
detailed_err_msg
+=
"
\n
"
detailed_err_msg
+=
" TotalSize: {} Byte(s) {:.3f} GB
\n
"
.
format
(
total_size
,
total_size
/
1024
.0
/
1024
/
1024
,
total_size
/
1024
/
1024
/
1024
,
)
detailed_err_msg
+=
" TotalSize inputs: {} Byte(s) {:.3f} GB
\n
"
.
format
(
total_size_inputs
,
total_size_inputs
/
1024
.0
/
1024
/
1024
,
total_size_inputs
/
1024
/
1024
/
1024
,
)
else
:
...
...
@@ -335,533 +330,7 @@ def raise_with_op(fgraph, node, thunk=None, exc_info=None, storage_map=None):
str
(
exc_value
)
+
detailed_err_msg
+
"
\n
"
+
"
\n
"
.
join
(
hints
)
)
except
TypeError
:
warn
(
f
"{exc_type} error does not allow us to add extra error message"
)
warn
ings
.
warn
(
f
"{exc_type} error does not allow us to add extra error message"
)
# Some exception need extra parameter in inputs. So forget the
# extra long error message in that case.
raise
exc_value
.
with_traceback
(
exc_trace
)
def
map_storage
(
fgraph
,
order
,
input_storage
,
output_storage
,
storage_map
=
None
):
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
:param fgraph: The current fgraph. This function uses the inputs and outputs attributes.
:param order: an iterable over Apply instances (in program running order)
:param input_storage: None or existing input storage (see below)
:param output_storage: None or existing output storage (see below)
:rtype: 3-tuple
:returns: (list of storage for inputs, list of storage for outputs, and the `storage_map`)
Parameters
----------
fgraph
The current fgraph. This function uses the inputs and outputs
attributes.
order
An iterable over Apply instances (in program running order).
input_storage
None or existing input storage (see below).
output_storage
None or existing output storage (see below).
Returns
-------
3-tuple
List of storage for inputs, list of storage for outputs, and
the `storage_map`.
Extended summary
----------------
This function iterates over the nodes in `order` and ensures that for every
input and output `Variable`, there is a unique storage container. This is
returned as a dictionary Variable -> storage called the `storage_map`.
This function also returns `input_storage`, which is a list of storages
corresponding to fgraph.inputs.
This function also returns `output_storage`, which is a list of storages
corresponding to fgraph.outputs.
"""
# each Apply argument's data is stored in a list of length 1 (these lists act like pointers)
if
storage_map
is
None
:
storage_map
=
{}
# input_storage is a list of data-containers for the inputs.
if
input_storage
is
None
:
input_storage
=
[[
None
]
for
input
in
fgraph
.
inputs
]
else
:
assert
len
(
fgraph
.
inputs
)
==
len
(
input_storage
)
# add input storage into storage_map
for
r
,
storage
in
zip
(
fgraph
.
inputs
,
input_storage
):
if
r
in
storage_map
:
assert
storage_map
[
r
]
is
storage
,
(
"Given input_storage conflicts "
"with storage in given storage_"
"map. Given input_storage: "
,
storage
,
"Storage in storage_ma"
"p: "
,
storage_map
[
r
],
)
else
:
storage_map
[
r
]
=
storage
# for orphan in fgraph.orphans:
# if not isinstance(orphan, Constant):
# raise TypeError("Cannot link a graph with non-constant orphans.", orphan)
# storage_map[orphan] = [orphan.data]
# allocate output storage
if
output_storage
is
not
None
:
assert
len
(
fgraph
.
outputs
)
==
len
(
output_storage
)
for
r
,
storage
in
zip
(
fgraph
.
outputs
,
output_storage
):
if
r
in
storage_map
:
assert
storage_map
[
r
]
is
storage
,
(
"Given output_storage confl"
"icts with storage in given"
" storage_map. Given output"
"_storage: "
,
storage
,
"Sto"
"rage in storage_map: "
,
storage_map
[
r
],
)
else
:
storage_map
[
r
]
=
storage
# allocate storage for intermediate computation
for
node
in
order
:
for
r
in
node
.
inputs
:
if
r
not
in
storage_map
:
assert
isinstance
(
r
,
graph
.
Constant
)
storage_map
[
r
]
=
[
r
.
data
]
for
r
in
node
.
outputs
:
storage_map
.
setdefault
(
r
,
[
None
])
for
r
in
fgraph
.
outputs
:
if
isinstance
(
r
,
graph
.
Constant
):
storage_map
.
setdefault
(
r
,
[
r
.
data
])
# extract output storage
if
output_storage
is
None
:
output_storage
=
[
storage_map
[
r
]
for
r
in
fgraph
.
outputs
]
return
input_storage
,
output_storage
,
storage_map
def
add_clear_storage
(
f
,
computed
,
storage_map
):
def
clear_storage
():
for
c
in
computed
:
storage_map
[
c
][
0
]
=
None
f
.
clear_storage
=
clear_storage
def
streamline
(
fgraph
,
thunks
,
order
,
post_thunk_old_storage
=
None
,
no_recycling
=
None
,
nice_errors
=
True
,
):
"""
WRITEME
Parameters
----------
fgraph
thunks
The list of program instructions.
order
The list of apply instances that gave rise to the thunks
(same order as thunks).
post_thunk_old_storage
A list (corresponding to thunks, order) whose elements are lists of
storage cells, that should be cleared after running thecorresponding
thunk. A value of None disables this functionality.
no_recycling
Storage elements that cannot be 'recycled' by repeatedly executing the
program. These storage elements are cleared before re-running.
nice_errors
Run in such a way that the double-traceback is printed. This costs a
bit of performance in the inner python loop.
"""
if
no_recycling
is
None
:
no_recycling
=
[]
if
len
(
thunks
)
!=
len
(
order
):
raise
ValueError
(
"Length of thunks and order must match"
,
(
len
(
thunks
),
len
(
order
))
)
if
post_thunk_old_storage
:
if
len
(
thunks
)
!=
len
(
post_thunk_old_storage
):
raise
ValueError
(
"Length of thunks and post_thunk_old_storage must match"
,
(
len
(
thunks
),
len
(
post_thunk_old_storage
)),
)
def
streamline_default_f
():
for
x
in
no_recycling
:
x
[
0
]
=
None
try
:
for
thunk
,
node
,
old_storage
in
zip
(
thunks
,
order
,
post_thunk_old_storage
):
thunk
()
for
old_s
in
old_storage
:
old_s
[
0
]
=
None
except
Exception
:
raise_with_op
(
fgraph
,
node
,
thunk
)
f
=
streamline_default_f
elif
nice_errors
:
def
streamline_nice_errors_f
():
for
x
in
no_recycling
:
x
[
0
]
=
None
try
:
for
thunk
,
node
in
zip
(
thunks
,
order
):
thunk
()
except
Exception
:
raise_with_op
(
fgraph
,
node
,
thunk
)
f
=
streamline_nice_errors_f
else
:
# don't worry about raise_with_op, just go a little faster.
# there is a mix of python and c thunks
def
streamline_fast_f
():
for
x
in
no_recycling
:
x
[
0
]
=
None
for
thunk
in
thunks
:
thunk
()
f
=
streamline_fast_f
return
f
def
gc_helper
(
node_list
):
"""
Return the set of Variable instances which are computed by node_list.
Parameters
----------
node_list
List of Apply instances in program execution order.
Returns
-------
2-tuple
FIRST, the set of Variable instances which are computed by node_list,
and SECOND a dictionary that maps each Variable instance to a the last
node to use Variable as an input.
Extended Summary
----------------
This is used to allow garbage collection within graphs.
It ignores view_map and destroy_map. This isn't needed as python
have reference count. In Theano gc, we should not take into
account view_map and destroy_map as if the thunk decided to create
a new output, we would delay uselessly its gc by Python.
"""
# for freeing memory
last_user
=
{}
computed
=
set
()
for
node
in
node_list
:
for
input
in
node
.
inputs
:
last_user
[
input
]
=
node
for
output
in
node
.
outputs
:
computed
.
add
(
output
)
return
computed
,
last_user
class
PerformLinker
(
LocalLinker
):
"""
Basic L{Linker} subclass that calls the perform method on each L{Op} in
the L{FunctionGraph} in the order given by L{Linker.schedule}.
"""
def
__init__
(
self
,
allow_gc
=
None
,
schedule
=
None
):
if
allow_gc
is
None
:
allow_gc
=
theano
.
config
.
allow_gc
self
.
fgraph
=
None
super
()
.
__init__
(
allow_gc
=
allow_gc
,
scheduler
=
schedule
)
def
accept
(
self
,
fgraph
,
no_recycling
=
None
,
profile
=
None
):
"""
Parameters
----------
fgraph
A PerformLinker can have accepted one FunctionGraph instance at a time.
no_recycling
WRITEME
Returns
-------
object
self (TODO: WHY? Who calls this function?)
"""
if
no_recycling
is
None
:
no_recycling
=
[]
if
self
.
fgraph
is
not
None
and
self
.
fgraph
is
not
fgraph
:
return
type
(
self
)(
allow_gc
=
self
.
allow_gc
)
.
accept
(
fgraph
,
no_recycling
,
profile
)
# raise Exception("Cannot accept from a Linker that is already tied to another FunctionGraph.")
self
.
fgraph
=
fgraph
self
.
no_recycling
=
no_recycling
return
self
def
make_all
(
self
,
input_storage
=
None
,
output_storage
=
None
,
storage_map
=
None
):
"""
Returns Function to run all nodes, list of input containers, list of outputs
Parameters
----------
input_storage
list of storages corresponding to fgraph.inputs
output_storage
list of storages corresponding to fgraph.outputs
Returns
-------
object
Function to run all nodes, list of input containers, list of output
containers, list of thunks (for all programs), list of nodes
(for all programs).
"""
fgraph
=
self
.
fgraph
order
=
self
.
schedule
(
fgraph
)
no_recycling
=
self
.
no_recycling
input_storage
,
output_storage
,
storage_map
=
map_storage
(
fgraph
,
order
,
input_storage
,
output_storage
,
storage_map
)
compute_map
=
{}
for
k
in
storage_map
:
compute_map
[
k
]
=
[
k
.
owner
is
None
]
thunks
=
[]
for
node
in
order
:
# Maker sure we don't use C version of the code, but rather only
# the python version
# Note : ops that implement their own make thunk don't usually
# have this attribute defiend !!
thunks
+=
[
node
.
op
.
make_thunk
(
node
,
storage_map
,
compute_map
,
no_recycling
,
"py"
)
]
thunks
[
-
1
]
.
inputs
=
[
storage_map
[
v
]
for
v
in
node
.
inputs
]
thunks
[
-
1
]
.
outputs
=
[
storage_map
[
v
]
for
v
in
node
.
outputs
]
computed
,
last_user
=
gc_helper
(
order
)
if
self
.
allow_gc
:
post_thunk_old_storage
=
[]
else
:
post_thunk_old_storage
=
None
for
node
in
order
:
if
self
.
allow_gc
:
post_thunk_old_storage
.
append
(
[
storage_map
[
input
]
for
input
in
node
.
inputs
if
(
input
in
computed
)
and
(
input
not
in
fgraph
.
outputs
)
and
(
node
==
last_user
[
input
])
]
)
if
no_recycling
is
True
:
# True seems like some special code for *everything*?? -JB
# FunctionMaker always passes a list I think -JB
no_recycling
=
list
(
storage_map
.
values
())
no_recycling
=
utils
.
difference
(
no_recycling
,
input_storage
)
else
:
no_recycling
=
[
storage_map
[
r
]
for
r
in
no_recycling
if
r
not
in
fgraph
.
inputs
]
# The function that actually runs your program is one of the f's in streamline.
f
=
streamline
(
fgraph
,
thunks
,
order
,
post_thunk_old_storage
,
no_recycling
=
no_recycling
)
f
.
allow_gc
=
(
self
.
allow_gc
)
# HACK: this is a way of passing an arg to Function.__call__
add_clear_storage
(
f
,
computed
,
storage_map
)
f
.
storage_map
=
storage_map
return
(
f
,
[
Container
(
input
,
storage
)
for
input
,
storage
in
zip
(
fgraph
.
inputs
,
input_storage
)
],
[
Container
(
output
,
storage
,
readonly
=
True
)
for
output
,
storage
in
zip
(
fgraph
.
outputs
,
output_storage
)
],
thunks
,
order
,
)
class
WrapLinker
(
Linker
):
"""
This class makes it easier to run several L{LocalLinker}s in parallel, and
offers some control over how each thunk is run.
A wrapper function must be provided, and it can be used to execute the
thunks, inspect the nodes, print stuff out, etc.
The constructor initializes a WrapLinker.
Parameters
----------
linkers : list of L{LocalLinker} subclasses, whose make_all() method returns
thunks in the same order.
For each node in the graph, each linker will provide a
thunk. This class makes it possible to iterate over each linker's
program in parallel.
wrapper : lambda (fgraph, i, i_node, i_thunk1, i_thunk2, ...) : None
Does some user-defined action for the i'th element of the program.
i_thunk<n> is the thunk returned by the n'th linker. (If you want
to run the program, make sure to call the necessary thunks in this
function.)
Notes
-----
The outputs of the first linker will be returned.
This linker ensures that each linker has its own storage for inputs and
outputs and intermediate variables. There is no interference between
linkers.
"""
def
__init__
(
self
,
linkers
,
wrapper
):
self
.
fgraph
=
None
self
.
linkers
=
linkers
self
.
wrapper
=
wrapper
def
__copy__
(
self
):
"""
Shallow copy of a WrapLinker.
Returns
-------
object
A copy of self, where each of the linkers in self.linkers
have been shallow-copied.
It is useful because in FunctionMaker, copy.copy is called on the
Mode's linker, so that it is not modified inplace when linker.accept()
is called. In this case, we want the wrapped linkers to be copied too.
"""
other
=
self
.
__class__
(
linkers
=
[
copy
(
x
)
for
x
in
self
.
linkers
],
wrapper
=
self
.
wrapper
)
return
other
def
clone
(
self
,
allow_gc
=
undef
):
return
self
.
__class__
(
linkers
=
[
x
.
clone
(
allow_gc
=
allow_gc
)
for
x
in
self
.
linkers
],
wrapper
=
self
.
wrapper
,
)
def
accept
(
self
,
fgraph
,
no_recycling
=
None
,
profile
=
None
):
"""
Parameters
----------
fgraph : gof.FunctionGraph
The fgraph which we will link.
no_recycling : a list of Variables that belong to fgraph.
If a Variable is in no_recycling, L{WrapLinker} will clear
the output storage associated to it (for each linker in linkers)
during the computation to avoid reusing it.
"""
if
no_recycling
is
None
:
no_recycling
=
[]
if
self
.
fgraph
is
not
None
and
self
.
fgraph
is
not
fgraph
:
return
type
(
self
)(
self
.
linkers
,
self
.
wrapper
)
.
accept
(
fgraph
,
no_recycling
)
self
.
fgraph
=
fgraph
self
.
no_recycling
=
no_recycling
self
.
linkers
=
[
linker
.
accept
(
fgraph
,
no_recycling
)
for
linker
in
self
.
linkers
]
return
self
def
pre
(
self
,
f
,
inputs
,
order
,
thunk_groups
):
pass
def
make_thunk
(
self
,
**
kwargs
):
no_recycling
=
self
.
no_recycling
make_all
=
[
self
.
linkers
[
0
]
.
make_all
(
**
kwargs
)]
kwargs
.
pop
(
"input_storage"
,
None
)
make_all
+=
[
x
.
make_all
(
**
kwargs
)
for
x
in
self
.
linkers
[
1
:]]
fns
,
input_lists
,
output_lists
,
thunk_lists
,
order_lists
=
zip
(
*
make_all
)
order_list0
=
order_lists
[
0
]
for
order_list
in
order_lists
[
1
:]:
if
not
order_list0
==
order_list
:
raise
Exception
(
"All linkers to WrapLinker should execute operations in the same order."
)
inputs0
=
input_lists
[
0
]
outputs0
=
output_lists
[
0
]
thunk_groups
=
list
(
zip
(
*
thunk_lists
))
order
=
[
x
[
0
]
for
x
in
zip
(
*
order_lists
)]
to_reset
=
[]
for
thunks
,
node
in
zip
(
thunk_groups
,
order
):
for
j
,
output
in
enumerate
(
node
.
outputs
):
if
output
in
no_recycling
:
for
thunk
in
thunks
:
to_reset
.
append
(
thunk
.
outputs
[
j
])
wrapper
=
self
.
wrapper
pre
=
self
.
pre
def
f
():
for
inputs
in
input_lists
[
1
:]:
for
input1
,
input2
in
zip
(
inputs0
,
inputs
):
input2
.
storage
[
0
]
=
copy
(
input1
.
storage
[
0
])
for
x
in
to_reset
:
x
[
0
]
=
None
pre
(
self
,
[
input
.
data
for
input
in
input_lists
[
0
]],
order
,
thunk_groups
)
for
i
,
(
thunks
,
node
)
in
enumerate
(
zip
(
thunk_groups
,
order
)):
try
:
wrapper
(
self
.
fgraph
,
i
,
node
,
*
thunks
)
except
Exception
:
raise_with_op
(
self
.
fgraph
,
node
,
*
thunks
)
f
.
thunk_groups
=
thunk_groups
return
f
,
inputs0
,
outputs0
def
WrapLinkerMany
(
linkers
,
wrappers
):
"""
Variant on WrapLinker that runs a series of wrapper functions instead of
just one.
"""
def
wrapper
(
*
args
):
for
f
in
wrappers
:
f
(
*
args
)
return
WrapLinker
(
linkers
,
wrapper
)
theano/link/jax/jax_linker.py
浏览文件 @
5bb459cf
...
...
@@ -3,14 +3,14 @@ from warnings import warn
from
theano.gof
import
utils
from
theano.gof.graph
import
Constant
from
theano.gof.link
import
(
from
theano.link.basic
import
(
Container
,
PerformLinker
,
add_clear_storage
,
gc_helper
,
map_storage
,
streamline
,
)
from
theano.link
import
Container
class
JAXLinker
(
PerformLinker
):
...
...
theano/scan/op.py
浏览文件 @
5bb459cf
...
...
@@ -53,7 +53,7 @@ from collections import OrderedDict
import
numpy
as
np
import
theano
from
theano
import
compile
,
config
,
gof
,
gradient
,
tensor
from
theano
import
compile
,
config
,
gof
,
gradient
,
link
,
tensor
from
theano.compile.builders
import
infer_shape
from
theano.compile.function
import
function
from
theano.compile.io
import
In
,
Out
...
...
@@ -1465,7 +1465,7 @@ class Scan(PureOp):
# done by raise_with_op is not implemented in C.
if
hasattr
(
fn
,
"thunks"
):
# For the CVM
gof
.
link
.
raise_with_op
(
link
.
raise_with_op
(
self
.
fn
.
maker
.
fgraph
,
fn
.
nodes
[
fn
.
position_of_error
],
fn
.
thunks
[
fn
.
position_of_error
],
...
...
@@ -1475,7 +1475,7 @@ class Scan(PureOp):
# We don't have access from python to all the
# temps values So for now, we just don't print
# the extra shapes/strides info
gof
.
link
.
raise_with_op
(
link
.
raise_with_op
(
self
.
fn
.
maker
.
fgraph
,
fn
.
nodes
[
fn
.
position_of_error
]
)
else
:
...
...
theano/scan/scan_perform.pyx
浏览文件 @
5bb459cf
...
...
@@ -61,7 +61,7 @@ cimport numpy
import copy
import time
from theano import gof
from theano import gof
, link
def get_version():
...
...
@@ -405,7 +405,7 @@ def perform(
# done by raise_with_op is not implemented in C.
if hasattr(fn, 'thunks'):
# For the CVM
gof.
link.raise_with_op(fn.maker.fgraph,
link.raise_with_op(fn.maker.fgraph,
fn.nodes[fn.position_of_error],
fn.thunks[fn.position_of_error])
else:
...
...
@@ -413,7 +413,7 @@ def perform(
# We don't have access from python to all the
# temps values So for now, we just don't print
# the extra shapes/strides info
gof.
link.raise_with_op(fn.maker.fgraph, fn.nodes[fn.position_of_error])
link.raise_with_op(fn.maker.fgraph, fn.nodes[fn.position_of_error])
else:
# old-style linkers raise their own exceptions
raise
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论