Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f0392db7
提交
f0392db7
authored
3月 12, 2021
作者:
LegrandNico
提交者:
Thomas Wiecki
3月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix type hint errors (272->168)
上级
f26a5086
隐藏空白字符变更
内嵌
并排
正在显示
28 个修改的文件
包含
143 行增加
和
99 行删除
+143
-99
_version.py
aesara/_version.py
+3
-3
types.py
aesara/compile/function/types.py
+2
-1
mode.py
aesara/compile/mode.py
+3
-0
ops.py
aesara/compile/ops.py
+8
-7
profiling.py
aesara/compile/profiling.py
+6
-5
configparser.py
aesara/configparser.py
+3
-2
basic_ops.py
aesara/gpuarray/basic_ops.py
+2
-1
pathparse.py
aesara/gpuarray/pathparse.py
+2
-1
basic.py
aesara/graph/basic.py
+14
-13
op.py
aesara/graph/op.py
+7
-5
optdb.py
aesara/graph/optdb.py
+8
-2
unify.py
aesara/graph/unify.py
+22
-22
utils.py
aesara/graph/utils.py
+2
-1
basic.py
aesara/link/c/basic.py
+2
-1
cmodule.py
aesara/link/c/cmodule.py
+8
-7
interface.py
aesara/link/c/interface.py
+3
-3
utils.py
aesara/link/utils.py
+5
-1
check_duplicate_key.py
aesara/misc/check_duplicate_key.py
+5
-4
multinomial.py
aesara/sandbox/multinomial.py
+2
-1
basic.py
aesara/scalar/basic.py
+4
-4
__init__.py
aesara/tensor/__init__.py
+5
-3
basic.py
aesara/tensor/basic.py
+2
-1
blas.py
aesara/tensor/blas.py
+2
-1
elemwise.py
aesara/tensor/elemwise.py
+10
-2
nlinalg.py
aesara/tensor/nlinalg.py
+2
-1
blocksparse.py
aesara/tensor/nnet/blocksparse.py
+4
-2
shape.py
aesara/tensor/shape.py
+4
-3
utils.py
aesara/utils.py
+3
-2
没有找到文件。
aesara/_version.py
浏览文件 @
f0392db7
...
...
@@ -14,7 +14,7 @@ import os
import
re
import
subprocess
import
sys
from
typing
import
Dict
def
get_keywords
():
"""Get the keywords needed to look up the version information."""
...
...
@@ -51,8 +51,8 @@ class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario."""
LONG_VERSION_PY
=
{}
HANDLERS
=
{}
LONG_VERSION_PY
:
Dict
=
{}
HANDLERS
:
Dict
=
{}
def
register_vcs_handler
(
vcs
,
method
):
# decorator
...
...
aesara/compile/function/types.py
浏览文件 @
f0392db7
...
...
@@ -11,6 +11,7 @@ import pickle
import
time
import
warnings
from
itertools
import
chain
from
typing
import
List
import
numpy
as
np
...
...
@@ -1877,7 +1878,7 @@ def _constructor_FunctionMaker(kwargs):
return
None
__checkers
=
[]
__checkers
:
List
=
[]
def
check_equal
(
x
,
y
):
...
...
aesara/compile/mode.py
浏览文件 @
f0392db7
...
...
@@ -5,6 +5,7 @@ WRITEME
import
logging
import
warnings
from
typing
import
Tuple
,
Union
import
aesara
from
aesara.compile.function.types
import
Supervisor
...
...
@@ -252,6 +253,8 @@ optdb.register("add_destroy_handler", AddDestroyHandler(), 49.5, "fast_run", "in
# final pass just to make sure
optdb
.
register
(
"merge3"
,
MergeOptimizer
(),
100
,
"fast_run"
,
"merge"
)
_tags
:
Union
[
Tuple
[
str
,
str
],
Tuple
]
if
config
.
check_stack_trace
in
[
"raise"
,
"warn"
,
"log"
]:
_tags
=
(
"fast_run"
,
"fast_compile"
)
...
...
aesara/compile/ops.py
浏览文件 @
f0392db7
...
...
@@ -8,6 +8,7 @@ help make new Ops more rapidly.
import
copy
import
pickle
import
warnings
from
typing
import
Dict
,
Tuple
from
aesara.graph.basic
import
Apply
from
aesara.graph.op
import
COp
,
Op
...
...
@@ -42,9 +43,9 @@ class ViewOp(COp):
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version
=
{}
__props__
=
()
_f16_ok
=
True
c_code_and_version
:
Dict
=
{}
__props__
:
Tuple
=
()
_f16_ok
:
bool
=
True
def
make_node
(
self
,
x
):
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
...
...
@@ -148,11 +149,11 @@ class DeepCopyOp(COp):
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version
=
{}
c_code_and_version
:
Dict
=
{}
check_input
=
False
__props__
=
()
_f16_ok
=
True
check_input
:
bool
=
False
__props__
:
Tuple
=
()
_f16_ok
:
bool
=
True
def
__init__
(
self
):
pass
...
...
aesara/compile/profiling.py
浏览文件 @
f0392db7
...
...
@@ -17,6 +17,7 @@ import sys
import
time
import
warnings
from
collections
import
defaultdict
from
typing
import
Dict
,
List
import
numpy
as
np
...
...
@@ -37,7 +38,7 @@ total_fct_exec_time = 0.0
total_graph_opt_time
=
0.0
total_time_linker
=
0.0
_atexit_print_list
=
[]
_atexit_print_list
:
List
=
[]
_atexit_registered
=
False
...
...
@@ -242,15 +243,15 @@ class ProfileStats:
# pretty string to print in summary, to identify this output
#
variable_shape
=
{}
variable_shape
:
Dict
=
{}
# Variable -> shapes
#
variable_strides
=
{}
variable_strides
:
Dict
=
{}
# Variable -> strides
#
variable_offset
=
{}
variable_offset
:
Dict
=
{}
# Variable -> offset
#
...
...
@@ -270,7 +271,7 @@ class ProfileStats:
linker_node_make_thunks
=
0.0
linker_make_thunk_time
=
{}
linker_make_thunk_time
:
Dict
=
{}
line_width
=
config
.
profiling__output_line_width
...
...
aesara/configparser.py
浏览文件 @
f0392db7
...
...
@@ -13,6 +13,7 @@ from configparser import (
)
from
functools
import
wraps
from
io
import
StringIO
from
typing
import
Optional
from
aesara.utils
import
deprecated
,
hash_from_code
...
...
@@ -94,7 +95,7 @@ class AesaraConfigParser:
self
.
_flags_dict
=
flags_dict
self
.
_aesara_cfg
=
aesara_cfg
self
.
_aesara_raw_cfg
=
aesara_raw_cfg
self
.
_config_var_dict
=
{}
self
.
_config_var_dict
:
typing
.
Dict
=
{}
super
()
.
__init__
()
def
__str__
(
self
,
print_doc
=
True
):
...
...
@@ -325,7 +326,7 @@ class ConfigParam:
return
self
.
_apply
(
value
)
return
value
def
validate
(
self
,
value
)
->
None
:
def
validate
(
self
,
value
)
->
Optional
[
bool
]
:
"""Validates that a parameter values falls into a supported set or range.
Raises
...
...
aesara/gpuarray/basic_ops.py
浏览文件 @
f0392db7
...
...
@@ -2,6 +2,7 @@ import copy
import
os
import
re
from
collections
import
deque
from
typing
import
Union
import
numpy
as
np
...
...
@@ -306,7 +307,7 @@ class GpuKernelBase:
"""
params_type
=
gpu_context_type
params_type
:
Union
[
ParamsType
,
gpu_context_type
]
=
gpu_context_type
def
get_params
(
self
,
node
):
# Default implementation, suitable for most sub-classes.
...
...
aesara/gpuarray/pathparse.py
浏览文件 @
f0392db7
import
os
import
sys
from
typing
import
Set
class
PathParser
:
...
...
@@ -25,7 +26,7 @@ class PathParser:
"""
paths
=
set
()
paths
:
Set
=
set
()
def
_add
(
self
,
path
):
path
=
path
.
strip
()
...
...
aesara/graph/basic.py
浏览文件 @
f0392db7
...
...
@@ -571,7 +571,7 @@ class Variable(Node):
return
d
# refer to doc in nodes_constructed.
construction_observers
=
[]
construction_observers
:
List
=
[]
@classmethod
def
append_construction_observer
(
cls
,
observer
):
...
...
@@ -700,10 +700,11 @@ def walk(
rval_set
:
Set
[
T
]
=
set
()
nodes_pop
:
Callable
[[],
T
]
if
bfs
:
nodes_pop
:
Callable
[[],
T
]
=
nodes
.
popleft
nodes_pop
=
nodes
.
popleft
else
:
nodes_pop
:
Callable
[[],
T
]
=
nodes
.
pop
nodes_pop
=
nodes
.
pop
while
nodes
:
node
:
T
=
nodes_pop
()
...
...
@@ -714,7 +715,7 @@ def walk(
rval_set
.
add
(
node_hash
)
new_nodes
:
Sequence
[
T
]
=
expand
(
node
)
new_nodes
:
Optional
[
Sequence
[
T
]
]
=
expand
(
node
)
if
return_children
:
yield
node
,
new_nodes
...
...
@@ -862,8 +863,8 @@ def applys_between(
def
clone
(
inputs
:
Collection
[
Variable
],
outputs
:
Collection
[
Variable
],
inputs
:
List
[
Variable
],
outputs
:
List
[
Variable
],
copy_inputs
:
bool
=
True
,
copy_orphans
:
Optional
[
bool
]
=
None
,
)
->
Tuple
[
Collection
[
Variable
],
Collection
[
Variable
]]:
...
...
@@ -902,8 +903,8 @@ def clone(
def
clone_get_equiv
(
inputs
:
Collection
[
Variable
],
outputs
:
Collection
[
Variable
],
inputs
:
List
[
Variable
],
outputs
:
List
[
Variable
],
copy_inputs
:
bool
=
True
,
copy_orphans
:
bool
=
True
,
memo
:
Optional
[
Dict
[
Variable
,
Variable
]]
=
None
,
...
...
@@ -1171,7 +1172,7 @@ def io_toposort(
compute_deps
=
None
compute_deps_cache
=
None
iset
=
set
(
inputs
)
deps_cache
=
{}
deps_cache
:
Dict
=
{}
if
not
orderings
:
# ordering can be None or empty dict
# Specialized function that is faster when no ordering.
...
...
@@ -1345,8 +1346,8 @@ def as_string(
orph
=
list
(
orphans_between
(
i
,
outputs
))
multi
=
set
()
seen
=
set
()
multi
:
Set
=
set
()
seen
:
Set
=
set
()
for
output
in
outputs
:
op
=
output
.
owner
if
op
in
seen
:
...
...
@@ -1362,8 +1363,8 @@ def as_string(
multi
.
add
(
op2
)
else
:
seen
.
add
(
input
.
owner
)
multi
=
[
x
for
x
in
multi
]
done
=
set
()
multi
:
Set
=
[
x
for
x
in
multi
]
done
:
Set
=
set
()
def
multi_index
(
x
):
return
multi
.
index
(
x
)
+
1
...
...
aesara/graph/op.py
浏览文件 @
f0392db7
...
...
@@ -157,7 +157,7 @@ class Op(MetaObject):
"""
default_output
=
None
default_output
:
Optional
[
int
]
=
None
"""
An `int` that specifies which output `Op.__call__` should return. If
`None`, then all outputs are returned.
...
...
@@ -852,7 +852,7 @@ def lquote_macro(txt: Text) -> Text:
return
"
\n
"
.
join
(
res
)
def
get_sub_macros
(
sub
:
Dict
[
Text
,
Text
])
->
Tuple
[
Text
]:
def
get_sub_macros
(
sub
:
Dict
[
Text
,
Text
])
->
Union
[
Tuple
[
Text
],
Tuple
[
Text
,
Text
]
]:
define_macros
=
[]
undef_macros
=
[]
define_macros
.
append
(
f
"#define FAIL {lquote_macro(sub['fail'])}"
)
...
...
@@ -864,7 +864,9 @@ def get_sub_macros(sub: Dict[Text, Text]) -> Tuple[Text]:
return
"
\n
"
.
join
(
define_macros
),
"
\n
"
.
join
(
undef_macros
)
def
get_io_macros
(
inputs
:
List
[
Text
],
outputs
:
List
[
Text
])
->
Tuple
[
List
[
Text
]]:
def
get_io_macros
(
inputs
:
List
[
Text
],
outputs
:
List
[
Text
]
)
->
Union
[
Tuple
[
List
[
Text
]],
Tuple
[
str
,
str
]]:
define_macros
=
[]
undef_macros
=
[]
...
...
@@ -1023,7 +1025,7 @@ class ExternalCOp(COp):
f
"No valid section marker was found in file {func_files[i]}"
)
def
__get_op_params
(
self
)
->
List
[
Text
]:
def
__get_op_params
(
self
)
->
Union
[
List
[
Text
],
List
[
Tuple
[
str
,
Any
]]
]:
"""Construct name, value pairs that will be turned into macros for use within the `Op`'s code.
The names must be strings that are not a C keyword and the
...
...
@@ -1130,7 +1132,7 @@ class ExternalCOp(COp):
def
get_c_macros
(
self
,
node
:
Apply
,
name
:
Text
,
check_input
:
Optional
[
bool
]
=
None
)
->
Tuple
[
Text
]:
)
->
Union
[
Tuple
[
str
],
Tuple
[
str
,
str
]
]:
"Construct a pair of C ``#define`` and ``#undef`` code strings."
define_template
=
"#define
%
s
%
s"
undef_template
=
"#undef
%
s"
...
...
aesara/graph/optdb.py
浏览文件 @
f0392db7
...
...
@@ -2,6 +2,7 @@ import copy
import
math
import
sys
from
io
import
StringIO
from
typing
import
Dict
,
Optional
from
aesara.configdefaults
import
config
from
aesara.graph
import
opt
...
...
@@ -192,6 +193,7 @@ class Query:
self
.
exclude
=
exclude
or
OrderedSet
()
self
.
subquery
=
subquery
or
{}
self
.
position_cutoff
=
position_cutoff
self
.
name
:
Optional
[
str
]
=
None
if
extra_optimizations
is
None
:
extra_optimizations
=
[]
self
.
extra_optimizations
=
extra_optimizations
...
...
@@ -438,14 +440,18 @@ class LocalGroupDB(DB):
"""
def
__init__
(
self
,
apply_all_opts
=
False
,
profile
=
False
,
local_opt
=
opt
.
LocalOptGroup
self
,
apply_all_opts
:
bool
=
False
,
profile
:
bool
=
False
,
local_opt
=
opt
.
LocalOptGroup
,
):
super
()
.
__init__
()
self
.
failure_callback
=
None
self
.
apply_all_opts
=
apply_all_opts
self
.
profile
=
profile
self
.
__position__
=
{}
self
.
__position__
:
Dict
=
{}
self
.
local_opt
=
local_opt
self
.
__name__
:
str
=
""
def
register
(
self
,
name
,
obj
,
*
tags
,
**
kwargs
):
super
()
.
register
(
name
,
obj
,
*
tags
)
...
...
aesara/graph/unify.py
浏览文件 @
f0392db7
...
...
@@ -331,7 +331,7 @@ def unify_walk(a, b, U):
return
False
@comm_guard
(
FreeVariable
,
ANY_TYPE
)
# noqa
@comm_guard
(
FreeVariable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
fv
,
o
,
U
):
"""
FreeV is unified to BoundVariable(other_object).
...
...
@@ -341,7 +341,7 @@ def unify_walk(fv, o, U):
return
U
.
merge
(
v
,
fv
)
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
# noqa
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
bv
,
o
,
U
):
"""
The unification succeed iff BV.value == other_object.
...
...
@@ -353,7 +353,7 @@ def unify_walk(bv, o, U):
return
False
@comm_guard
(
OrVariable
,
ANY_TYPE
)
# noqa
@comm_guard
(
OrVariable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
ov
,
o
,
U
):
"""
The unification succeeds iff other_object in OrV.options.
...
...
@@ -366,7 +366,7 @@ def unify_walk(ov, o, U):
return
False
@comm_guard
(
NotVariable
,
ANY_TYPE
)
# noqa
@comm_guard
(
NotVariable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
nv
,
o
,
U
):
"""
The unification succeeds iff other_object not in NV.not_options.
...
...
@@ -379,7 +379,7 @@ def unify_walk(nv, o, U):
return
U
.
merge
(
v
,
nv
)
@comm_guard
(
FreeVariable
,
Variable
)
# noqa
@comm_guard
(
FreeVariable
,
Variable
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
fv
,
v
,
U
):
"""
Both variables are unified.
...
...
@@ -389,7 +389,7 @@ def unify_walk(fv, v, U):
return
U
.
merge
(
v
,
fv
)
@comm_guard
(
BoundVariable
,
Variable
)
# noqa
@comm_guard
(
BoundVariable
,
Variable
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
bv
,
v
,
U
):
"""
V is unified to BV.value.
...
...
@@ -398,7 +398,7 @@ def unify_walk(bv, v, U):
return
unify_walk
(
v
,
bv
.
value
,
U
)
@comm_guard
(
OrVariable
,
OrVariable
)
# noqa
@comm_guard
(
OrVariable
,
OrVariable
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
a
,
b
,
U
):
"""
OrV(list1) == OrV(list2) == OrV(intersection(list1, list2))
...
...
@@ -414,7 +414,7 @@ def unify_walk(a, b, U):
return
U
.
merge
(
v
,
a
,
b
)
@comm_guard
(
NotVariable
,
NotVariable
)
# noqa
@comm_guard
(
NotVariable
,
NotVariable
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
a
,
b
,
U
):
"""
NV(list1) == NV(list2) == NV(union(list1, list2))
...
...
@@ -425,7 +425,7 @@ def unify_walk(a, b, U):
return
U
.
merge
(
v
,
a
,
b
)
@comm_guard
(
OrVariable
,
NotVariable
)
# noqa
@comm_guard
(
OrVariable
,
NotVariable
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
o
,
n
,
U
):
r"""
OrV(list1) == NV(list2) == OrV(list1 \ list2)
...
...
@@ -441,7 +441,7 @@ def unify_walk(o, n, U):
return
U
.
merge
(
v
,
o
,
n
)
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
# noqa
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
vil
,
l
,
U
):
"""
Unifies VIL's inner Variable to OrV(list).
...
...
@@ -452,7 +452,7 @@ def unify_walk(vil, l, U):
return
unify_walk
(
v
,
ov
,
U
)
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
# noqa
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
l1
,
l2
,
U
):
"""
Tries to unify each corresponding pair of elements from l1 and l2.
...
...
@@ -467,7 +467,7 @@ def unify_walk(l1, l2, U):
return
U
@comm_guard
(
dict
,
dict
)
# noqa
@comm_guard
(
dict
,
dict
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
d1
,
d2
,
U
):
"""
Tries to unify values of corresponding keys.
...
...
@@ -481,7 +481,7 @@ def unify_walk(d1, d2, U):
return
U
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
# noqa
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
a
,
b
,
U
):
"""
Checks for the existence of the __unify_walk__ method for one of
...
...
@@ -498,7 +498,7 @@ def unify_walk(a, b, U):
return
FALL_THROUGH
@comm_guard
(
Variable
,
ANY_TYPE
)
# noqa
@comm_guard
(
Variable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
v
,
o
,
U
):
"""
This simply checks if the Var has an unification in U and uses it
...
...
@@ -528,27 +528,27 @@ def unify_merge(a, b, U):
return
a
@comm_guard
(
Variable
,
ANY_TYPE
)
# noqa
@comm_guard
(
Variable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
v
,
o
,
U
):
return
v
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
# noqa
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
bv
,
o
,
U
):
return
bv
.
value
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
# noqa
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
vil
,
l
,
U
):
return
[
unify_merge
(
x
,
x
,
U
)
for
x
in
l
]
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
# noqa
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
l1
,
l2
,
U
):
return
[
unify_merge
(
x1
,
x2
,
U
)
for
x1
,
x2
in
zip
(
l1
,
l2
)]
@comm_guard
(
dict
,
dict
)
# noqa
@comm_guard
(
dict
,
dict
)
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
d1
,
d2
,
U
):
d
=
d1
.
__class__
()
for
k1
,
v1
in
d1
.
items
():
...
...
@@ -562,12 +562,12 @@ def unify_merge(d1, d2, U):
return
d
@comm_guard
(
FVar
,
ANY_TYPE
)
# noqa
@comm_guard
(
FVar
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
vs
,
o
,
U
):
return
vs
(
U
)
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
# noqa
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
a
,
b
,
U
):
if
(
not
isinstance
(
a
,
Variable
)
...
...
@@ -579,7 +579,7 @@ def unify_merge(a, b, U):
return
FALL_THROUGH
@comm_guard
(
Variable
,
ANY_TYPE
)
# noqa
@comm_guard
(
Variable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
v
,
o
,
U
):
"""
This simply checks if the Var has an unification in U and uses it
...
...
aesara/graph/utils.py
浏览文件 @
f0392db7
...
...
@@ -3,6 +3,7 @@ import sys
import
traceback
from
abc
import
ABCMeta
from
io
import
StringIO
from
typing
import
List
def
simple_extract_stack
(
f
=
None
,
limit
=
None
,
skips
=
None
):
...
...
@@ -225,7 +226,7 @@ class MetaType(ABCMeta):
class
MetaObject
(
metaclass
=
MetaType
):
__slots__
=
[]
__slots__
:
List
=
[]
def
__ne__
(
self
,
other
):
return
not
self
==
other
...
...
aesara/link/c/basic.py
浏览文件 @
f0392db7
...
...
@@ -8,6 +8,7 @@ import sys
from
collections
import
defaultdict
from
copy
import
copy
from
io
import
StringIO
from
typing
import
Dict
import
numpy
as
np
...
...
@@ -1795,7 +1796,7 @@ class OpWiseCLinker(LocalLinker):
"""
__cache__
=
{}
__cache__
:
Dict
=
{}
def
__init__
(
self
,
fallback_on_perform
=
True
,
allow_gc
=
None
,
nice_errors
=
True
,
schedule
=
None
...
...
aesara/link/c/cmodule.py
浏览文件 @
f0392db7
...
...
@@ -19,6 +19,7 @@ import textwrap
import
time
import
warnings
from
io
import
BytesIO
,
StringIO
from
typing
import
Dict
,
List
,
Set
import
numpy.distutils
...
...
@@ -631,38 +632,38 @@ class ModuleCache:
"""
dirname
=
""
dirname
:
str
=
""
"""
The working directory that is managed by this interface.
"""
module_from_name
=
{}
module_from_name
:
Dict
=
{}
"""
Maps a module filename to the loaded module object.
"""
entry_from_key
=
{}
entry_from_key
:
Dict
=
{}
"""
Maps keys to the filename of a .so/.pyd.
"""
similar_keys
=
{}
similar_keys
:
Dict
=
{}
"""
Maps a part-of-key to all keys that share this same part.
"""
module_hash_to_key_data
=
{}
module_hash_to_key_data
:
Dict
=
{}
"""
Maps a module hash to its corresponding KeyData object.
"""
stats
=
[]
stats
:
List
=
[]
"""
A list with counters for the number of hits, loads, compiles issued by
module_from_key().
"""
loaded_key_pkl
=
set
()
loaded_key_pkl
:
Set
=
set
()
"""
Set of all key.pkl files that have been loaded.
...
...
aesara/link/c/interface.py
浏览文件 @
f0392db7
from
abc
import
abstractmethod
from
typing
import
Callable
,
Dict
,
List
,
Text
,
Tuple
from
typing
import
Callable
,
Dict
,
List
,
Text
,
Tuple
,
Union
from
aesara.graph.basic
import
Apply
,
Constant
from
aesara.graph.utils
import
MethodNotDefined
...
...
@@ -129,7 +129,7 @@ class CLinkerObject:
"""Return a list of code snippets to be inserted in module initialization."""
return
[]
def
c_code_cache_version
(
self
)
->
Tuple
[
int
]:
def
c_code_cache_version
(
self
)
->
Union
[
Tuple
[
int
],
Tuple
]:
"""Return a tuple of integers indicating the version of this `Op`.
An empty tuple indicates an 'unversioned' `Op` that will not be cached
...
...
@@ -551,7 +551,7 @@ class CLinkerType(CLinkerObject):
"""
return
""
def
c_code_cache_version
(
self
)
->
Tuple
[
int
]:
def
c_code_cache_version
(
self
)
->
Union
[
Tuple
,
Tuple
[
int
]
]:
"""Return a tuple of integers indicating the version of this type.
An empty tuple indicates an 'unversioned' type that will not
...
...
aesara/link/utils.py
浏览文件 @
f0392db7
...
...
@@ -3,7 +3,7 @@ import sys
import
traceback
import
warnings
from
operator
import
itemgetter
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
NoReturn
,
Optional
,
Tuple
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
NoReturn
,
Optional
,
Tuple
,
Union
import
numpy
as
np
...
...
@@ -315,6 +315,10 @@ def raise_with_op(
types
=
[
getattr
(
ipt
,
"type"
,
"No type"
)
for
ipt
in
node
.
inputs
]
detailed_err_msg
+=
f
"
\n
Inputs types: {types}
\n
"
shapes
:
Union
[
List
,
str
]
strides
:
Union
[
List
,
str
]
scalar_values
:
Union
[
List
,
str
]
if
thunk
is
not
None
:
if
hasattr
(
thunk
,
"inputs"
):
shapes
=
[
getattr
(
ipt
[
0
],
"shape"
,
"No shapes"
)
for
ipt
in
thunk
.
inputs
]
...
...
aesara/misc/check_duplicate_key.py
浏览文件 @
f0392db7
import
os
import
pickle
import
sys
from
typing
import
Dict
from
aesara.configdefaults
import
config
...
...
@@ -15,8 +16,8 @@ if len(sys.argv) > 1:
else
:
dirs
=
os
.
listdir
(
config
.
compiledir
)
dirs
=
[
os
.
path
.
join
(
config
.
compiledir
,
d
)
for
d
in
dirs
]
keys
=
{}
# key -> nb seen
mods
=
{}
keys
:
Dict
=
{}
# key -> nb seen
mods
:
Dict
=
{}
for
dir
in
dirs
:
key
=
None
...
...
@@ -48,12 +49,12 @@ if DISPLAY_DUPLICATE_KEYS:
if
v
>
1
:
print
(
"Duplicate key (
%
i copies):
%
s"
%
(
v
,
pickle
.
loads
(
k
)))
nbs_keys
=
{}
# nb seen -> now many key
nbs_keys
:
Dict
=
{}
# nb seen -> now many key
for
val
in
keys
.
values
():
nbs_keys
.
setdefault
(
val
,
0
)
nbs_keys
[
val
]
+=
1
nbs_mod
=
{}
# nb seen -> how many key
nbs_mod
:
Dict
=
{}
# nb seen -> how many key
nbs_mod_to_key
=
{}
# nb seen -> keys
more_than_one
=
0
for
mod
,
kk
in
mods
.
items
():
...
...
aesara/sandbox/multinomial.py
浏览文件 @
f0392db7
import
copy
import
warnings
from
typing
import
Tuple
,
Union
import
numpy
as
np
...
...
@@ -18,7 +19,7 @@ class MultinomialFromUniform(COp):
TODO : need description for parameter 'odtype'
"""
__props__
=
(
"odtype"
,)
__props__
:
Union
[
Tuple
[
str
],
Tuple
[
str
,
str
]]
=
(
"odtype"
,)
def
__init__
(
self
,
odtype
):
self
.
odtype
=
odtype
...
...
aesara/scalar/basic.py
浏览文件 @
f0392db7
...
...
@@ -15,7 +15,7 @@ from collections.abc import Callable
from
copy
import
copy
from
itertools
import
chain
from
textwrap
import
dedent
from
typing
import
Optional
from
typing
import
Dict
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
import
numpy
as
np
...
...
@@ -898,7 +898,7 @@ def upgrade_to_float(*types):
Upgrade any int types to float32 or float64 to avoid losing precision.
"""
conv
=
{
conv
:
Mapping
[
Type
,
Type
]
=
{
bool
:
float32
,
int8
:
float32
,
int16
:
float32
,
...
...
@@ -3954,7 +3954,7 @@ class Composite(ScalarOp):
"""
init_param
=
(
"inputs"
,
"outputs"
)
init_param
:
Union
[
Tuple
[
str
,
str
],
Tuple
[
str
]]
=
(
"inputs"
,
"outputs"
)
def
__str__
(
self
):
if
self
.
name
is
None
:
...
...
@@ -4331,7 +4331,7 @@ class Composite(ScalarOp):
class
Compositef32
:
# This is a dict of scalar op classes that need special handling
special
=
{}
special
:
Dict
=
{}
def
apply
(
self
,
fgraph
):
mapping
=
{}
...
...
aesara/tensor/__init__.py
浏览文件 @
f0392db7
...
...
@@ -5,12 +5,12 @@ __docformat__ = "restructuredtext en"
import
warnings
from
functools
import
singledispatch
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
NoReturn
,
Optional
def
as_tensor_variable
(
x
,
name
:
Optional
[
str
]
=
None
,
ndim
:
Optional
[
int
]
=
None
,
**
kwargs
):
)
->
Callable
:
"""Convert `x` into the appropriate `TensorType`.
This function is often used by `make_node` methods of `Op` subclasses to
...
...
@@ -39,7 +39,9 @@ def as_tensor_variable(
@singledispatch
def
_as_tensor_variable
(
x
,
name
:
Optional
[
str
],
ndim
:
Optional
[
int
],
**
kwargs
):
def
_as_tensor_variable
(
x
,
name
:
Optional
[
str
],
ndim
:
Optional
[
int
],
**
kwargs
)
->
NoReturn
:
raise
NotImplementedError
(
""
)
...
...
aesara/tensor/basic.py
浏览文件 @
f0392db7
...
...
@@ -11,6 +11,7 @@ import warnings
from
collections
import
OrderedDict
from
collections.abc
import
Sequence
from
numbers
import
Number
from
typing
import
Dict
import
numpy
as
np
...
...
@@ -679,7 +680,7 @@ class Rebroadcast(COp):
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version
=
{}
c_code_and_version
:
Dict
=
{}
check_input
=
False
__props__
=
(
"axis"
,)
...
...
aesara/tensor/blas.py
浏览文件 @
f0392db7
...
...
@@ -140,6 +140,7 @@ except ImportError:
pass
from
functools
import
reduce
from
typing
import
Tuple
,
Union
import
aesara.scalar
from
aesara.compile.mode
import
optdb
...
...
@@ -506,7 +507,7 @@ class GemmRelated(COp):
"""
__props__
=
()
__props__
:
Union
[
Tuple
,
Tuple
[
str
]]
=
()
def
c_support_code
(
self
,
**
kwargs
):
# return cblas_header_text()
...
...
aesara/tensor/elemwise.py
浏览文件 @
f0392db7
from
copy
import
copy
from
typing
import
Tuple
,
Union
import
numpy
as
np
...
...
@@ -1297,7 +1298,9 @@ class CAReduce(COp):
"""
__props__
=
(
"scalar_op"
,
"axis"
)
__props__
:
Union
[
Tuple
[
str
],
Tuple
[
str
,
str
],
Tuple
[
str
,
str
,
str
],
Tuple
[
str
,
str
,
str
,
str
]
]
=
(
"scalar_op"
,
"axis"
)
def
__init__
(
self
,
scalar_op
,
axis
=
None
):
if
scalar_op
.
nin
not
in
[
-
1
,
2
]
or
scalar_op
.
nout
!=
1
:
...
...
@@ -1682,7 +1685,12 @@ class CAReduceDtype(CAReduce):
"""
__props__
=
(
"scalar_op"
,
"axis"
,
"dtype"
,
"acc_dtype"
)
__props__
:
Union
[
Tuple
[
str
,
str
,
str
],
Tuple
[
str
,
str
,
str
,
str
]]
=
(
"scalar_op"
,
"axis"
,
"dtype"
,
"acc_dtype"
,
)
def
__init__
(
self
,
scalar_op
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
):
super
()
.
__init__
(
scalar_op
,
axis
=
axis
)
...
...
aesara/tensor/nlinalg.py
浏览文件 @
f0392db7
import
logging
from
functools
import
partial
from
typing
import
Tuple
,
Union
import
numpy
as
np
...
...
@@ -224,7 +225,7 @@ class Eig(Op):
"""
_numop
=
staticmethod
(
np
.
linalg
.
eig
)
__props__
=
()
__props__
:
Union
[
Tuple
,
Tuple
[
str
]]
=
()
def
make_node
(
self
,
x
):
x
=
as_tensor_variable
(
x
)
...
...
aesara/tensor/nnet/blocksparse.py
浏览文件 @
f0392db7
from
typing
import
List
import
numpy
as
np
import
aesara
...
...
@@ -26,7 +28,7 @@ class SparseBlockGemv(Op):
__props__
=
(
"inplace"
,)
registered_opts
=
[]
registered_opts
:
List
=
[]
def
__init__
(
self
,
inplace
=
False
):
self
.
inplace
=
inplace
...
...
@@ -147,7 +149,7 @@ class SparseBlockOuter(Op):
__props__
=
(
"inplace"
,)
registered_opts
=
[]
registered_opts
:
List
=
[]
def
__init__
(
self
,
inplace
=
False
):
self
.
inplace
=
inplace
...
...
aesara/tensor/shape.py
浏览文件 @
f0392db7
import
warnings
from
typing
import
Dict
import
numpy
as
np
...
...
@@ -49,7 +50,7 @@ class Shape(COp):
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version
=
{}
c_code_and_version
:
Dict
=
{}
check_input
=
False
__props__
=
()
...
...
@@ -143,7 +144,7 @@ class Shape_i(COp):
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version
=
{}
c_code_and_version
:
Dict
=
{}
check_input
=
False
...
...
@@ -363,7 +364,7 @@ class SpecifyShape(COp):
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version
=
{}
c_code_and_version
:
Dict
=
{}
__props__
=
()
_f16_ok
=
True
...
...
aesara/utils.py
浏览文件 @
f0392db7
...
...
@@ -12,6 +12,7 @@ import warnings
from
collections
import
OrderedDict
from
collections.abc
import
Callable
from
functools
import
partial
,
wraps
from
typing
import
List
,
Set
__all__
=
[
...
...
@@ -29,7 +30,7 @@ __all__ = [
]
__excepthooks
=
[]
__excepthooks
:
List
=
[]
LOCAL_BITWIDTH
=
struct
.
calcsize
(
"P"
)
*
8
...
...
@@ -374,7 +375,7 @@ def apply_across_args(*fns):
class
NoDuplicateOptWarningFilter
(
logging
.
Filter
):
"""Filter to avoid duplicating optimization warnings."""
prev_msgs
=
set
()
prev_msgs
:
Set
=
set
()
def
filter
(
self
,
record
):
msg
=
record
.
getMessage
()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论