Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f9618a62
提交
f9618a62
authored
3月 01, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
3月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix typing issues in aesara.graph.basic
上级
90c4ce82
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
125 行增加
和
92 行删除
+125
-92
basic.py
aesara/graph/basic.py
+116
-77
utils.py
aesara/graph/utils.py
+5
-7
setup.cfg
setup.cfg
+4
-8
没有找到文件。
aesara/graph/basic.py
浏览文件 @
f9618a62
...
@@ -4,15 +4,17 @@ from collections import deque
...
@@ -4,15 +4,17 @@ from collections import deque
from
copy
import
copy
from
copy
import
copy
from
itertools
import
count
from
itertools
import
count
from
typing
import
(
from
typing
import
(
TYPE_CHECKING
,
Callable
,
Callable
,
Collection
,
Collection
,
Deque
,
Deque
,
Dict
,
Dict
,
Generator
,
Generator
,
Hashable
,
Iterable
,
Iterable
,
Iterator
,
List
,
List
,
Optional
,
Optional
,
Reversible
,
Sequence
,
Sequence
,
Set
,
Set
,
Tuple
,
Tuple
,
...
@@ -36,9 +38,13 @@ from aesara.graph.utils import (
...
@@ -36,9 +38,13 @@ from aesara.graph.utils import (
from
aesara.misc.ordered_set
import
OrderedSet
from
aesara.misc.ordered_set
import
OrderedSet
T
=
TypeVar
(
"T"
)
if
TYPE_CHECKING
:
from
aesara.graph.type
import
Type
T
=
TypeVar
(
"T"
,
bound
=
"Node"
)
NoParams
=
object
()
NoParams
=
object
()
NodeAndChildren
=
Tuple
[
T
,
Optional
[
Iterable
[
T
]]]
class
Node
(
MetaObject
):
class
Node
(
MetaObject
):
...
@@ -49,6 +55,8 @@ class Node(MetaObject):
...
@@ -49,6 +55,8 @@ class Node(MetaObject):
keeps track of its parents via `Variable.owner` / `Apply.inputs`.
keeps track of its parents via `Variable.owner` / `Apply.inputs`.
"""
"""
type
:
"Type"
name
:
Optional
[
str
]
def
get_parents
(
self
):
def
get_parents
(
self
):
"""
"""
...
@@ -377,7 +385,23 @@ class Variable(Node):
...
@@ -377,7 +385,23 @@ class Variable(Node):
# __slots__ = ['type', 'owner', 'index', 'name']
# __slots__ = ['type', 'owner', 'index', 'name']
__count__
=
count
(
0
)
__count__
=
count
(
0
)
def
__init__
(
self
,
type
,
owner
=
None
,
index
=
None
,
name
=
None
):
_owner
:
Optional
[
Apply
]
@property
def
owner
(
self
)
->
Optional
[
Apply
]:
return
self
.
_owner
@owner.setter
def
owner
(
self
,
value
)
->
None
:
self
.
_owner
=
value
def
__init__
(
self
,
type
,
owner
:
Optional
[
Apply
]
=
None
,
index
:
Optional
[
int
]
=
None
,
name
:
Optional
[
str
]
=
None
,
):
super
()
.
__init__
()
super
()
.
__init__
()
self
.
tag
=
ValidatingScratchpad
(
"test_value"
,
type
.
filter
)
self
.
tag
=
ValidatingScratchpad
(
"test_value"
,
type
.
filter
)
...
@@ -386,7 +410,7 @@ class Variable(Node):
...
@@ -386,7 +410,7 @@ class Variable(Node):
if
owner
is
not
None
and
not
isinstance
(
owner
,
Apply
):
if
owner
is
not
None
and
not
isinstance
(
owner
,
Apply
):
raise
TypeError
(
"owner must be an Apply instance"
,
owner
)
raise
TypeError
(
"owner must be an Apply instance"
,
owner
)
self
.
owner
=
owner
self
.
_
owner
=
owner
if
index
is
not
None
and
not
isinstance
(
index
,
int
):
if
index
is
not
None
and
not
isinstance
(
index
,
int
):
raise
TypeError
(
"index must be an int"
,
index
)
raise
TypeError
(
"index must be an int"
,
index
)
...
@@ -607,19 +631,15 @@ class Constant(Variable):
...
@@ -607,19 +631,15 @@ class Constant(Variable):
"""Return `self`, because there's no reason to clone a constant."""
"""Return `self`, because there's no reason to clone a constant."""
return
self
return
self
def
__set_owner
(
self
,
value
):
@property
"""Prevent the :prop:`owner` property from being set.
def
owner
(
self
)
->
None
:
return
None
Raises
------
ValueError
If `value` is not ``None``.
"""
@owner.setter
def
owner
(
self
,
value
)
->
None
:
if
value
is
not
None
:
if
value
is
not
None
:
raise
ValueError
(
"Constant instances cannot have an owner."
)
raise
ValueError
(
"Constant instances cannot have an owner."
)
owner
=
property
(
lambda
self
:
None
,
__set_owner
)
value
=
property
(
lambda
self
:
self
.
data
,
doc
=
"read-only data access method"
)
value
=
property
(
lambda
self
:
self
.
data
,
doc
=
"read-only data access method"
)
# index is not defined, because the `owner` attribute must necessarily be None
# index is not defined, because the `owner` attribute must necessarily be None
...
@@ -627,34 +647,30 @@ class Constant(Variable):
...
@@ -627,34 +647,30 @@ class Constant(Variable):
def
walk
(
def
walk
(
nodes
:
Iterable
[
T
],
nodes
:
Iterable
[
T
],
expand
:
Callable
[[
T
],
Optional
[
Sequenc
e
[
T
]]],
expand
:
Callable
[[
T
],
Optional
[
Iterabl
e
[
T
]]],
bfs
:
bool
=
True
,
bfs
:
bool
=
True
,
return_children
:
bool
=
False
,
return_children
:
bool
=
False
,
hash_fn
:
Callable
[[
T
],
Hashable
]
=
id
,
hash_fn
:
Callable
[[
T
],
int
]
=
id
,
)
->
Generator
[
Union
[
T
,
Tuple
[
T
,
Optional
[
Sequence
[
T
]]]
],
None
,
None
]:
)
->
Generator
[
Union
[
T
,
NodeAndChildren
],
None
,
None
]:
"""Walk through a graph, either breadth- or depth-first.
r
"""Walk through a graph, either breadth- or depth-first.
Parameters
Parameters
----------
----------
nodes
: deque
nodes
The nodes from which to start walking.
The nodes from which to start walking.
expand
: callable
expand
A callable that is applied to each node in `nodes`, the results of
A callable that is applied to each node in `nodes`, the results of
which are either new nodes to visit or ``None``.
which are either new nodes to visit or ``None``.
bfs
: bool
bfs
If ``True``, breath first search is used; otherwise, depth first
If ``True``, breath first search is used; otherwise, depth first
search.
search.
return_children
: bool
return_children
If ``True``, each output node will be accompanied by the output of
If ``True``, each output node will be accompanied by the output of
`expand` (i.e. the corresponding child nodes).
`expand` (i.e. the corresponding child nodes).
hash_fn
: callable
hash_fn
The function used to produce hashes of the elements in `nodes`.
The function used to produce hashes of the elements in `nodes`.
The default is ``id``.
The default is ``id``.
Yields
------
nodes
Notes
Notes
-----
-----
A node will appear at most once in the return value, even if it
A node will appear at most once in the return value, even if it
...
@@ -664,7 +680,7 @@ def walk(
...
@@ -664,7 +680,7 @@ def walk(
nodes
=
deque
(
nodes
)
nodes
=
deque
(
nodes
)
rval_set
:
Set
[
T
]
=
set
()
rval_set
:
Set
[
int
]
=
set
()
nodes_pop
:
Callable
[[],
T
]
nodes_pop
:
Callable
[[],
T
]
if
bfs
:
if
bfs
:
...
@@ -675,13 +691,13 @@ def walk(
...
@@ -675,13 +691,13 @@ def walk(
while
nodes
:
while
nodes
:
node
:
T
=
nodes_pop
()
node
:
T
=
nodes_pop
()
node_hash
:
Hashable
=
hash_fn
(
node
)
node_hash
:
int
=
hash_fn
(
node
)
if
node_hash
not
in
rval_set
:
if
node_hash
not
in
rval_set
:
rval_set
.
add
(
node_hash
)
rval_set
.
add
(
node_hash
)
new_nodes
:
Optional
[
Sequenc
e
[
T
]]
=
expand
(
node
)
new_nodes
:
Optional
[
Iterabl
e
[
T
]]
=
expand
(
node
)
if
return_children
:
if
return_children
:
yield
node
,
new_nodes
yield
node
,
new_nodes
...
@@ -693,7 +709,7 @@ def walk(
...
@@ -693,7 +709,7 @@ def walk(
def
ancestors
(
def
ancestors
(
graphs
:
Iterable
[
Variable
],
blockers
:
Collection
[
Variable
]
=
None
graphs
:
Iterable
[
Variable
],
blockers
:
Optional
[
Collection
[
Variable
]
]
=
None
)
->
Generator
[
Variable
,
None
,
None
]:
)
->
Generator
[
Variable
,
None
,
None
]:
r"""Return the variables that contribute to those in given graphs (inclusive).
r"""Return the variables that contribute to those in given graphs (inclusive).
...
@@ -714,15 +730,15 @@ def ancestors(
...
@@ -714,15 +730,15 @@ def ancestors(
"""
"""
def
expand
(
r
)
:
def
expand
(
r
:
Variable
)
->
Optional
[
Iterator
[
Variable
]]
:
if
r
.
owner
and
(
not
blockers
or
r
not
in
blockers
):
if
r
.
owner
and
(
not
blockers
or
r
not
in
blockers
):
return
reversed
(
r
.
owner
.
inputs
)
return
reversed
(
r
.
owner
.
inputs
)
yield
from
walk
(
graphs
,
expand
,
False
)
yield
from
cast
(
Generator
[
Variable
,
None
,
None
],
walk
(
graphs
,
expand
,
False
)
)
def
graph_inputs
(
def
graph_inputs
(
graphs
:
Iterable
[
Variable
],
blockers
:
Collection
[
Variable
]
=
None
graphs
:
Iterable
[
Variable
],
blockers
:
Optional
[
Collection
[
Variable
]
]
=
None
)
->
Generator
[
Variable
,
None
,
None
]:
)
->
Generator
[
Variable
,
None
,
None
]:
r"""Return the inputs required to compute the given Variables.
r"""Return the inputs required to compute the given Variables.
...
@@ -751,14 +767,13 @@ def vars_between(
...
@@ -751,14 +767,13 @@ def vars_between(
Parameters
Parameters
----------
----------
ins
: list
ins
Input `Variable`\s.
Input `Variable`\s.
outs
: list
outs
Output `Variable`\s.
Output `Variable`\s.
Yields
Yields
------
------
`Variable`\s
The `Variable`\s that are involved in the subgraph that lies
The `Variable`\s that are involved in the subgraph that lies
between `ins` and `outs`. This includes `ins`, `outs`,
between `ins` and `outs`. This includes `ins`, `outs`,
``orphans_between(ins, outs)`` and all values of all intermediary steps from
``orphans_between(ins, outs)`` and all values of all intermediary steps from
...
@@ -766,11 +781,11 @@ def vars_between(
...
@@ -766,11 +781,11 @@ def vars_between(
"""
"""
def
expand
(
r
)
:
def
expand
(
r
:
Variable
)
->
Optional
[
Iterable
[
Variable
]]
:
if
r
.
owner
and
r
not
in
ins
:
if
r
.
owner
and
r
not
in
ins
:
return
reversed
(
r
.
owner
.
inputs
+
r
.
owner
.
outputs
)
return
reversed
(
r
.
owner
.
inputs
+
r
.
owner
.
outputs
)
yield
from
walk
(
outs
,
expand
)
yield
from
cast
(
Generator
[
Variable
,
None
,
None
],
walk
(
outs
,
expand
)
)
def
orphans_between
(
def
orphans_between
(
...
@@ -862,16 +877,18 @@ def clone(
...
@@ -862,16 +877,18 @@ def clone(
if
copy_orphans
is
None
:
if
copy_orphans
is
None
:
copy_orphans
=
copy_inputs
copy_orphans
=
copy_inputs
equiv
=
clone_get_equiv
(
inputs
,
outputs
,
copy_inputs
,
copy_orphans
)
equiv
=
clone_get_equiv
(
inputs
,
outputs
,
copy_inputs
,
copy_orphans
)
return
[
equiv
[
input
]
for
input
in
inputs
],
[
equiv
[
output
]
for
output
in
outputs
]
return
[
cast
(
Variable
,
equiv
[
input
])
for
input
in
inputs
],
[
cast
(
Variable
,
equiv
[
output
])
for
output
in
outputs
]
def
clone_get_equiv
(
def
clone_get_equiv
(
inputs
:
List
[
Variable
],
inputs
:
Sequence
[
Variable
],
outputs
:
List
[
Variable
],
outputs
:
Sequence
[
Variable
],
copy_inputs
:
bool
=
True
,
copy_inputs
:
bool
=
True
,
copy_orphans
:
bool
=
True
,
copy_orphans
:
bool
=
True
,
memo
:
Optional
[
Dict
[
Variable
,
Variabl
e
]]
=
None
,
memo
:
Optional
[
Dict
[
Node
,
Nod
e
]]
=
None
,
)
->
Dict
[
Variable
,
Variabl
e
]:
)
->
Dict
[
Node
,
Nod
e
]:
"""
"""
Return a dictionary that maps from `Variable` and `Apply` nodes in the
Return a dictionary that maps from `Variable` and `Apply` nodes in the
original graph to a new node (a clone) in a new graph.
original graph to a new node (a clone) in a new graph.
...
@@ -934,11 +951,13 @@ def clone_get_equiv(
...
@@ -934,11 +951,13 @@ def clone_get_equiv(
def
clone_replace
(
def
clone_replace
(
output
:
Collection
[
Variable
],
output
:
List
[
Variable
],
replace
:
Optional
[
Dict
[
Variable
,
Variable
]]
=
None
,
replace
:
Optional
[
Union
[
Iterable
[
Tuple
[
Variable
,
Variable
]],
Dict
[
Variable
,
Variable
]]
]
=
None
,
strict
:
bool
=
True
,
strict
:
bool
=
True
,
share_inputs
:
bool
=
True
,
share_inputs
:
bool
=
True
,
)
->
Collection
[
Variable
]:
)
->
List
[
Variable
]:
"""Clone a graph and replace subgraphs within it.
"""Clone a graph and replace subgraphs within it.
It returns a copy of the initial subgraph with the corresponding
It returns a copy of the initial subgraph with the corresponding
...
@@ -959,6 +978,7 @@ def clone_replace(
...
@@ -959,6 +978,7 @@ def clone_replace(
"""
"""
from
aesara.compile.function.pfunc
import
rebuild_collect_shared
from
aesara.compile.function.pfunc
import
rebuild_collect_shared
items
:
Union
[
List
[
Tuple
[
Variable
,
Variable
]],
Tuple
[
Tuple
[
Variable
,
Variable
],
...
]]
if
isinstance
(
replace
,
dict
):
if
isinstance
(
replace
,
dict
):
items
=
list
(
replace
.
items
())
items
=
list
(
replace
.
items
())
elif
isinstance
(
replace
,
(
list
,
tuple
)):
elif
isinstance
(
replace
,
(
list
,
tuple
)):
...
@@ -984,13 +1004,15 @@ def clone_replace(
...
@@ -984,13 +1004,15 @@ def clone_replace(
_outs
,
[],
new_replace
,
[],
strict
,
share_inputs
_outs
,
[],
new_replace
,
[],
strict
,
share_inputs
)
)
return
outs
return
cast
(
List
[
Variable
],
outs
)
def
general_toposort
(
def
general_toposort
(
outputs
:
Iterable
[
T
],
outputs
:
Iterable
[
T
],
deps
:
Callable
[[
T
],
Union
[
OrderedSet
,
List
[
T
]]],
deps
:
Callable
[[
T
],
Union
[
OrderedSet
,
List
[
T
]]],
compute_deps_cache
:
Optional
[
Callable
[[
T
],
Union
[
OrderedSet
,
List
[
T
]]]]
=
None
,
compute_deps_cache
:
Optional
[
Callable
[[
T
],
Optional
[
Union
[
OrderedSet
,
List
[
T
]]]]
]
=
None
,
deps_cache
:
Optional
[
Dict
[
T
,
List
[
T
]]]
=
None
,
deps_cache
:
Optional
[
Dict
[
T
,
List
[
T
]]]
=
None
,
clients
:
Optional
[
Dict
[
T
,
List
[
T
]]]
=
None
,
clients
:
Optional
[
Dict
[
T
,
List
[
T
]]]
=
None
,
)
->
List
[
T
]:
)
->
List
[
T
]:
...
@@ -1030,7 +1052,7 @@ def general_toposort(
...
@@ -1030,7 +1052,7 @@ def general_toposort(
if
deps_cache
is
None
:
if
deps_cache
is
None
:
deps_cache
=
{}
deps_cache
=
{}
def
compute_deps_cache
(
io
):
def
_
compute_deps_cache
(
io
):
if
io
not
in
deps_cache
:
if
io
not
in
deps_cache
:
d
=
deps
(
io
)
d
=
deps
(
io
)
...
@@ -1048,23 +1070,27 @@ def general_toposort(
...
@@ -1048,23 +1070,27 @@ def general_toposort(
else
:
else
:
return
deps_cache
[
io
]
return
deps_cache
[
io
]
else
:
_compute_deps_cache
=
compute_deps_cache
if
deps_cache
is
None
:
if
deps_cache
is
None
:
raise
ValueError
(
"deps_cache cannot be None"
)
raise
ValueError
(
"deps_cache cannot be None"
)
search_res
:
List
[
Tuple
[
T
,
Optional
[
List
[
T
]]]]
=
list
(
search_res
:
List
[
NodeAndChildren
]
=
cast
(
walk
(
outputs
,
compute_deps_cache
,
bfs
=
False
,
return_children
=
True
)
List
[
NodeAndChildren
],
list
(
walk
(
outputs
,
_compute_deps_cache
,
bfs
=
False
,
return_children
=
True
)),
)
)
_clients
:
Dict
[
T
,
List
[
T
]]
=
{}
_clients
:
Dict
[
T
,
List
[
T
]]
=
{}
sources
:
Deque
[
T
]
=
deque
()
sources
:
Deque
[
T
]
=
deque
()
search_res_len
:
int
=
0
search_res_len
:
int
=
0
for
node
,
children
in
search_res
:
for
s
node
,
children
in
search_res
:
search_res_len
+=
1
search_res_len
+=
1
if
children
:
if
children
:
for
child
in
children
:
for
child
in
children
:
_clients
.
setdefault
(
child
,
[])
.
append
(
node
)
_clients
.
setdefault
(
child
,
[])
.
append
(
s
node
)
if
not
deps_cache
.
get
(
node
):
if
not
deps_cache
.
get
(
s
node
):
sources
.
append
(
node
)
sources
.
append
(
s
node
)
if
clients
is
not
None
:
if
clients
is
not
None
:
clients
.
update
(
_clients
)
clients
.
update
(
_clients
)
...
@@ -1090,7 +1116,7 @@ def general_toposort(
...
@@ -1090,7 +1116,7 @@ def general_toposort(
def
io_toposort
(
def
io_toposort
(
inputs
:
Iterable
[
Variable
],
inputs
:
Iterable
[
Variable
],
outputs
:
Itera
ble
[
Variable
],
outputs
:
Reversi
ble
[
Variable
],
orderings
:
Optional
[
Dict
[
Apply
,
List
[
Apply
]]]
=
None
,
orderings
:
Optional
[
Dict
[
Apply
,
List
[
Apply
]]]
=
None
,
clients
:
Optional
[
Dict
[
Variable
,
List
[
Variable
]]]
=
None
,
clients
:
Optional
[
Dict
[
Variable
,
List
[
Variable
]]]
=
None
,
)
->
List
[
Apply
]:
)
->
List
[
Apply
]:
...
@@ -1301,8 +1327,8 @@ def as_string(
...
@@ -1301,8 +1327,8 @@ def as_string(
orph
=
list
(
orphans_between
(
i
,
outputs
))
orph
=
list
(
orphans_between
(
i
,
outputs
))
multi
:
Set
=
set
()
multi
=
set
()
seen
:
Set
=
set
()
seen
=
set
()
for
output
in
outputs
:
for
output
in
outputs
:
op
=
output
.
owner
op
=
output
.
owner
if
op
in
seen
:
if
op
in
seen
:
...
@@ -1318,11 +1344,11 @@ def as_string(
...
@@ -1318,11 +1344,11 @@ def as_string(
multi
.
add
(
op2
)
multi
.
add
(
op2
)
else
:
else
:
seen
.
add
(
input
.
owner
)
seen
.
add
(
input
.
owner
)
multi
:
Se
t
=
[
x
for
x
in
multi
]
multi
_lis
t
=
[
x
for
x
in
multi
]
done
:
Set
=
set
()
done
:
Set
=
set
()
def
multi_index
(
x
):
def
multi_index
(
x
):
return
multi
.
index
(
x
)
+
1
return
multi
_list
.
index
(
x
)
+
1
def
describe
(
r
):
def
describe
(
r
):
if
r
.
owner
is
not
None
and
r
not
in
i
and
r
not
in
orph
:
if
r
.
owner
is
not
None
and
r
not
in
i
and
r
not
in
orph
:
...
@@ -1337,7 +1363,7 @@ def as_string(
...
@@ -1337,7 +1363,7 @@ def as_string(
else
:
else
:
done
.
add
(
op
)
done
.
add
(
op
)
s
=
node_formatter
(
op
,
[
describe
(
input
)
for
input
in
op
.
inputs
])
s
=
node_formatter
(
op
,
[
describe
(
input
)
for
input
in
op
.
inputs
])
if
op
in
multi
:
if
op
in
multi
_list
:
return
f
"*{multi_index(op)} -> {s}"
return
f
"*{multi_index(op)} -> {s}"
else
:
else
:
return
s
return
s
...
@@ -1380,14 +1406,21 @@ def list_of_nodes(
...
@@ -1380,14 +1406,21 @@ def list_of_nodes(
Output `Variable`\s.
Output `Variable`\s.
"""
"""
return
list
(
walk
(
def
expand
(
o
:
Apply
)
->
List
[
Apply
]:
[
o
.
owner
for
o
in
outputs
],
return
[
lambda
o
:
[
inp
.
owner
inp
.
owner
for
inp
in
o
.
inputs
for
inp
in
o
.
inputs
if
inp
.
owner
and
not
any
(
i
in
inp
.
owner
.
outputs
for
i
in
inputs
)
if
inp
.
owner
and
not
any
(
i
in
inp
.
owner
.
outputs
for
i
in
inputs
)
],
]
return
list
(
cast
(
Iterable
[
Apply
],
walk
(
[
o
.
owner
for
o
in
outputs
if
o
.
owner
],
expand
,
),
)
)
)
)
...
@@ -1423,7 +1456,12 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool:
...
@@ -1423,7 +1456,12 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool:
return
False
return
False
def
equal_computations
(
xs
,
ys
,
in_xs
=
None
,
in_ys
=
None
):
def
equal_computations
(
xs
:
List
[
Union
[
np
.
ndarray
,
Variable
]],
ys
:
List
[
Union
[
np
.
ndarray
,
Variable
]],
in_xs
:
Optional
[
List
[
Variable
]]
=
None
,
in_ys
:
Optional
[
List
[
Variable
]]
=
None
,
)
->
bool
:
"""Checks if Aesara graphs represent the same computations.
"""Checks if Aesara graphs represent the same computations.
The two lists `xs`, `ys` should have the same number of entries. The
The two lists `xs`, `ys` should have the same number of entries. The
...
@@ -1458,22 +1496,20 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
...
@@ -1458,22 +1496,20 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
for
x
,
y
in
zip
(
xs
,
ys
):
for
x
,
y
in
zip
(
xs
,
ys
):
if
not
isinstance
(
x
,
Variable
)
and
not
isinstance
(
y
,
Variable
):
if
not
isinstance
(
x
,
Variable
)
and
not
isinstance
(
y
,
Variable
):
return
np
.
array_equal
(
x
,
y
)
return
cast
(
bool
,
np
.
array_equal
(
x
,
y
)
)
if
not
isinstance
(
x
,
Variable
):
if
not
isinstance
(
x
,
Variable
):
if
isinstance
(
y
,
Constant
):
if
isinstance
(
y
,
Constant
):
return
np
.
array_equal
(
y
.
data
,
x
)
return
cast
(
bool
,
np
.
array_equal
(
y
.
data
,
x
)
)
return
False
return
False
if
not
isinstance
(
y
,
Variable
):
if
not
isinstance
(
y
,
Variable
):
if
isinstance
(
x
,
Constant
):
if
isinstance
(
x
,
Constant
):
return
np
.
array_equal
(
x
.
data
,
y
)
return
cast
(
bool
,
np
.
array_equal
(
x
.
data
,
y
))
return
False
if
not
isinstance
(
y
,
Variable
):
return
False
return
False
if
x
.
owner
and
not
y
.
owner
:
if
x
.
owner
and
not
y
.
owner
:
return
False
return
False
if
y
.
owner
and
not
x
.
owner
:
if
y
.
owner
and
not
x
.
owner
:
return
False
return
False
if
x
.
owner
:
# Check above tell that y.owner eval to True too.
if
x
.
owner
and
y
.
owner
:
if
x
.
owner
.
outputs
.
index
(
x
)
!=
y
.
owner
.
outputs
.
index
(
y
):
if
x
.
owner
.
outputs
.
index
(
x
)
!=
y
.
owner
.
outputs
.
index
(
y
):
return
False
return
False
if
x
not
in
in_xs
and
not
(
y
.
type
.
in_same_class
(
x
.
type
)):
if
x
not
in
in_xs
and
not
(
y
.
type
.
in_same_class
(
x
.
type
)):
...
@@ -1487,8 +1523,9 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
...
@@ -1487,8 +1523,9 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
return
False
return
False
common
=
set
(
zip
(
in_xs
,
in_ys
))
common
=
set
(
zip
(
in_xs
,
in_ys
))
different
=
set
()
different
:
Set
[
Tuple
[
Variable
,
Variable
]]
=
set
()
for
dx
,
dy
in
zip
(
xs
,
ys
):
for
dx
,
dy
in
zip
(
xs
,
ys
):
assert
isinstance
(
dx
,
Variable
)
# We checked above that both dx and dy have an owner or not
# We checked above that both dx and dy have an owner or not
if
not
dx
.
owner
:
if
not
dx
.
owner
:
if
isinstance
(
dx
,
Constant
)
and
isinstance
(
dy
,
Constant
):
if
isinstance
(
dx
,
Constant
)
and
isinstance
(
dy
,
Constant
):
...
@@ -1575,10 +1612,12 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
...
@@ -1575,10 +1612,12 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
# Validate that each xs[i], ys[i] pair represents the same computation
# Validate that each xs[i], ys[i] pair represents the same computation
for
i
in
range
(
len
(
xs
)):
for
i
in
range
(
len
(
xs
)):
if
xs
[
i
]
.
owner
:
x_i
:
Variable
=
cast
(
Variable
,
xs
[
i
])
if
x_i
.
owner
:
y_i
:
Variable
=
cast
(
Variable
,
ys
[
i
])
# The case where pairs of x[i]s and y[i]s don't both have an owner
# The case where pairs of x[i]s and y[i]s don't both have an owner
# have already been addressed.
# have already been addressed.
is_equal
=
compare_nodes
(
x
s
[
i
]
.
owner
,
ys
[
i
]
.
owner
,
common
,
different
)
is_equal
=
compare_nodes
(
x
_i
.
owner
,
y_i
.
owner
,
common
,
different
)
if
not
is_equal
:
if
not
is_equal
:
return
False
return
False
...
@@ -1605,7 +1644,7 @@ def get_var_by_name(
...
@@ -1605,7 +1644,7 @@ def get_var_by_name(
"""
"""
from
aesara.graph.op
import
HasInnerGraph
from
aesara.graph.op
import
HasInnerGraph
def
expand
(
r
):
def
expand
(
r
)
->
Optional
[
List
[
Variable
]]
:
if
r
.
owner
:
if
r
.
owner
:
res
=
list
(
r
.
owner
.
inputs
)
res
=
list
(
r
.
owner
.
inputs
)
...
...
aesara/graph/utils.py
浏览文件 @
f9618a62
...
@@ -8,7 +8,7 @@ from typing import List, Optional, Sequence, Tuple
...
@@ -8,7 +8,7 @@ from typing import List, Optional, Sequence, Tuple
def
simple_extract_stack
(
def
simple_extract_stack
(
f
=
None
,
limit
:
Optional
[
int
]
=
None
,
skips
:
Optional
[
Sequence
[
str
]]
=
None
f
=
None
,
limit
:
Optional
[
int
]
=
None
,
skips
:
Optional
[
Sequence
[
str
]]
=
None
)
->
List
[
Tuple
[
Optional
[
str
],
int
,
str
,
str
]]:
)
->
List
[
Tuple
[
Optional
[
str
],
int
,
str
,
Optional
[
str
]
]]:
"""This is traceback.extract_stack from python 2.7 with this change:
"""This is traceback.extract_stack from python 2.7 with this change:
- Comment the update of the cache.
- Comment the update of the cache.
...
@@ -28,14 +28,12 @@ def simple_extract_stack(
...
@@ -28,14 +28,12 @@ def simple_extract_stack(
skips
=
[]
skips
=
[]
if
f
is
None
:
if
f
is
None
:
try
:
f
=
sys
.
_getframe
()
.
f_back
raise
ZeroDivisionError
except
ZeroDivisionError
:
f
=
sys
.
exc_info
()[
2
]
.
tb_frame
.
f_back
if
limit
is
None
:
if
limit
is
None
:
if
hasattr
(
sys
,
"tracebacklimit"
):
if
hasattr
(
sys
,
"tracebacklimit"
):
limit
=
sys
.
tracebacklimit
limit
=
sys
.
tracebacklimit
trace
:
List
[
Tuple
[
Optional
[
str
],
int
,
str
,
str
]]
=
[]
trace
:
List
[
Tuple
[
Optional
[
str
],
int
,
str
,
Optional
[
str
]
]]
=
[]
n
=
0
n
=
0
while
f
is
not
None
and
(
limit
is
None
or
n
<
limit
):
while
f
is
not
None
and
(
limit
is
None
or
n
<
limit
):
lineno
=
f
.
f_lineno
lineno
=
f
.
f_lineno
...
@@ -43,7 +41,7 @@ def simple_extract_stack(
...
@@ -43,7 +41,7 @@ def simple_extract_stack(
filename
=
co
.
co_filename
filename
=
co
.
co_filename
name
=
co
.
co_name
name
=
co
.
co_name
# linecache.checkcache(filename)
# linecache.checkcache(filename)
line
=
linecache
.
getline
(
filename
,
lineno
,
f
.
f_globals
)
line
:
Optional
[
str
]
=
linecache
.
getline
(
filename
,
lineno
,
f
.
f_globals
)
if
line
:
if
line
:
line
=
line
.
strip
()
line
=
line
.
strip
()
else
:
else
:
...
...
setup.cfg
浏览文件 @
f9618a62
...
@@ -115,14 +115,6 @@ check_untyped_defs = False
...
@@ -115,14 +115,6 @@ check_untyped_defs = False
ignore_errors = True
ignore_errors = True
check_untyped_defs = False
check_untyped_defs = False
[mypy-aesara.graph.utils]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.basic]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.op]
[mypy-aesara.graph.op]
ignore_errors = True
ignore_errors = True
check_untyped_defs = False
check_untyped_defs = False
...
@@ -179,6 +171,10 @@ check_untyped_defs = False
...
@@ -179,6 +171,10 @@ check_untyped_defs = False
ignore_errors = True
ignore_errors = True
check_untyped_defs = False
check_untyped_defs = False
[mypy-aesara.compile.builders]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.compile.sharedvalue]
[mypy-aesara.compile.sharedvalue]
ignore_errors = True
ignore_errors = True
check_untyped_defs = False
check_untyped_defs = False
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论