Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f9618a62
提交
f9618a62
authored
3月 01, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
3月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix typing issues in aesara.graph.basic
上级
90c4ce82
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
131 行增加
和
98 行删除
+131
-98
basic.py
aesara/graph/basic.py
+122
-83
utils.py
aesara/graph/utils.py
+5
-7
setup.cfg
setup.cfg
+4
-8
没有找到文件。
aesara/graph/basic.py
浏览文件 @
f9618a62
...
...
@@ -4,15 +4,17 @@ from collections import deque
from
copy
import
copy
from
itertools
import
count
from
typing
import
(
TYPE_CHECKING
,
Callable
,
Collection
,
Deque
,
Dict
,
Generator
,
Hashable
,
Iterable
,
Iterator
,
List
,
Optional
,
Reversible
,
Sequence
,
Set
,
Tuple
,
...
...
@@ -36,9 +38,13 @@ from aesara.graph.utils import (
from
aesara.misc.ordered_set
import
OrderedSet
T
=
TypeVar
(
"T"
)
if
TYPE_CHECKING
:
from
aesara.graph.type
import
Type
T
=
TypeVar
(
"T"
,
bound
=
"Node"
)
NoParams
=
object
()
NodeAndChildren
=
Tuple
[
T
,
Optional
[
Iterable
[
T
]]]
class
Node
(
MetaObject
):
...
...
@@ -49,6 +55,8 @@ class Node(MetaObject):
keeps track of its parents via `Variable.owner` / `Apply.inputs`.
"""
type
:
"Type"
name
:
Optional
[
str
]
def
get_parents
(
self
):
"""
...
...
@@ -377,7 +385,23 @@ class Variable(Node):
# __slots__ = ['type', 'owner', 'index', 'name']
__count__
=
count
(
0
)
def
__init__
(
self
,
type
,
owner
=
None
,
index
=
None
,
name
=
None
):
_owner
:
Optional
[
Apply
]
@property
def
owner
(
self
)
->
Optional
[
Apply
]:
return
self
.
_owner
@owner.setter
def
owner
(
self
,
value
)
->
None
:
self
.
_owner
=
value
def
__init__
(
self
,
type
,
owner
:
Optional
[
Apply
]
=
None
,
index
:
Optional
[
int
]
=
None
,
name
:
Optional
[
str
]
=
None
,
):
super
()
.
__init__
()
self
.
tag
=
ValidatingScratchpad
(
"test_value"
,
type
.
filter
)
...
...
@@ -386,7 +410,7 @@ class Variable(Node):
if
owner
is
not
None
and
not
isinstance
(
owner
,
Apply
):
raise
TypeError
(
"owner must be an Apply instance"
,
owner
)
self
.
owner
=
owner
self
.
_
owner
=
owner
if
index
is
not
None
and
not
isinstance
(
index
,
int
):
raise
TypeError
(
"index must be an int"
,
index
)
...
...
@@ -607,19 +631,15 @@ class Constant(Variable):
"""Return `self`, because there's no reason to clone a constant."""
return
self
def
__set_owner
(
self
,
value
):
"""Prevent the :prop:`owner` property from being set.
Raises
------
ValueError
If `value` is not ``None``.
@property
def
owner
(
self
)
->
None
:
return
None
"""
@owner.setter
def
owner
(
self
,
value
)
->
None
:
if
value
is
not
None
:
raise
ValueError
(
"Constant instances cannot have an owner."
)
owner
=
property
(
lambda
self
:
None
,
__set_owner
)
value
=
property
(
lambda
self
:
self
.
data
,
doc
=
"read-only data access method"
)
# index is not defined, because the `owner` attribute must necessarily be None
...
...
@@ -627,34 +647,30 @@ class Constant(Variable):
def
walk
(
nodes
:
Iterable
[
T
],
expand
:
Callable
[[
T
],
Optional
[
Sequenc
e
[
T
]]],
expand
:
Callable
[[
T
],
Optional
[
Iterabl
e
[
T
]]],
bfs
:
bool
=
True
,
return_children
:
bool
=
False
,
hash_fn
:
Callable
[[
T
],
Hashable
]
=
id
,
)
->
Generator
[
Union
[
T
,
Tuple
[
T
,
Optional
[
Sequence
[
T
]]]
],
None
,
None
]:
"""Walk through a graph, either breadth- or depth-first.
hash_fn
:
Callable
[[
T
],
int
]
=
id
,
)
->
Generator
[
Union
[
T
,
NodeAndChildren
],
None
,
None
]:
r
"""Walk through a graph, either breadth- or depth-first.
Parameters
----------
nodes
: deque
nodes
The nodes from which to start walking.
expand
: callable
expand
A callable that is applied to each node in `nodes`, the results of
which are either new nodes to visit or ``None``.
bfs
: bool
bfs
If ``True``, breath first search is used; otherwise, depth first
search.
return_children
: bool
return_children
If ``True``, each output node will be accompanied by the output of
`expand` (i.e. the corresponding child nodes).
hash_fn
: callable
hash_fn
The function used to produce hashes of the elements in `nodes`.
The default is ``id``.
Yields
------
nodes
Notes
-----
A node will appear at most once in the return value, even if it
...
...
@@ -664,7 +680,7 @@ def walk(
nodes
=
deque
(
nodes
)
rval_set
:
Set
[
T
]
=
set
()
rval_set
:
Set
[
int
]
=
set
()
nodes_pop
:
Callable
[[],
T
]
if
bfs
:
...
...
@@ -675,13 +691,13 @@ def walk(
while
nodes
:
node
:
T
=
nodes_pop
()
node_hash
:
Hashable
=
hash_fn
(
node
)
node_hash
:
int
=
hash_fn
(
node
)
if
node_hash
not
in
rval_set
:
rval_set
.
add
(
node_hash
)
new_nodes
:
Optional
[
Sequenc
e
[
T
]]
=
expand
(
node
)
new_nodes
:
Optional
[
Iterabl
e
[
T
]]
=
expand
(
node
)
if
return_children
:
yield
node
,
new_nodes
...
...
@@ -693,7 +709,7 @@ def walk(
def
ancestors
(
graphs
:
Iterable
[
Variable
],
blockers
:
Collection
[
Variable
]
=
None
graphs
:
Iterable
[
Variable
],
blockers
:
Optional
[
Collection
[
Variable
]
]
=
None
)
->
Generator
[
Variable
,
None
,
None
]:
r"""Return the variables that contribute to those in given graphs (inclusive).
...
...
@@ -714,15 +730,15 @@ def ancestors(
"""
def
expand
(
r
)
:
def
expand
(
r
:
Variable
)
->
Optional
[
Iterator
[
Variable
]]
:
if
r
.
owner
and
(
not
blockers
or
r
not
in
blockers
):
return
reversed
(
r
.
owner
.
inputs
)
yield
from
walk
(
graphs
,
expand
,
False
)
yield
from
cast
(
Generator
[
Variable
,
None
,
None
],
walk
(
graphs
,
expand
,
False
)
)
def
graph_inputs
(
graphs
:
Iterable
[
Variable
],
blockers
:
Collection
[
Variable
]
=
None
graphs
:
Iterable
[
Variable
],
blockers
:
Optional
[
Collection
[
Variable
]
]
=
None
)
->
Generator
[
Variable
,
None
,
None
]:
r"""Return the inputs required to compute the given Variables.
...
...
@@ -751,26 +767,25 @@ def vars_between(
Parameters
----------
ins
: list
ins
Input `Variable`\s.
outs
: list
outs
Output `Variable`\s.
Yields
------
`Variable`\s
The `Variable`\s that are involved in the subgraph that lies
between `ins` and `outs`. This includes `ins`, `outs`,
``orphans_between(ins, outs)`` and all values of all intermediary steps from
`ins` to `outs`.
The `Variable`\s that are involved in the subgraph that lies
between `ins` and `outs`. This includes `ins`, `outs`,
``orphans_between(ins, outs)`` and all values of all intermediary steps from
`ins` to `outs`.
"""
def
expand
(
r
)
:
def
expand
(
r
:
Variable
)
->
Optional
[
Iterable
[
Variable
]]
:
if
r
.
owner
and
r
not
in
ins
:
return
reversed
(
r
.
owner
.
inputs
+
r
.
owner
.
outputs
)
yield
from
walk
(
outs
,
expand
)
yield
from
cast
(
Generator
[
Variable
,
None
,
None
],
walk
(
outs
,
expand
)
)
def
orphans_between
(
...
...
@@ -862,16 +877,18 @@ def clone(
if
copy_orphans
is
None
:
copy_orphans
=
copy_inputs
equiv
=
clone_get_equiv
(
inputs
,
outputs
,
copy_inputs
,
copy_orphans
)
return
[
equiv
[
input
]
for
input
in
inputs
],
[
equiv
[
output
]
for
output
in
outputs
]
return
[
cast
(
Variable
,
equiv
[
input
])
for
input
in
inputs
],
[
cast
(
Variable
,
equiv
[
output
])
for
output
in
outputs
]
def
clone_get_equiv
(
inputs
:
List
[
Variable
],
outputs
:
List
[
Variable
],
inputs
:
Sequence
[
Variable
],
outputs
:
Sequence
[
Variable
],
copy_inputs
:
bool
=
True
,
copy_orphans
:
bool
=
True
,
memo
:
Optional
[
Dict
[
Variable
,
Variabl
e
]]
=
None
,
)
->
Dict
[
Variable
,
Variabl
e
]:
memo
:
Optional
[
Dict
[
Node
,
Nod
e
]]
=
None
,
)
->
Dict
[
Node
,
Nod
e
]:
"""
Return a dictionary that maps from `Variable` and `Apply` nodes in the
original graph to a new node (a clone) in a new graph.
...
...
@@ -934,11 +951,13 @@ def clone_get_equiv(
def
clone_replace
(
output
:
Collection
[
Variable
],
replace
:
Optional
[
Dict
[
Variable
,
Variable
]]
=
None
,
output
:
List
[
Variable
],
replace
:
Optional
[
Union
[
Iterable
[
Tuple
[
Variable
,
Variable
]],
Dict
[
Variable
,
Variable
]]
]
=
None
,
strict
:
bool
=
True
,
share_inputs
:
bool
=
True
,
)
->
Collection
[
Variable
]:
)
->
List
[
Variable
]:
"""Clone a graph and replace subgraphs within it.
It returns a copy of the initial subgraph with the corresponding
...
...
@@ -959,6 +978,7 @@ def clone_replace(
"""
from
aesara.compile.function.pfunc
import
rebuild_collect_shared
items
:
Union
[
List
[
Tuple
[
Variable
,
Variable
]],
Tuple
[
Tuple
[
Variable
,
Variable
],
...
]]
if
isinstance
(
replace
,
dict
):
items
=
list
(
replace
.
items
())
elif
isinstance
(
replace
,
(
list
,
tuple
)):
...
...
@@ -984,13 +1004,15 @@ def clone_replace(
_outs
,
[],
new_replace
,
[],
strict
,
share_inputs
)
return
outs
return
cast
(
List
[
Variable
],
outs
)
def
general_toposort
(
outputs
:
Iterable
[
T
],
deps
:
Callable
[[
T
],
Union
[
OrderedSet
,
List
[
T
]]],
compute_deps_cache
:
Optional
[
Callable
[[
T
],
Union
[
OrderedSet
,
List
[
T
]]]]
=
None
,
compute_deps_cache
:
Optional
[
Callable
[[
T
],
Optional
[
Union
[
OrderedSet
,
List
[
T
]]]]
]
=
None
,
deps_cache
:
Optional
[
Dict
[
T
,
List
[
T
]]]
=
None
,
clients
:
Optional
[
Dict
[
T
,
List
[
T
]]]
=
None
,
)
->
List
[
T
]:
...
...
@@ -1030,7 +1052,7 @@ def general_toposort(
if
deps_cache
is
None
:
deps_cache
=
{}
def
compute_deps_cache
(
io
):
def
_
compute_deps_cache
(
io
):
if
io
not
in
deps_cache
:
d
=
deps
(
io
)
...
...
@@ -1048,23 +1070,27 @@ def general_toposort(
else
:
return
deps_cache
[
io
]
else
:
_compute_deps_cache
=
compute_deps_cache
if
deps_cache
is
None
:
raise
ValueError
(
"deps_cache cannot be None"
)
search_res
:
List
[
Tuple
[
T
,
Optional
[
List
[
T
]]]]
=
list
(
walk
(
outputs
,
compute_deps_cache
,
bfs
=
False
,
return_children
=
True
)
search_res
:
List
[
NodeAndChildren
]
=
cast
(
List
[
NodeAndChildren
],
list
(
walk
(
outputs
,
_compute_deps_cache
,
bfs
=
False
,
return_children
=
True
)),
)
_clients
:
Dict
[
T
,
List
[
T
]]
=
{}
sources
:
Deque
[
T
]
=
deque
()
search_res_len
:
int
=
0
for
node
,
children
in
search_res
:
for
s
node
,
children
in
search_res
:
search_res_len
+=
1
if
children
:
for
child
in
children
:
_clients
.
setdefault
(
child
,
[])
.
append
(
node
)
if
not
deps_cache
.
get
(
node
):
sources
.
append
(
node
)
_clients
.
setdefault
(
child
,
[])
.
append
(
s
node
)
if
not
deps_cache
.
get
(
s
node
):
sources
.
append
(
s
node
)
if
clients
is
not
None
:
clients
.
update
(
_clients
)
...
...
@@ -1090,7 +1116,7 @@ def general_toposort(
def
io_toposort
(
inputs
:
Iterable
[
Variable
],
outputs
:
Itera
ble
[
Variable
],
outputs
:
Reversi
ble
[
Variable
],
orderings
:
Optional
[
Dict
[
Apply
,
List
[
Apply
]]]
=
None
,
clients
:
Optional
[
Dict
[
Variable
,
List
[
Variable
]]]
=
None
,
)
->
List
[
Apply
]:
...
...
@@ -1301,8 +1327,8 @@ def as_string(
orph
=
list
(
orphans_between
(
i
,
outputs
))
multi
:
Set
=
set
()
seen
:
Set
=
set
()
multi
=
set
()
seen
=
set
()
for
output
in
outputs
:
op
=
output
.
owner
if
op
in
seen
:
...
...
@@ -1318,11 +1344,11 @@ def as_string(
multi
.
add
(
op2
)
else
:
seen
.
add
(
input
.
owner
)
multi
:
Se
t
=
[
x
for
x
in
multi
]
multi
_lis
t
=
[
x
for
x
in
multi
]
done
:
Set
=
set
()
def
multi_index
(
x
):
return
multi
.
index
(
x
)
+
1
return
multi
_list
.
index
(
x
)
+
1
def
describe
(
r
):
if
r
.
owner
is
not
None
and
r
not
in
i
and
r
not
in
orph
:
...
...
@@ -1337,7 +1363,7 @@ def as_string(
else
:
done
.
add
(
op
)
s
=
node_formatter
(
op
,
[
describe
(
input
)
for
input
in
op
.
inputs
])
if
op
in
multi
:
if
op
in
multi
_list
:
return
f
"*{multi_index(op)} -> {s}"
else
:
return
s
...
...
@@ -1380,14 +1406,21 @@ def list_of_nodes(
Output `Variable`\s.
"""
def
expand
(
o
:
Apply
)
->
List
[
Apply
]:
return
[
inp
.
owner
for
inp
in
o
.
inputs
if
inp
.
owner
and
not
any
(
i
in
inp
.
owner
.
outputs
for
i
in
inputs
)
]
return
list
(
walk
(
[
o
.
owner
for
o
in
outputs
],
lambda
o
:
[
inp
.
owner
for
inp
in
o
.
inputs
if
inp
.
owner
and
not
any
(
i
in
inp
.
owner
.
outputs
for
i
in
inputs
)
],
cast
(
Iterable
[
Apply
],
walk
(
[
o
.
owner
for
o
in
outputs
if
o
.
owner
],
expand
,
),
)
)
...
...
@@ -1423,7 +1456,12 @@ def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool:
return
False
def
equal_computations
(
xs
,
ys
,
in_xs
=
None
,
in_ys
=
None
):
def
equal_computations
(
xs
:
List
[
Union
[
np
.
ndarray
,
Variable
]],
ys
:
List
[
Union
[
np
.
ndarray
,
Variable
]],
in_xs
:
Optional
[
List
[
Variable
]]
=
None
,
in_ys
:
Optional
[
List
[
Variable
]]
=
None
,
)
->
bool
:
"""Checks if Aesara graphs represent the same computations.
The two lists `xs`, `ys` should have the same number of entries. The
...
...
@@ -1458,22 +1496,20 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
for
x
,
y
in
zip
(
xs
,
ys
):
if
not
isinstance
(
x
,
Variable
)
and
not
isinstance
(
y
,
Variable
):
return
np
.
array_equal
(
x
,
y
)
return
cast
(
bool
,
np
.
array_equal
(
x
,
y
)
)
if
not
isinstance
(
x
,
Variable
):
if
isinstance
(
y
,
Constant
):
return
np
.
array_equal
(
y
.
data
,
x
)
return
cast
(
bool
,
np
.
array_equal
(
y
.
data
,
x
)
)
return
False
if
not
isinstance
(
y
,
Variable
):
if
isinstance
(
x
,
Constant
):
return
np
.
array_equal
(
x
.
data
,
y
)
return
False
if
not
isinstance
(
y
,
Variable
):
return
cast
(
bool
,
np
.
array_equal
(
x
.
data
,
y
))
return
False
if
x
.
owner
and
not
y
.
owner
:
return
False
if
y
.
owner
and
not
x
.
owner
:
return
False
if
x
.
owner
:
# Check above tell that y.owner eval to True too.
if
x
.
owner
and
y
.
owner
:
if
x
.
owner
.
outputs
.
index
(
x
)
!=
y
.
owner
.
outputs
.
index
(
y
):
return
False
if
x
not
in
in_xs
and
not
(
y
.
type
.
in_same_class
(
x
.
type
)):
...
...
@@ -1487,8 +1523,9 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
return
False
common
=
set
(
zip
(
in_xs
,
in_ys
))
different
=
set
()
different
:
Set
[
Tuple
[
Variable
,
Variable
]]
=
set
()
for
dx
,
dy
in
zip
(
xs
,
ys
):
assert
isinstance
(
dx
,
Variable
)
# We checked above that both dx and dy have an owner or not
if
not
dx
.
owner
:
if
isinstance
(
dx
,
Constant
)
and
isinstance
(
dy
,
Constant
):
...
...
@@ -1575,10 +1612,12 @@ def equal_computations(xs, ys, in_xs=None, in_ys=None):
# Validate that each xs[i], ys[i] pair represents the same computation
for
i
in
range
(
len
(
xs
)):
if
xs
[
i
]
.
owner
:
x_i
:
Variable
=
cast
(
Variable
,
xs
[
i
])
if
x_i
.
owner
:
y_i
:
Variable
=
cast
(
Variable
,
ys
[
i
])
# The case where pairs of x[i]s and y[i]s don't both have an owner
# have already been addressed.
is_equal
=
compare_nodes
(
x
s
[
i
]
.
owner
,
ys
[
i
]
.
owner
,
common
,
different
)
is_equal
=
compare_nodes
(
x
_i
.
owner
,
y_i
.
owner
,
common
,
different
)
if
not
is_equal
:
return
False
...
...
@@ -1605,7 +1644,7 @@ def get_var_by_name(
"""
from
aesara.graph.op
import
HasInnerGraph
def
expand
(
r
):
def
expand
(
r
)
->
Optional
[
List
[
Variable
]]
:
if
r
.
owner
:
res
=
list
(
r
.
owner
.
inputs
)
...
...
aesara/graph/utils.py
浏览文件 @
f9618a62
...
...
@@ -8,7 +8,7 @@ from typing import List, Optional, Sequence, Tuple
def
simple_extract_stack
(
f
=
None
,
limit
:
Optional
[
int
]
=
None
,
skips
:
Optional
[
Sequence
[
str
]]
=
None
)
->
List
[
Tuple
[
Optional
[
str
],
int
,
str
,
str
]]:
)
->
List
[
Tuple
[
Optional
[
str
],
int
,
str
,
Optional
[
str
]
]]:
"""This is traceback.extract_stack from python 2.7 with this change:
- Comment the update of the cache.
...
...
@@ -28,14 +28,12 @@ def simple_extract_stack(
skips
=
[]
if
f
is
None
:
try
:
raise
ZeroDivisionError
except
ZeroDivisionError
:
f
=
sys
.
exc_info
()[
2
]
.
tb_frame
.
f_back
f
=
sys
.
_getframe
()
.
f_back
if
limit
is
None
:
if
hasattr
(
sys
,
"tracebacklimit"
):
limit
=
sys
.
tracebacklimit
trace
:
List
[
Tuple
[
Optional
[
str
],
int
,
str
,
str
]]
=
[]
trace
:
List
[
Tuple
[
Optional
[
str
],
int
,
str
,
Optional
[
str
]
]]
=
[]
n
=
0
while
f
is
not
None
and
(
limit
is
None
or
n
<
limit
):
lineno
=
f
.
f_lineno
...
...
@@ -43,7 +41,7 @@ def simple_extract_stack(
filename
=
co
.
co_filename
name
=
co
.
co_name
# linecache.checkcache(filename)
line
=
linecache
.
getline
(
filename
,
lineno
,
f
.
f_globals
)
line
:
Optional
[
str
]
=
linecache
.
getline
(
filename
,
lineno
,
f
.
f_globals
)
if
line
:
line
=
line
.
strip
()
else
:
...
...
setup.cfg
浏览文件 @
f9618a62
...
...
@@ -115,14 +115,6 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.utils]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.basic]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.graph.op]
ignore_errors = True
check_untyped_defs = False
...
...
@@ -179,6 +171,10 @@ check_untyped_defs = False
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.compile.builders]
ignore_errors = True
check_untyped_defs = False
[mypy-aesara.compile.sharedvalue]
ignore_errors = True
check_untyped_defs = False
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论