Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f0392db7
提交
f0392db7
authored
3月 12, 2021
作者:
LegrandNico
提交者:
Thomas Wiecki
3月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix type hint errors (272->168)
上级
f26a5086
显示空白字符变更
内嵌
并排
正在显示
28 个修改的文件
包含
143 行增加
和
99 行删除
+143
-99
_version.py
aesara/_version.py
+3
-3
types.py
aesara/compile/function/types.py
+2
-1
mode.py
aesara/compile/mode.py
+3
-0
ops.py
aesara/compile/ops.py
+8
-7
profiling.py
aesara/compile/profiling.py
+6
-5
configparser.py
aesara/configparser.py
+3
-2
basic_ops.py
aesara/gpuarray/basic_ops.py
+2
-1
pathparse.py
aesara/gpuarray/pathparse.py
+2
-1
basic.py
aesara/graph/basic.py
+14
-13
op.py
aesara/graph/op.py
+7
-5
optdb.py
aesara/graph/optdb.py
+8
-2
unify.py
aesara/graph/unify.py
+22
-22
utils.py
aesara/graph/utils.py
+2
-1
basic.py
aesara/link/c/basic.py
+2
-1
cmodule.py
aesara/link/c/cmodule.py
+8
-7
interface.py
aesara/link/c/interface.py
+3
-3
utils.py
aesara/link/utils.py
+5
-1
check_duplicate_key.py
aesara/misc/check_duplicate_key.py
+5
-4
multinomial.py
aesara/sandbox/multinomial.py
+2
-1
basic.py
aesara/scalar/basic.py
+4
-4
__init__.py
aesara/tensor/__init__.py
+5
-3
basic.py
aesara/tensor/basic.py
+2
-1
blas.py
aesara/tensor/blas.py
+2
-1
elemwise.py
aesara/tensor/elemwise.py
+10
-2
nlinalg.py
aesara/tensor/nlinalg.py
+2
-1
blocksparse.py
aesara/tensor/nnet/blocksparse.py
+4
-2
shape.py
aesara/tensor/shape.py
+4
-3
utils.py
aesara/utils.py
+3
-2
没有找到文件。
aesara/_version.py
浏览文件 @
f0392db7
...
@@ -14,7 +14,7 @@ import os
...
@@ -14,7 +14,7 @@ import os
import
re
import
re
import
subprocess
import
subprocess
import
sys
import
sys
from
typing
import
Dict
def
get_keywords
():
def
get_keywords
():
"""Get the keywords needed to look up the version information."""
"""Get the keywords needed to look up the version information."""
...
@@ -51,8 +51,8 @@ class NotThisMethod(Exception):
...
@@ -51,8 +51,8 @@ class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario."""
"""Exception raised if a method is not valid for the current scenario."""
LONG_VERSION_PY
=
{}
LONG_VERSION_PY
:
Dict
=
{}
HANDLERS
=
{}
HANDLERS
:
Dict
=
{}
def
register_vcs_handler
(
vcs
,
method
):
# decorator
def
register_vcs_handler
(
vcs
,
method
):
# decorator
...
...
aesara/compile/function/types.py
浏览文件 @
f0392db7
...
@@ -11,6 +11,7 @@ import pickle
...
@@ -11,6 +11,7 @@ import pickle
import
time
import
time
import
warnings
import
warnings
from
itertools
import
chain
from
itertools
import
chain
from
typing
import
List
import
numpy
as
np
import
numpy
as
np
...
@@ -1877,7 +1878,7 @@ def _constructor_FunctionMaker(kwargs):
...
@@ -1877,7 +1878,7 @@ def _constructor_FunctionMaker(kwargs):
return
None
return
None
__checkers
=
[]
__checkers
:
List
=
[]
def
check_equal
(
x
,
y
):
def
check_equal
(
x
,
y
):
...
...
aesara/compile/mode.py
浏览文件 @
f0392db7
...
@@ -5,6 +5,7 @@ WRITEME
...
@@ -5,6 +5,7 @@ WRITEME
import
logging
import
logging
import
warnings
import
warnings
from
typing
import
Tuple
,
Union
import
aesara
import
aesara
from
aesara.compile.function.types
import
Supervisor
from
aesara.compile.function.types
import
Supervisor
...
@@ -252,6 +253,8 @@ optdb.register("add_destroy_handler", AddDestroyHandler(), 49.5, "fast_run", "in
...
@@ -252,6 +253,8 @@ optdb.register("add_destroy_handler", AddDestroyHandler(), 49.5, "fast_run", "in
# final pass just to make sure
# final pass just to make sure
optdb
.
register
(
"merge3"
,
MergeOptimizer
(),
100
,
"fast_run"
,
"merge"
)
optdb
.
register
(
"merge3"
,
MergeOptimizer
(),
100
,
"fast_run"
,
"merge"
)
_tags
:
Union
[
Tuple
[
str
,
str
],
Tuple
]
if
config
.
check_stack_trace
in
[
"raise"
,
"warn"
,
"log"
]:
if
config
.
check_stack_trace
in
[
"raise"
,
"warn"
,
"log"
]:
_tags
=
(
"fast_run"
,
"fast_compile"
)
_tags
=
(
"fast_run"
,
"fast_compile"
)
...
...
aesara/compile/ops.py
浏览文件 @
f0392db7
...
@@ -8,6 +8,7 @@ help make new Ops more rapidly.
...
@@ -8,6 +8,7 @@ help make new Ops more rapidly.
import
copy
import
copy
import
pickle
import
pickle
import
warnings
import
warnings
from
typing
import
Dict
,
Tuple
from
aesara.graph.basic
import
Apply
from
aesara.graph.basic
import
Apply
from
aesara.graph.op
import
COp
,
Op
from
aesara.graph.op
import
COp
,
Op
...
@@ -42,9 +43,9 @@ class ViewOp(COp):
...
@@ -42,9 +43,9 @@ class ViewOp(COp):
# Mapping from Type to C code (and version) to use.
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
# the output variable is %(oname)s.
c_code_and_version
=
{}
c_code_and_version
:
Dict
=
{}
__props__
=
()
__props__
:
Tuple
=
()
_f16_ok
=
True
_f16_ok
:
bool
=
True
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
...
@@ -148,11 +149,11 @@ class DeepCopyOp(COp):
...
@@ -148,11 +149,11 @@ class DeepCopyOp(COp):
# Mapping from Type to C code (and version) to use.
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
# the output variable is %(oname)s.
c_code_and_version
=
{}
c_code_and_version
:
Dict
=
{}
check_input
=
False
check_input
:
bool
=
False
__props__
=
()
__props__
:
Tuple
=
()
_f16_ok
=
True
_f16_ok
:
bool
=
True
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
...
...
aesara/compile/profiling.py
浏览文件 @
f0392db7
...
@@ -17,6 +17,7 @@ import sys
...
@@ -17,6 +17,7 @@ import sys
import
time
import
time
import
warnings
import
warnings
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Dict
,
List
import
numpy
as
np
import
numpy
as
np
...
@@ -37,7 +38,7 @@ total_fct_exec_time = 0.0
...
@@ -37,7 +38,7 @@ total_fct_exec_time = 0.0
total_graph_opt_time
=
0.0
total_graph_opt_time
=
0.0
total_time_linker
=
0.0
total_time_linker
=
0.0
_atexit_print_list
=
[]
_atexit_print_list
:
List
=
[]
_atexit_registered
=
False
_atexit_registered
=
False
...
@@ -242,15 +243,15 @@ class ProfileStats:
...
@@ -242,15 +243,15 @@ class ProfileStats:
# pretty string to print in summary, to identify this output
# pretty string to print in summary, to identify this output
#
#
variable_shape
=
{}
variable_shape
:
Dict
=
{}
# Variable -> shapes
# Variable -> shapes
#
#
variable_strides
=
{}
variable_strides
:
Dict
=
{}
# Variable -> strides
# Variable -> strides
#
#
variable_offset
=
{}
variable_offset
:
Dict
=
{}
# Variable -> offset
# Variable -> offset
#
#
...
@@ -270,7 +271,7 @@ class ProfileStats:
...
@@ -270,7 +271,7 @@ class ProfileStats:
linker_node_make_thunks
=
0.0
linker_node_make_thunks
=
0.0
linker_make_thunk_time
=
{}
linker_make_thunk_time
:
Dict
=
{}
line_width
=
config
.
profiling__output_line_width
line_width
=
config
.
profiling__output_line_width
...
...
aesara/configparser.py
浏览文件 @
f0392db7
...
@@ -13,6 +13,7 @@ from configparser import (
...
@@ -13,6 +13,7 @@ from configparser import (
)
)
from
functools
import
wraps
from
functools
import
wraps
from
io
import
StringIO
from
io
import
StringIO
from
typing
import
Optional
from
aesara.utils
import
deprecated
,
hash_from_code
from
aesara.utils
import
deprecated
,
hash_from_code
...
@@ -94,7 +95,7 @@ class AesaraConfigParser:
...
@@ -94,7 +95,7 @@ class AesaraConfigParser:
self
.
_flags_dict
=
flags_dict
self
.
_flags_dict
=
flags_dict
self
.
_aesara_cfg
=
aesara_cfg
self
.
_aesara_cfg
=
aesara_cfg
self
.
_aesara_raw_cfg
=
aesara_raw_cfg
self
.
_aesara_raw_cfg
=
aesara_raw_cfg
self
.
_config_var_dict
=
{}
self
.
_config_var_dict
:
typing
.
Dict
=
{}
super
()
.
__init__
()
super
()
.
__init__
()
def
__str__
(
self
,
print_doc
=
True
):
def
__str__
(
self
,
print_doc
=
True
):
...
@@ -325,7 +326,7 @@ class ConfigParam:
...
@@ -325,7 +326,7 @@ class ConfigParam:
return
self
.
_apply
(
value
)
return
self
.
_apply
(
value
)
return
value
return
value
def
validate
(
self
,
value
)
->
None
:
def
validate
(
self
,
value
)
->
Optional
[
bool
]
:
"""Validates that a parameter values falls into a supported set or range.
"""Validates that a parameter values falls into a supported set or range.
Raises
Raises
...
...
aesara/gpuarray/basic_ops.py
浏览文件 @
f0392db7
...
@@ -2,6 +2,7 @@ import copy
...
@@ -2,6 +2,7 @@ import copy
import
os
import
os
import
re
import
re
from
collections
import
deque
from
collections
import
deque
from
typing
import
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -306,7 +307,7 @@ class GpuKernelBase:
...
@@ -306,7 +307,7 @@ class GpuKernelBase:
"""
"""
params_type
=
gpu_context_type
params_type
:
Union
[
ParamsType
,
gpu_context_type
]
=
gpu_context_type
def
get_params
(
self
,
node
):
def
get_params
(
self
,
node
):
# Default implementation, suitable for most sub-classes.
# Default implementation, suitable for most sub-classes.
...
...
aesara/gpuarray/pathparse.py
浏览文件 @
f0392db7
import
os
import
os
import
sys
import
sys
from
typing
import
Set
class
PathParser
:
class
PathParser
:
...
@@ -25,7 +26,7 @@ class PathParser:
...
@@ -25,7 +26,7 @@ class PathParser:
"""
"""
paths
=
set
()
paths
:
Set
=
set
()
def
_add
(
self
,
path
):
def
_add
(
self
,
path
):
path
=
path
.
strip
()
path
=
path
.
strip
()
...
...
aesara/graph/basic.py
浏览文件 @
f0392db7
...
@@ -571,7 +571,7 @@ class Variable(Node):
...
@@ -571,7 +571,7 @@ class Variable(Node):
return
d
return
d
# refer to doc in nodes_constructed.
# refer to doc in nodes_constructed.
construction_observers
=
[]
construction_observers
:
List
=
[]
@classmethod
@classmethod
def
append_construction_observer
(
cls
,
observer
):
def
append_construction_observer
(
cls
,
observer
):
...
@@ -700,10 +700,11 @@ def walk(
...
@@ -700,10 +700,11 @@ def walk(
rval_set
:
Set
[
T
]
=
set
()
rval_set
:
Set
[
T
]
=
set
()
nodes_pop
:
Callable
[[],
T
]
if
bfs
:
if
bfs
:
nodes_pop
:
Callable
[[],
T
]
=
nodes
.
popleft
nodes_pop
=
nodes
.
popleft
else
:
else
:
nodes_pop
:
Callable
[[],
T
]
=
nodes
.
pop
nodes_pop
=
nodes
.
pop
while
nodes
:
while
nodes
:
node
:
T
=
nodes_pop
()
node
:
T
=
nodes_pop
()
...
@@ -714,7 +715,7 @@ def walk(
...
@@ -714,7 +715,7 @@ def walk(
rval_set
.
add
(
node_hash
)
rval_set
.
add
(
node_hash
)
new_nodes
:
Sequence
[
T
]
=
expand
(
node
)
new_nodes
:
Optional
[
Sequence
[
T
]
]
=
expand
(
node
)
if
return_children
:
if
return_children
:
yield
node
,
new_nodes
yield
node
,
new_nodes
...
@@ -862,8 +863,8 @@ def applys_between(
...
@@ -862,8 +863,8 @@ def applys_between(
def
clone
(
def
clone
(
inputs
:
Collection
[
Variable
],
inputs
:
List
[
Variable
],
outputs
:
Collection
[
Variable
],
outputs
:
List
[
Variable
],
copy_inputs
:
bool
=
True
,
copy_inputs
:
bool
=
True
,
copy_orphans
:
Optional
[
bool
]
=
None
,
copy_orphans
:
Optional
[
bool
]
=
None
,
)
->
Tuple
[
Collection
[
Variable
],
Collection
[
Variable
]]:
)
->
Tuple
[
Collection
[
Variable
],
Collection
[
Variable
]]:
...
@@ -902,8 +903,8 @@ def clone(
...
@@ -902,8 +903,8 @@ def clone(
def
clone_get_equiv
(
def
clone_get_equiv
(
inputs
:
Collection
[
Variable
],
inputs
:
List
[
Variable
],
outputs
:
Collection
[
Variable
],
outputs
:
List
[
Variable
],
copy_inputs
:
bool
=
True
,
copy_inputs
:
bool
=
True
,
copy_orphans
:
bool
=
True
,
copy_orphans
:
bool
=
True
,
memo
:
Optional
[
Dict
[
Variable
,
Variable
]]
=
None
,
memo
:
Optional
[
Dict
[
Variable
,
Variable
]]
=
None
,
...
@@ -1171,7 +1172,7 @@ def io_toposort(
...
@@ -1171,7 +1172,7 @@ def io_toposort(
compute_deps
=
None
compute_deps
=
None
compute_deps_cache
=
None
compute_deps_cache
=
None
iset
=
set
(
inputs
)
iset
=
set
(
inputs
)
deps_cache
=
{}
deps_cache
:
Dict
=
{}
if
not
orderings
:
# ordering can be None or empty dict
if
not
orderings
:
# ordering can be None or empty dict
# Specialized function that is faster when no ordering.
# Specialized function that is faster when no ordering.
...
@@ -1345,8 +1346,8 @@ def as_string(
...
@@ -1345,8 +1346,8 @@ def as_string(
orph
=
list
(
orphans_between
(
i
,
outputs
))
orph
=
list
(
orphans_between
(
i
,
outputs
))
multi
=
set
()
multi
:
Set
=
set
()
seen
=
set
()
seen
:
Set
=
set
()
for
output
in
outputs
:
for
output
in
outputs
:
op
=
output
.
owner
op
=
output
.
owner
if
op
in
seen
:
if
op
in
seen
:
...
@@ -1362,8 +1363,8 @@ def as_string(
...
@@ -1362,8 +1363,8 @@ def as_string(
multi
.
add
(
op2
)
multi
.
add
(
op2
)
else
:
else
:
seen
.
add
(
input
.
owner
)
seen
.
add
(
input
.
owner
)
multi
=
[
x
for
x
in
multi
]
multi
:
Set
=
[
x
for
x
in
multi
]
done
=
set
()
done
:
Set
=
set
()
def
multi_index
(
x
):
def
multi_index
(
x
):
return
multi
.
index
(
x
)
+
1
return
multi
.
index
(
x
)
+
1
...
...
aesara/graph/op.py
浏览文件 @
f0392db7
...
@@ -157,7 +157,7 @@ class Op(MetaObject):
...
@@ -157,7 +157,7 @@ class Op(MetaObject):
"""
"""
default_output
=
None
default_output
:
Optional
[
int
]
=
None
"""
"""
An `int` that specifies which output `Op.__call__` should return. If
An `int` that specifies which output `Op.__call__` should return. If
`None`, then all outputs are returned.
`None`, then all outputs are returned.
...
@@ -852,7 +852,7 @@ def lquote_macro(txt: Text) -> Text:
...
@@ -852,7 +852,7 @@ def lquote_macro(txt: Text) -> Text:
return
"
\n
"
.
join
(
res
)
return
"
\n
"
.
join
(
res
)
def
get_sub_macros
(
sub
:
Dict
[
Text
,
Text
])
->
Tuple
[
Text
]:
def
get_sub_macros
(
sub
:
Dict
[
Text
,
Text
])
->
Union
[
Tuple
[
Text
],
Tuple
[
Text
,
Text
]
]:
define_macros
=
[]
define_macros
=
[]
undef_macros
=
[]
undef_macros
=
[]
define_macros
.
append
(
f
"#define FAIL {lquote_macro(sub['fail'])}"
)
define_macros
.
append
(
f
"#define FAIL {lquote_macro(sub['fail'])}"
)
...
@@ -864,7 +864,9 @@ def get_sub_macros(sub: Dict[Text, Text]) -> Tuple[Text]:
...
@@ -864,7 +864,9 @@ def get_sub_macros(sub: Dict[Text, Text]) -> Tuple[Text]:
return
"
\n
"
.
join
(
define_macros
),
"
\n
"
.
join
(
undef_macros
)
return
"
\n
"
.
join
(
define_macros
),
"
\n
"
.
join
(
undef_macros
)
def
get_io_macros
(
inputs
:
List
[
Text
],
outputs
:
List
[
Text
])
->
Tuple
[
List
[
Text
]]:
def
get_io_macros
(
inputs
:
List
[
Text
],
outputs
:
List
[
Text
]
)
->
Union
[
Tuple
[
List
[
Text
]],
Tuple
[
str
,
str
]]:
define_macros
=
[]
define_macros
=
[]
undef_macros
=
[]
undef_macros
=
[]
...
@@ -1023,7 +1025,7 @@ class ExternalCOp(COp):
...
@@ -1023,7 +1025,7 @@ class ExternalCOp(COp):
f
"No valid section marker was found in file {func_files[i]}"
f
"No valid section marker was found in file {func_files[i]}"
)
)
def
__get_op_params
(
self
)
->
List
[
Text
]:
def
__get_op_params
(
self
)
->
Union
[
List
[
Text
],
List
[
Tuple
[
str
,
Any
]]
]:
"""Construct name, value pairs that will be turned into macros for use within the `Op`'s code.
"""Construct name, value pairs that will be turned into macros for use within the `Op`'s code.
The names must be strings that are not a C keyword and the
The names must be strings that are not a C keyword and the
...
@@ -1130,7 +1132,7 @@ class ExternalCOp(COp):
...
@@ -1130,7 +1132,7 @@ class ExternalCOp(COp):
def
get_c_macros
(
def
get_c_macros
(
self
,
node
:
Apply
,
name
:
Text
,
check_input
:
Optional
[
bool
]
=
None
self
,
node
:
Apply
,
name
:
Text
,
check_input
:
Optional
[
bool
]
=
None
)
->
Tuple
[
Text
]:
)
->
Union
[
Tuple
[
str
],
Tuple
[
str
,
str
]
]:
"Construct a pair of C ``#define`` and ``#undef`` code strings."
"Construct a pair of C ``#define`` and ``#undef`` code strings."
define_template
=
"#define
%
s
%
s"
define_template
=
"#define
%
s
%
s"
undef_template
=
"#undef
%
s"
undef_template
=
"#undef
%
s"
...
...
aesara/graph/optdb.py
浏览文件 @
f0392db7
...
@@ -2,6 +2,7 @@ import copy
...
@@ -2,6 +2,7 @@ import copy
import
math
import
math
import
sys
import
sys
from
io
import
StringIO
from
io
import
StringIO
from
typing
import
Dict
,
Optional
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
from
aesara.graph
import
opt
from
aesara.graph
import
opt
...
@@ -192,6 +193,7 @@ class Query:
...
@@ -192,6 +193,7 @@ class Query:
self
.
exclude
=
exclude
or
OrderedSet
()
self
.
exclude
=
exclude
or
OrderedSet
()
self
.
subquery
=
subquery
or
{}
self
.
subquery
=
subquery
or
{}
self
.
position_cutoff
=
position_cutoff
self
.
position_cutoff
=
position_cutoff
self
.
name
:
Optional
[
str
]
=
None
if
extra_optimizations
is
None
:
if
extra_optimizations
is
None
:
extra_optimizations
=
[]
extra_optimizations
=
[]
self
.
extra_optimizations
=
extra_optimizations
self
.
extra_optimizations
=
extra_optimizations
...
@@ -438,14 +440,18 @@ class LocalGroupDB(DB):
...
@@ -438,14 +440,18 @@ class LocalGroupDB(DB):
"""
"""
def
__init__
(
def
__init__
(
self
,
apply_all_opts
=
False
,
profile
=
False
,
local_opt
=
opt
.
LocalOptGroup
self
,
apply_all_opts
:
bool
=
False
,
profile
:
bool
=
False
,
local_opt
=
opt
.
LocalOptGroup
,
):
):
super
()
.
__init__
()
super
()
.
__init__
()
self
.
failure_callback
=
None
self
.
failure_callback
=
None
self
.
apply_all_opts
=
apply_all_opts
self
.
apply_all_opts
=
apply_all_opts
self
.
profile
=
profile
self
.
profile
=
profile
self
.
__position__
=
{}
self
.
__position__
:
Dict
=
{}
self
.
local_opt
=
local_opt
self
.
local_opt
=
local_opt
self
.
__name__
:
str
=
""
def
register
(
self
,
name
,
obj
,
*
tags
,
**
kwargs
):
def
register
(
self
,
name
,
obj
,
*
tags
,
**
kwargs
):
super
()
.
register
(
name
,
obj
,
*
tags
)
super
()
.
register
(
name
,
obj
,
*
tags
)
...
...
aesara/graph/unify.py
浏览文件 @
f0392db7
...
@@ -331,7 +331,7 @@ def unify_walk(a, b, U):
...
@@ -331,7 +331,7 @@ def unify_walk(a, b, U):
return
False
return
False
@comm_guard
(
FreeVariable
,
ANY_TYPE
)
# noqa
@comm_guard
(
FreeVariable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
fv
,
o
,
U
):
def
unify_walk
(
fv
,
o
,
U
):
"""
"""
FreeV is unified to BoundVariable(other_object).
FreeV is unified to BoundVariable(other_object).
...
@@ -341,7 +341,7 @@ def unify_walk(fv, o, U):
...
@@ -341,7 +341,7 @@ def unify_walk(fv, o, U):
return
U
.
merge
(
v
,
fv
)
return
U
.
merge
(
v
,
fv
)
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
# noqa
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
bv
,
o
,
U
):
def
unify_walk
(
bv
,
o
,
U
):
"""
"""
The unification succeed iff BV.value == other_object.
The unification succeed iff BV.value == other_object.
...
@@ -353,7 +353,7 @@ def unify_walk(bv, o, U):
...
@@ -353,7 +353,7 @@ def unify_walk(bv, o, U):
return
False
return
False
@comm_guard
(
OrVariable
,
ANY_TYPE
)
# noqa
@comm_guard
(
OrVariable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
ov
,
o
,
U
):
def
unify_walk
(
ov
,
o
,
U
):
"""
"""
The unification succeeds iff other_object in OrV.options.
The unification succeeds iff other_object in OrV.options.
...
@@ -366,7 +366,7 @@ def unify_walk(ov, o, U):
...
@@ -366,7 +366,7 @@ def unify_walk(ov, o, U):
return
False
return
False
@comm_guard
(
NotVariable
,
ANY_TYPE
)
# noqa
@comm_guard
(
NotVariable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
nv
,
o
,
U
):
def
unify_walk
(
nv
,
o
,
U
):
"""
"""
The unification succeeds iff other_object not in NV.not_options.
The unification succeeds iff other_object not in NV.not_options.
...
@@ -379,7 +379,7 @@ def unify_walk(nv, o, U):
...
@@ -379,7 +379,7 @@ def unify_walk(nv, o, U):
return
U
.
merge
(
v
,
nv
)
return
U
.
merge
(
v
,
nv
)
@comm_guard
(
FreeVariable
,
Variable
)
# noqa
@comm_guard
(
FreeVariable
,
Variable
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
fv
,
v
,
U
):
def
unify_walk
(
fv
,
v
,
U
):
"""
"""
Both variables are unified.
Both variables are unified.
...
@@ -389,7 +389,7 @@ def unify_walk(fv, v, U):
...
@@ -389,7 +389,7 @@ def unify_walk(fv, v, U):
return
U
.
merge
(
v
,
fv
)
return
U
.
merge
(
v
,
fv
)
@comm_guard
(
BoundVariable
,
Variable
)
# noqa
@comm_guard
(
BoundVariable
,
Variable
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
bv
,
v
,
U
):
def
unify_walk
(
bv
,
v
,
U
):
"""
"""
V is unified to BV.value.
V is unified to BV.value.
...
@@ -398,7 +398,7 @@ def unify_walk(bv, v, U):
...
@@ -398,7 +398,7 @@ def unify_walk(bv, v, U):
return
unify_walk
(
v
,
bv
.
value
,
U
)
return
unify_walk
(
v
,
bv
.
value
,
U
)
@comm_guard
(
OrVariable
,
OrVariable
)
# noqa
@comm_guard
(
OrVariable
,
OrVariable
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
a
,
b
,
U
):
def
unify_walk
(
a
,
b
,
U
):
"""
"""
OrV(list1) == OrV(list2) == OrV(intersection(list1, list2))
OrV(list1) == OrV(list2) == OrV(intersection(list1, list2))
...
@@ -414,7 +414,7 @@ def unify_walk(a, b, U):
...
@@ -414,7 +414,7 @@ def unify_walk(a, b, U):
return
U
.
merge
(
v
,
a
,
b
)
return
U
.
merge
(
v
,
a
,
b
)
@comm_guard
(
NotVariable
,
NotVariable
)
# noqa
@comm_guard
(
NotVariable
,
NotVariable
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
a
,
b
,
U
):
def
unify_walk
(
a
,
b
,
U
):
"""
"""
NV(list1) == NV(list2) == NV(union(list1, list2))
NV(list1) == NV(list2) == NV(union(list1, list2))
...
@@ -425,7 +425,7 @@ def unify_walk(a, b, U):
...
@@ -425,7 +425,7 @@ def unify_walk(a, b, U):
return
U
.
merge
(
v
,
a
,
b
)
return
U
.
merge
(
v
,
a
,
b
)
@comm_guard
(
OrVariable
,
NotVariable
)
# noqa
@comm_guard
(
OrVariable
,
NotVariable
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
o
,
n
,
U
):
def
unify_walk
(
o
,
n
,
U
):
r"""
r"""
OrV(list1) == NV(list2) == OrV(list1 \ list2)
OrV(list1) == NV(list2) == OrV(list1 \ list2)
...
@@ -441,7 +441,7 @@ def unify_walk(o, n, U):
...
@@ -441,7 +441,7 @@ def unify_walk(o, n, U):
return
U
.
merge
(
v
,
o
,
n
)
return
U
.
merge
(
v
,
o
,
n
)
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
# noqa
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
vil
,
l
,
U
):
def
unify_walk
(
vil
,
l
,
U
):
"""
"""
Unifies VIL's inner Variable to OrV(list).
Unifies VIL's inner Variable to OrV(list).
...
@@ -452,7 +452,7 @@ def unify_walk(vil, l, U):
...
@@ -452,7 +452,7 @@ def unify_walk(vil, l, U):
return
unify_walk
(
v
,
ov
,
U
)
return
unify_walk
(
v
,
ov
,
U
)
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
# noqa
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
l1
,
l2
,
U
):
def
unify_walk
(
l1
,
l2
,
U
):
"""
"""
Tries to unify each corresponding pair of elements from l1 and l2.
Tries to unify each corresponding pair of elements from l1 and l2.
...
@@ -467,7 +467,7 @@ def unify_walk(l1, l2, U):
...
@@ -467,7 +467,7 @@ def unify_walk(l1, l2, U):
return
U
return
U
@comm_guard
(
dict
,
dict
)
# noqa
@comm_guard
(
dict
,
dict
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
d1
,
d2
,
U
):
def
unify_walk
(
d1
,
d2
,
U
):
"""
"""
Tries to unify values of corresponding keys.
Tries to unify values of corresponding keys.
...
@@ -481,7 +481,7 @@ def unify_walk(d1, d2, U):
...
@@ -481,7 +481,7 @@ def unify_walk(d1, d2, U):
return
U
return
U
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
# noqa
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
a
,
b
,
U
):
def
unify_walk
(
a
,
b
,
U
):
"""
"""
Checks for the existence of the __unify_walk__ method for one of
Checks for the existence of the __unify_walk__ method for one of
...
@@ -498,7 +498,7 @@ def unify_walk(a, b, U):
...
@@ -498,7 +498,7 @@ def unify_walk(a, b, U):
return
FALL_THROUGH
return
FALL_THROUGH
@comm_guard
(
Variable
,
ANY_TYPE
)
# noqa
@comm_guard
(
Variable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_walk
(
v
,
o
,
U
):
def
unify_walk
(
v
,
o
,
U
):
"""
"""
This simply checks if the Var has an unification in U and uses it
This simply checks if the Var has an unification in U and uses it
...
@@ -528,27 +528,27 @@ def unify_merge(a, b, U):
...
@@ -528,27 +528,27 @@ def unify_merge(a, b, U):
return
a
return
a
@comm_guard
(
Variable
,
ANY_TYPE
)
# noqa
@comm_guard
(
Variable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
v
,
o
,
U
):
def
unify_merge
(
v
,
o
,
U
):
return
v
return
v
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
# noqa
@comm_guard
(
BoundVariable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
bv
,
o
,
U
):
def
unify_merge
(
bv
,
o
,
U
):
return
bv
.
value
return
bv
.
value
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
# noqa
@comm_guard
(
VariableInList
,
(
list
,
tuple
))
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
vil
,
l
,
U
):
def
unify_merge
(
vil
,
l
,
U
):
return
[
unify_merge
(
x
,
x
,
U
)
for
x
in
l
]
return
[
unify_merge
(
x
,
x
,
U
)
for
x
in
l
]
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
# noqa
@comm_guard
((
list
,
tuple
),
(
list
,
tuple
))
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
l1
,
l2
,
U
):
def
unify_merge
(
l1
,
l2
,
U
):
return
[
unify_merge
(
x1
,
x2
,
U
)
for
x1
,
x2
in
zip
(
l1
,
l2
)]
return
[
unify_merge
(
x1
,
x2
,
U
)
for
x1
,
x2
in
zip
(
l1
,
l2
)]
@comm_guard
(
dict
,
dict
)
# noqa
@comm_guard
(
dict
,
dict
)
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
d1
,
d2
,
U
):
def
unify_merge
(
d1
,
d2
,
U
):
d
=
d1
.
__class__
()
d
=
d1
.
__class__
()
for
k1
,
v1
in
d1
.
items
():
for
k1
,
v1
in
d1
.
items
():
...
@@ -562,12 +562,12 @@ def unify_merge(d1, d2, U):
...
@@ -562,12 +562,12 @@ def unify_merge(d1, d2, U):
return
d
return
d
@comm_guard
(
FVar
,
ANY_TYPE
)
# noqa
@comm_guard
(
FVar
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
vs
,
o
,
U
):
def
unify_merge
(
vs
,
o
,
U
):
return
vs
(
U
)
return
vs
(
U
)
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
# noqa
@comm_guard
(
ANY_TYPE
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
a
,
b
,
U
):
def
unify_merge
(
a
,
b
,
U
):
if
(
if
(
not
isinstance
(
a
,
Variable
)
not
isinstance
(
a
,
Variable
)
...
@@ -579,7 +579,7 @@ def unify_merge(a, b, U):
...
@@ -579,7 +579,7 @@ def unify_merge(a, b, U):
return
FALL_THROUGH
return
FALL_THROUGH
@comm_guard
(
Variable
,
ANY_TYPE
)
# noqa
@comm_guard
(
Variable
,
ANY_TYPE
)
#
type: ignore[no-redef] #
noqa
def
unify_merge
(
v
,
o
,
U
):
def
unify_merge
(
v
,
o
,
U
):
"""
"""
This simply checks if the Var has an unification in U and uses it
This simply checks if the Var has an unification in U and uses it
...
...
aesara/graph/utils.py
浏览文件 @
f0392db7
...
@@ -3,6 +3,7 @@ import sys
...
@@ -3,6 +3,7 @@ import sys
import
traceback
import
traceback
from
abc
import
ABCMeta
from
abc
import
ABCMeta
from
io
import
StringIO
from
io
import
StringIO
from
typing
import
List
def
simple_extract_stack
(
f
=
None
,
limit
=
None
,
skips
=
None
):
def
simple_extract_stack
(
f
=
None
,
limit
=
None
,
skips
=
None
):
...
@@ -225,7 +226,7 @@ class MetaType(ABCMeta):
...
@@ -225,7 +226,7 @@ class MetaType(ABCMeta):
class
MetaObject
(
metaclass
=
MetaType
):
class
MetaObject
(
metaclass
=
MetaType
):
__slots__
=
[]
__slots__
:
List
=
[]
def
__ne__
(
self
,
other
):
def
__ne__
(
self
,
other
):
return
not
self
==
other
return
not
self
==
other
...
...
aesara/link/c/basic.py
浏览文件 @
f0392db7
...
@@ -8,6 +8,7 @@ import sys
...
@@ -8,6 +8,7 @@ import sys
from
collections
import
defaultdict
from
collections
import
defaultdict
from
copy
import
copy
from
copy
import
copy
from
io
import
StringIO
from
io
import
StringIO
from
typing
import
Dict
import
numpy
as
np
import
numpy
as
np
...
@@ -1795,7 +1796,7 @@ class OpWiseCLinker(LocalLinker):
...
@@ -1795,7 +1796,7 @@ class OpWiseCLinker(LocalLinker):
"""
"""
__cache__
=
{}
__cache__
:
Dict
=
{}
def
__init__
(
def
__init__
(
self
,
fallback_on_perform
=
True
,
allow_gc
=
None
,
nice_errors
=
True
,
schedule
=
None
self
,
fallback_on_perform
=
True
,
allow_gc
=
None
,
nice_errors
=
True
,
schedule
=
None
...
...
aesara/link/c/cmodule.py
浏览文件 @
f0392db7
...
@@ -19,6 +19,7 @@ import textwrap
...
@@ -19,6 +19,7 @@ import textwrap
import
time
import
time
import
warnings
import
warnings
from
io
import
BytesIO
,
StringIO
from
io
import
BytesIO
,
StringIO
from
typing
import
Dict
,
List
,
Set
import
numpy.distutils
import
numpy.distutils
...
@@ -631,38 +632,38 @@ class ModuleCache:
...
@@ -631,38 +632,38 @@ class ModuleCache:
"""
"""
dirname
=
""
dirname
:
str
=
""
"""
"""
The working directory that is managed by this interface.
The working directory that is managed by this interface.
"""
"""
module_from_name
=
{}
module_from_name
:
Dict
=
{}
"""
"""
Maps a module filename to the loaded module object.
Maps a module filename to the loaded module object.
"""
"""
entry_from_key
=
{}
entry_from_key
:
Dict
=
{}
"""
"""
Maps keys to the filename of a .so/.pyd.
Maps keys to the filename of a .so/.pyd.
"""
"""
similar_keys
=
{}
similar_keys
:
Dict
=
{}
"""
"""
Maps a part-of-key to all keys that share this same part.
Maps a part-of-key to all keys that share this same part.
"""
"""
module_hash_to_key_data
=
{}
module_hash_to_key_data
:
Dict
=
{}
"""
"""
Maps a module hash to its corresponding KeyData object.
Maps a module hash to its corresponding KeyData object.
"""
"""
stats
=
[]
stats
:
List
=
[]
"""
"""
A list with counters for the number of hits, loads, compiles issued by
A list with counters for the number of hits, loads, compiles issued by
module_from_key().
module_from_key().
"""
"""
loaded_key_pkl
=
set
()
loaded_key_pkl
:
Set
=
set
()
"""
"""
Set of all key.pkl files that have been loaded.
Set of all key.pkl files that have been loaded.
...
...
aesara/link/c/interface.py
浏览文件 @
f0392db7
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Callable
,
Dict
,
List
,
Text
,
Tuple
from
typing
import
Callable
,
Dict
,
List
,
Text
,
Tuple
,
Union
from
aesara.graph.basic
import
Apply
,
Constant
from
aesara.graph.basic
import
Apply
,
Constant
from
aesara.graph.utils
import
MethodNotDefined
from
aesara.graph.utils
import
MethodNotDefined
...
@@ -129,7 +129,7 @@ class CLinkerObject:
...
@@ -129,7 +129,7 @@ class CLinkerObject:
"""Return a list of code snippets to be inserted in module initialization."""
"""Return a list of code snippets to be inserted in module initialization."""
return
[]
return
[]
def
c_code_cache_version
(
self
)
->
Tuple
[
int
]:
def
c_code_cache_version
(
self
)
->
Union
[
Tuple
[
int
],
Tuple
]:
"""Return a tuple of integers indicating the version of this `Op`.
"""Return a tuple of integers indicating the version of this `Op`.
An empty tuple indicates an 'unversioned' `Op` that will not be cached
An empty tuple indicates an 'unversioned' `Op` that will not be cached
...
@@ -551,7 +551,7 @@ class CLinkerType(CLinkerObject):
...
@@ -551,7 +551,7 @@ class CLinkerType(CLinkerObject):
"""
"""
return
""
return
""
def
c_code_cache_version
(
self
)
->
Tuple
[
int
]:
def
c_code_cache_version
(
self
)
->
Union
[
Tuple
,
Tuple
[
int
]
]:
"""Return a tuple of integers indicating the version of this type.
"""Return a tuple of integers indicating the version of this type.
An empty tuple indicates an 'unversioned' type that will not
An empty tuple indicates an 'unversioned' type that will not
...
...
aesara/link/utils.py
浏览文件 @
f0392db7
...
@@ -3,7 +3,7 @@ import sys
...
@@ -3,7 +3,7 @@ import sys
import
traceback
import
traceback
import
warnings
import
warnings
from
operator
import
itemgetter
from
operator
import
itemgetter
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
NoReturn
,
Optional
,
Tuple
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
NoReturn
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -315,6 +315,10 @@ def raise_with_op(
...
@@ -315,6 +315,10 @@ def raise_with_op(
types
=
[
getattr
(
ipt
,
"type"
,
"No type"
)
for
ipt
in
node
.
inputs
]
types
=
[
getattr
(
ipt
,
"type"
,
"No type"
)
for
ipt
in
node
.
inputs
]
detailed_err_msg
+=
f
"
\n
Inputs types: {types}
\n
"
detailed_err_msg
+=
f
"
\n
Inputs types: {types}
\n
"
shapes
:
Union
[
List
,
str
]
strides
:
Union
[
List
,
str
]
scalar_values
:
Union
[
List
,
str
]
if
thunk
is
not
None
:
if
thunk
is
not
None
:
if
hasattr
(
thunk
,
"inputs"
):
if
hasattr
(
thunk
,
"inputs"
):
shapes
=
[
getattr
(
ipt
[
0
],
"shape"
,
"No shapes"
)
for
ipt
in
thunk
.
inputs
]
shapes
=
[
getattr
(
ipt
[
0
],
"shape"
,
"No shapes"
)
for
ipt
in
thunk
.
inputs
]
...
...
aesara/misc/check_duplicate_key.py
浏览文件 @
f0392db7
import
os
import
os
import
pickle
import
pickle
import
sys
import
sys
from
typing
import
Dict
from
aesara.configdefaults
import
config
from
aesara.configdefaults
import
config
...
@@ -15,8 +16,8 @@ if len(sys.argv) > 1:
...
@@ -15,8 +16,8 @@ if len(sys.argv) > 1:
else
:
else
:
dirs
=
os
.
listdir
(
config
.
compiledir
)
dirs
=
os
.
listdir
(
config
.
compiledir
)
dirs
=
[
os
.
path
.
join
(
config
.
compiledir
,
d
)
for
d
in
dirs
]
dirs
=
[
os
.
path
.
join
(
config
.
compiledir
,
d
)
for
d
in
dirs
]
keys
=
{}
# key -> nb seen
keys
:
Dict
=
{}
# key -> nb seen
mods
=
{}
mods
:
Dict
=
{}
for
dir
in
dirs
:
for
dir
in
dirs
:
key
=
None
key
=
None
...
@@ -48,12 +49,12 @@ if DISPLAY_DUPLICATE_KEYS:
...
@@ -48,12 +49,12 @@ if DISPLAY_DUPLICATE_KEYS:
if
v
>
1
:
if
v
>
1
:
print
(
"Duplicate key (
%
i copies):
%
s"
%
(
v
,
pickle
.
loads
(
k
)))
print
(
"Duplicate key (
%
i copies):
%
s"
%
(
v
,
pickle
.
loads
(
k
)))
nbs_keys
=
{}
# nb seen -> now many key
nbs_keys
:
Dict
=
{}
# nb seen -> now many key
for
val
in
keys
.
values
():
for
val
in
keys
.
values
():
nbs_keys
.
setdefault
(
val
,
0
)
nbs_keys
.
setdefault
(
val
,
0
)
nbs_keys
[
val
]
+=
1
nbs_keys
[
val
]
+=
1
nbs_mod
=
{}
# nb seen -> how many key
nbs_mod
:
Dict
=
{}
# nb seen -> how many key
nbs_mod_to_key
=
{}
# nb seen -> keys
nbs_mod_to_key
=
{}
# nb seen -> keys
more_than_one
=
0
more_than_one
=
0
for
mod
,
kk
in
mods
.
items
():
for
mod
,
kk
in
mods
.
items
():
...
...
aesara/sandbox/multinomial.py
浏览文件 @
f0392db7
import
copy
import
copy
import
warnings
import
warnings
from
typing
import
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -18,7 +19,7 @@ class MultinomialFromUniform(COp):
...
@@ -18,7 +19,7 @@ class MultinomialFromUniform(COp):
TODO : need description for parameter 'odtype'
TODO : need description for parameter 'odtype'
"""
"""
__props__
=
(
"odtype"
,)
__props__
:
Union
[
Tuple
[
str
],
Tuple
[
str
,
str
]]
=
(
"odtype"
,)
def
__init__
(
self
,
odtype
):
def
__init__
(
self
,
odtype
):
self
.
odtype
=
odtype
self
.
odtype
=
odtype
...
...
aesara/scalar/basic.py
浏览文件 @
f0392db7
...
@@ -15,7 +15,7 @@ from collections.abc import Callable
...
@@ -15,7 +15,7 @@ from collections.abc import Callable
from
copy
import
copy
from
copy
import
copy
from
itertools
import
chain
from
itertools
import
chain
from
textwrap
import
dedent
from
textwrap
import
dedent
from
typing
import
Optional
from
typing
import
Dict
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -898,7 +898,7 @@ def upgrade_to_float(*types):
...
@@ -898,7 +898,7 @@ def upgrade_to_float(*types):
Upgrade any int types to float32 or float64 to avoid losing precision.
Upgrade any int types to float32 or float64 to avoid losing precision.
"""
"""
conv
=
{
conv
:
Mapping
[
Type
,
Type
]
=
{
bool
:
float32
,
bool
:
float32
,
int8
:
float32
,
int8
:
float32
,
int16
:
float32
,
int16
:
float32
,
...
@@ -3954,7 +3954,7 @@ class Composite(ScalarOp):
...
@@ -3954,7 +3954,7 @@ class Composite(ScalarOp):
"""
"""
init_param
=
(
"inputs"
,
"outputs"
)
init_param
:
Union
[
Tuple
[
str
,
str
],
Tuple
[
str
]]
=
(
"inputs"
,
"outputs"
)
def
__str__
(
self
):
def
__str__
(
self
):
if
self
.
name
is
None
:
if
self
.
name
is
None
:
...
@@ -4331,7 +4331,7 @@ class Composite(ScalarOp):
...
@@ -4331,7 +4331,7 @@ class Composite(ScalarOp):
class
Compositef32
:
class
Compositef32
:
# This is a dict of scalar op classes that need special handling
# This is a dict of scalar op classes that need special handling
special
=
{}
special
:
Dict
=
{}
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
mapping
=
{}
mapping
=
{}
...
...
aesara/tensor/__init__.py
浏览文件 @
f0392db7
...
@@ -5,12 +5,12 @@ __docformat__ = "restructuredtext en"
...
@@ -5,12 +5,12 @@ __docformat__ = "restructuredtext en"
import
warnings
import
warnings
from
functools
import
singledispatch
from
functools
import
singledispatch
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
NoReturn
,
Optional
def
as_tensor_variable
(
def
as_tensor_variable
(
x
,
name
:
Optional
[
str
]
=
None
,
ndim
:
Optional
[
int
]
=
None
,
**
kwargs
x
,
name
:
Optional
[
str
]
=
None
,
ndim
:
Optional
[
int
]
=
None
,
**
kwargs
):
)
->
Callable
:
"""Convert `x` into the appropriate `TensorType`.
"""Convert `x` into the appropriate `TensorType`.
This function is often used by `make_node` methods of `Op` subclasses to
This function is often used by `make_node` methods of `Op` subclasses to
...
@@ -39,7 +39,9 @@ def as_tensor_variable(
...
@@ -39,7 +39,9 @@ def as_tensor_variable(
@singledispatch
@singledispatch
def
_as_tensor_variable
(
x
,
name
:
Optional
[
str
],
ndim
:
Optional
[
int
],
**
kwargs
):
def
_as_tensor_variable
(
x
,
name
:
Optional
[
str
],
ndim
:
Optional
[
int
],
**
kwargs
)
->
NoReturn
:
raise
NotImplementedError
(
""
)
raise
NotImplementedError
(
""
)
...
...
aesara/tensor/basic.py
浏览文件 @
f0392db7
...
@@ -11,6 +11,7 @@ import warnings
...
@@ -11,6 +11,7 @@ import warnings
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
numbers
import
Number
from
numbers
import
Number
from
typing
import
Dict
import
numpy
as
np
import
numpy
as
np
...
@@ -679,7 +680,7 @@ class Rebroadcast(COp):
...
@@ -679,7 +680,7 @@ class Rebroadcast(COp):
# Mapping from Type to C code (and version) to use.
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
# the output variable is %(oname)s.
c_code_and_version
=
{}
c_code_and_version
:
Dict
=
{}
check_input
=
False
check_input
=
False
__props__
=
(
"axis"
,)
__props__
=
(
"axis"
,)
...
...
aesara/tensor/blas.py
浏览文件 @
f0392db7
...
@@ -140,6 +140,7 @@ except ImportError:
...
@@ -140,6 +140,7 @@ except ImportError:
pass
pass
from
functools
import
reduce
from
functools
import
reduce
from
typing
import
Tuple
,
Union
import
aesara.scalar
import
aesara.scalar
from
aesara.compile.mode
import
optdb
from
aesara.compile.mode
import
optdb
...
@@ -506,7 +507,7 @@ class GemmRelated(COp):
...
@@ -506,7 +507,7 @@ class GemmRelated(COp):
"""
"""
__props__
=
()
__props__
:
Union
[
Tuple
,
Tuple
[
str
]]
=
()
def
c_support_code
(
self
,
**
kwargs
):
def
c_support_code
(
self
,
**
kwargs
):
# return cblas_header_text()
# return cblas_header_text()
...
...
aesara/tensor/elemwise.py
浏览文件 @
f0392db7
from
copy
import
copy
from
copy
import
copy
from
typing
import
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -1297,7 +1298,9 @@ class CAReduce(COp):
...
@@ -1297,7 +1298,9 @@ class CAReduce(COp):
"""
"""
__props__
=
(
"scalar_op"
,
"axis"
)
__props__
:
Union
[
Tuple
[
str
],
Tuple
[
str
,
str
],
Tuple
[
str
,
str
,
str
],
Tuple
[
str
,
str
,
str
,
str
]
]
=
(
"scalar_op"
,
"axis"
)
def
__init__
(
self
,
scalar_op
,
axis
=
None
):
def
__init__
(
self
,
scalar_op
,
axis
=
None
):
if
scalar_op
.
nin
not
in
[
-
1
,
2
]
or
scalar_op
.
nout
!=
1
:
if
scalar_op
.
nin
not
in
[
-
1
,
2
]
or
scalar_op
.
nout
!=
1
:
...
@@ -1682,7 +1685,12 @@ class CAReduceDtype(CAReduce):
...
@@ -1682,7 +1685,12 @@ class CAReduceDtype(CAReduce):
"""
"""
__props__
=
(
"scalar_op"
,
"axis"
,
"dtype"
,
"acc_dtype"
)
__props__
:
Union
[
Tuple
[
str
,
str
,
str
],
Tuple
[
str
,
str
,
str
,
str
]]
=
(
"scalar_op"
,
"axis"
,
"dtype"
,
"acc_dtype"
,
)
def
__init__
(
self
,
scalar_op
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
):
def
__init__
(
self
,
scalar_op
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
):
super
()
.
__init__
(
scalar_op
,
axis
=
axis
)
super
()
.
__init__
(
scalar_op
,
axis
=
axis
)
...
...
aesara/tensor/nlinalg.py
浏览文件 @
f0392db7
import
logging
import
logging
from
functools
import
partial
from
functools
import
partial
from
typing
import
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -224,7 +225,7 @@ class Eig(Op):
...
@@ -224,7 +225,7 @@ class Eig(Op):
"""
"""
_numop
=
staticmethod
(
np
.
linalg
.
eig
)
_numop
=
staticmethod
(
np
.
linalg
.
eig
)
__props__
=
()
__props__
:
Union
[
Tuple
,
Tuple
[
str
]]
=
()
def
make_node
(
self
,
x
):
def
make_node
(
self
,
x
):
x
=
as_tensor_variable
(
x
)
x
=
as_tensor_variable
(
x
)
...
...
aesara/tensor/nnet/blocksparse.py
浏览文件 @
f0392db7
from
typing
import
List
import
numpy
as
np
import
numpy
as
np
import
aesara
import
aesara
...
@@ -26,7 +28,7 @@ class SparseBlockGemv(Op):
...
@@ -26,7 +28,7 @@ class SparseBlockGemv(Op):
__props__
=
(
"inplace"
,)
__props__
=
(
"inplace"
,)
registered_opts
=
[]
registered_opts
:
List
=
[]
def
__init__
(
self
,
inplace
=
False
):
def
__init__
(
self
,
inplace
=
False
):
self
.
inplace
=
inplace
self
.
inplace
=
inplace
...
@@ -147,7 +149,7 @@ class SparseBlockOuter(Op):
...
@@ -147,7 +149,7 @@ class SparseBlockOuter(Op):
__props__
=
(
"inplace"
,)
__props__
=
(
"inplace"
,)
registered_opts
=
[]
registered_opts
:
List
=
[]
def
__init__
(
self
,
inplace
=
False
):
def
__init__
(
self
,
inplace
=
False
):
self
.
inplace
=
inplace
self
.
inplace
=
inplace
...
...
aesara/tensor/shape.py
浏览文件 @
f0392db7
import
warnings
import
warnings
from
typing
import
Dict
import
numpy
as
np
import
numpy
as
np
...
@@ -49,7 +50,7 @@ class Shape(COp):
...
@@ -49,7 +50,7 @@ class Shape(COp):
# Mapping from Type to C code (and version) to use.
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
# the output variable is %(oname)s.
c_code_and_version
=
{}
c_code_and_version
:
Dict
=
{}
check_input
=
False
check_input
=
False
__props__
=
()
__props__
=
()
...
@@ -143,7 +144,7 @@ class Shape_i(COp):
...
@@ -143,7 +144,7 @@ class Shape_i(COp):
# Mapping from Type to C code (and version) to use.
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
# the output variable is %(oname)s.
c_code_and_version
=
{}
c_code_and_version
:
Dict
=
{}
check_input
=
False
check_input
=
False
...
@@ -363,7 +364,7 @@ class SpecifyShape(COp):
...
@@ -363,7 +364,7 @@ class SpecifyShape(COp):
# Mapping from Type to C code (and version) to use.
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
# the output variable is %(oname)s.
c_code_and_version
=
{}
c_code_and_version
:
Dict
=
{}
__props__
=
()
__props__
=
()
_f16_ok
=
True
_f16_ok
=
True
...
...
aesara/utils.py
浏览文件 @
f0392db7
...
@@ -12,6 +12,7 @@ import warnings
...
@@ -12,6 +12,7 @@ import warnings
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
functools
import
partial
,
wraps
from
functools
import
partial
,
wraps
from
typing
import
List
,
Set
__all__
=
[
__all__
=
[
...
@@ -29,7 +30,7 @@ __all__ = [
...
@@ -29,7 +30,7 @@ __all__ = [
]
]
__excepthooks
=
[]
__excepthooks
:
List
=
[]
LOCAL_BITWIDTH
=
struct
.
calcsize
(
"P"
)
*
8
LOCAL_BITWIDTH
=
struct
.
calcsize
(
"P"
)
*
8
...
@@ -374,7 +375,7 @@ def apply_across_args(*fns):
...
@@ -374,7 +375,7 @@ def apply_across_args(*fns):
class
NoDuplicateOptWarningFilter
(
logging
.
Filter
):
class
NoDuplicateOptWarningFilter
(
logging
.
Filter
):
"""Filter to avoid duplicating optimization warnings."""
"""Filter to avoid duplicating optimization warnings."""
prev_msgs
=
set
()
prev_msgs
:
Set
=
set
()
def
filter
(
self
,
record
):
def
filter
(
self
,
record
):
msg
=
record
.
getMessage
()
msg
=
record
.
getMessage
()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论