Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
17748b7d
提交
17748b7d
authored
2月 05, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
2月 05, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove accidental print statements
上级
4ac1e637
隐藏空白字符变更
内嵌
并排
正在显示
33 个修改的文件
包含
161 行增加
和
168 行删除
+161
-168
pyproject.toml
pyproject.toml
+9
-2
breakpoint.py
pytensor/breakpoint.py
+8
-8
compiledir.py
pytensor/compile/compiledir.py
+23
-23
debugmode.py
pytensor/compile/debugmode.py
+4
-4
mode.py
pytensor/compile/mode.py
+1
-1
monitormode.py
pytensor/compile/monitormode.py
+3
-3
nanguardmode.py
pytensor/compile/nanguardmode.py
+1
-1
profiling.py
pytensor/compile/profiling.py
+2
-2
features.py
pytensor/graph/features.py
+11
-11
fg.py
pytensor/graph/fg.py
+1
-1
basic.py
pytensor/graph/rewriting/basic.py
+10
-10
utils.py
pytensor/graph/utils.py
+2
-2
basic.py
pytensor/link/c/basic.py
+6
-6
op.py
pytensor/link/c/op.py
+1
-1
printing.py
pytensor/printing.py
+3
-4
basic.py
pytensor/tensor/basic.py
+0
-1
nlinalg.py
pytensor/tensor/nlinalg.py
+4
-6
blas.py
pytensor/tensor/rewriting/blas.py
+1
-1
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+4
-4
math.py
pytensor/tensor/rewriting/math.py
+8
-8
extract-slow-tests.py
scripts/slowest_tests/extract-slow-tests.py
+1
-1
test_d3viz.py
tests/d3viz/test_d3viz.py
+1
-1
test_cmodule.py
tests/link/c/test_cmodule.py
+0
-1
test_basic.py
tests/link/numba/test_basic.py
+0
-1
test_vm.py
tests/link/test_vm.py
+26
-25
test_basic.py
tests/scan/test_basic.py
+1
-2
test_math.py
tests/tensor/rewriting/test_math.py
+19
-17
test_complex.py
tests/tensor/test_complex.py
+3
-9
test_fft.py
tests/tensor/test_fft.py
+0
-1
test_shape.py
tests/tensor/test_shape.py
+0
-1
test_config.py
tests/test_config.py
+1
-1
test_printing.py
tests/test_printing.py
+3
-3
unittest_tools.py
tests/unittest_tools.py
+4
-6
没有找到文件。
pyproject.toml
浏览文件 @
17748b7d
...
@@ -129,7 +129,7 @@ exclude = ["doc/", "pytensor/_version.py"]
...
@@ -129,7 +129,7 @@ exclude = ["doc/", "pytensor/_version.py"]
docstring-code-format
=
true
docstring-code-format
=
true
[tool.ruff.lint]
[tool.ruff.lint]
select
=
[
"B905"
,
"C"
,
"E"
,
"F"
,
"I"
,
"UP"
,
"W"
,
"RUF"
,
"PERF"
,
"PTH"
,
"ISC"
]
select
=
[
"B905"
,
"C"
,
"E"
,
"F"
,
"I"
,
"UP"
,
"W"
,
"RUF"
,
"PERF"
,
"PTH"
,
"ISC"
,
"T20"
]
ignore
=
[
"C408"
,
"C901"
,
"E501"
,
"E741"
,
"RUF012"
,
"PERF203"
,
"ISC001"
]
ignore
=
[
"C408"
,
"C901"
,
"E501"
,
"E741"
,
"RUF012"
,
"PERF203"
,
"ISC001"
]
unfixable
=
[
unfixable
=
[
# zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead
# zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead
...
@@ -144,7 +144,12 @@ lines-after-imports = 2
...
@@ -144,7 +144,12 @@ lines-after-imports = 2
# TODO: Get rid of these:
# TODO: Get rid of these:
"**/__init__.py"
=
[
"F401"
,
"E402"
,
"F403"
]
"**/__init__.py"
=
[
"F401"
,
"E402"
,
"F403"
]
"pytensor/tensor/linalg.py"
=
["F403"]
"pytensor/tensor/linalg.py"
=
["F403"]
"pytensor/link/c/cmodule.py"
=
["PTH"]
"pytensor/link/c/cmodule.py"
=
[
"PTH"
,
"T201"
]
"pytensor/misc/elemwise_time_test.py"
=
["T201"]
"pytensor/misc/elemwise_openmp_speedup.py"
=
["T201"]
"pytensor/misc/check_duplicate_key.py"
=
["T201"]
"pytensor/misc/check_blas.py"
=
["T201"]
"pytensor/bin/pytensor_cache.py"
=
["T201"]
# For the tests we skip because `pytest.importorskip` is used:
# For the tests we skip because `pytest.importorskip` is used:
"tests/link/jax/test_scalar.py"
=
["E402"]
"tests/link/jax/test_scalar.py"
=
["E402"]
"tests/link/jax/test_tensor_basic.py"
=
["E402"]
"tests/link/jax/test_tensor_basic.py"
=
["E402"]
...
@@ -158,6 +163,8 @@ lines-after-imports = 2
...
@@ -158,6 +163,8 @@ lines-after-imports = 2
"tests/sparse/test_sp2.py"
=
["E402"]
"tests/sparse/test_sp2.py"
=
["E402"]
"tests/sparse/test_utils.py"
=
["E402"]
"tests/sparse/test_utils.py"
=
["E402"]
"tests/sparse/sandbox/test_sp.py"
=
[
"E402"
,
"F401"
]
"tests/sparse/sandbox/test_sp.py"
=
[
"E402"
,
"F401"
]
"tests/compile/test_monitormode.py"
=
["T201"]
"scripts/run_mypy.py"
=
["T201"]
[tool.mypy]
[tool.mypy]
...
...
pytensor/breakpoint.py
浏览文件 @
17748b7d
...
@@ -108,14 +108,14 @@ class PdbBreakpoint(Op):
...
@@ -108,14 +108,14 @@ class PdbBreakpoint(Op):
f
"'{self.name}' could not be casted to NumPy arrays"
f
"'{self.name}' could not be casted to NumPy arrays"
)
)
print
(
"
\n
"
)
print
(
"
\n
"
)
# noqa: T201
print
(
"-------------------------------------------------"
)
print
(
"-------------------------------------------------"
)
# noqa: T201
print
(
f
"Conditional breakpoint '{self.name}' activated
\n
"
)
print
(
f
"Conditional breakpoint '{self.name}' activated
\n
"
)
# noqa: T201
print
(
"The monitored variables are stored, in order,"
)
print
(
"The monitored variables are stored, in order,"
)
# noqa: T201
print
(
"in the list variable 'monitored' as NumPy arrays.
\n
"
)
print
(
"in the list variable 'monitored' as NumPy arrays.
\n
"
)
# noqa: T201
print
(
"Their contents can be altered and, when execution"
)
print
(
"Their contents can be altered and, when execution"
)
# noqa: T201
print
(
"resumes, the updated values will be used."
)
print
(
"resumes, the updated values will be used."
)
# noqa: T201
print
(
"-------------------------------------------------"
)
print
(
"-------------------------------------------------"
)
# noqa: T201
try
:
try
:
import
pudb
import
pudb
...
...
pytensor/compile/compiledir.py
浏览文件 @
17748b7d
...
@@ -95,10 +95,10 @@ def cleanup():
...
@@ -95,10 +95,10 @@ def cleanup():
def
print_title
(
title
,
overline
=
""
,
underline
=
""
):
def
print_title
(
title
,
overline
=
""
,
underline
=
""
):
len_title
=
len
(
title
)
len_title
=
len
(
title
)
if
overline
:
if
overline
:
print
(
str
(
overline
)
*
len_title
)
print
(
str
(
overline
)
*
len_title
)
# noqa: T201
print
(
title
)
print
(
title
)
# noqa: T201
if
underline
:
if
underline
:
print
(
str
(
underline
)
*
len_title
)
print
(
str
(
underline
)
*
len_title
)
# noqa: T201
def
print_compiledir_content
():
def
print_compiledir_content
():
...
@@ -159,7 +159,7 @@ def print_compiledir_content():
...
@@ -159,7 +159,7 @@ def print_compiledir_content():
_logger
.
error
(
f
"Could not read key file '{filename}'."
)
_logger
.
error
(
f
"Could not read key file '{filename}'."
)
print_title
(
f
"PyTensor cache: {compiledir}"
,
overline
=
"="
,
underline
=
"="
)
print_title
(
f
"PyTensor cache: {compiledir}"
,
overline
=
"="
,
underline
=
"="
)
print
()
print
()
# noqa: T201
print_title
(
f
"List of {len(table)} compiled individual ops"
,
underline
=
"+"
)
print_title
(
f
"List of {len(table)} compiled individual ops"
,
underline
=
"+"
)
print_title
(
print_title
(
...
@@ -168,9 +168,9 @@ def print_compiledir_content():
...
@@ -168,9 +168,9 @@ def print_compiledir_content():
)
)
table
=
sorted
(
table
,
key
=
lambda
t
:
str
(
t
[
1
]))
table
=
sorted
(
table
,
key
=
lambda
t
:
str
(
t
[
1
]))
for
dir
,
op
,
types
,
compile_time
in
table
:
for
dir
,
op
,
types
,
compile_time
in
table
:
print
(
dir
,
f
"{compile_time:.3f}s"
,
op
,
types
)
print
(
dir
,
f
"{compile_time:.3f}s"
,
op
,
types
)
# noqa: T201
print
()
print
()
# noqa: T201
print_title
(
print_title
(
f
"List of {len(table_multiple_ops)} compiled sets of ops"
,
underline
=
"+"
f
"List of {len(table_multiple_ops)} compiled sets of ops"
,
underline
=
"+"
)
)
...
@@ -180,9 +180,9 @@ def print_compiledir_content():
...
@@ -180,9 +180,9 @@ def print_compiledir_content():
)
)
table_multiple_ops
=
sorted
(
table_multiple_ops
,
key
=
lambda
t
:
(
t
[
1
],
t
[
2
]))
table_multiple_ops
=
sorted
(
table_multiple_ops
,
key
=
lambda
t
:
(
t
[
1
],
t
[
2
]))
for
dir
,
ops_to_str
,
types_to_str
,
compile_time
in
table_multiple_ops
:
for
dir
,
ops_to_str
,
types_to_str
,
compile_time
in
table_multiple_ops
:
print
(
dir
,
f
"{compile_time:.3f}s"
,
ops_to_str
,
types_to_str
)
print
(
dir
,
f
"{compile_time:.3f}s"
,
ops_to_str
,
types_to_str
)
# noqa: T201
print
()
print
()
# noqa: T201
print_title
(
print_title
(
(
(
f
"List of {len(table_op_class)} compiled Op classes and "
f
"List of {len(table_op_class)} compiled Op classes and "
...
@@ -191,33 +191,33 @@ def print_compiledir_content():
...
@@ -191,33 +191,33 @@ def print_compiledir_content():
underline
=
"+"
,
underline
=
"+"
,
)
)
for
op_class
,
nb
in
reversed
(
table_op_class
.
most_common
()):
for
op_class
,
nb
in
reversed
(
table_op_class
.
most_common
()):
print
(
op_class
,
nb
)
print
(
op_class
,
nb
)
# noqa: T201
if
big_key_files
:
if
big_key_files
:
big_key_files
=
sorted
(
big_key_files
,
key
=
lambda
t
:
str
(
t
[
1
]))
big_key_files
=
sorted
(
big_key_files
,
key
=
lambda
t
:
str
(
t
[
1
]))
big_total_size
=
sum
(
sz
for
_
,
sz
,
_
in
big_key_files
)
big_total_size
=
sum
(
sz
for
_
,
sz
,
_
in
big_key_files
)
print
(
print
(
# noqa: T201
f
"There are directories with key files bigger than {int(max_key_file_size)} bytes "
f
"There are directories with key files bigger than {int(max_key_file_size)} bytes "
"(they probably contain big tensor constants)"
"(they probably contain big tensor constants)"
)
)
print
(
print
(
# noqa: T201
f
"They use {int(big_total_size)} bytes out of {int(total_key_sizes)} (total size "
f
"They use {int(big_total_size)} bytes out of {int(total_key_sizes)} (total size "
"used by all key files)"
"used by all key files)"
)
)
for
dir
,
size
,
ops
in
big_key_files
:
for
dir
,
size
,
ops
in
big_key_files
:
print
(
dir
,
size
,
ops
)
print
(
dir
,
size
,
ops
)
# noqa: T201
nb_keys
=
sorted
(
nb_keys
.
items
())
nb_keys
=
sorted
(
nb_keys
.
items
())
print
()
print
()
# noqa: T201
print_title
(
"Number of keys for a compiled module"
,
underline
=
"+"
)
print_title
(
"Number of keys for a compiled module"
,
underline
=
"+"
)
print_title
(
print_title
(
"number of keys/number of modules with that number of keys"
,
underline
=
"-"
"number of keys/number of modules with that number of keys"
,
underline
=
"-"
)
)
for
n_k
,
n_m
in
nb_keys
:
for
n_k
,
n_m
in
nb_keys
:
print
(
n_k
,
n_m
)
print
(
n_k
,
n_m
)
# noqa: T201
print
()
print
()
# noqa: T201
print
(
print
(
# noqa: T201
f
"Skipped {int(zeros_op)} files that contained 0 op "
f
"Skipped {int(zeros_op)} files that contained 0 op "
"(are they always pytensor.scalar ops?)"
"(are they always pytensor.scalar ops?)"
)
)
...
@@ -242,18 +242,18 @@ def basecompiledir_ls():
...
@@ -242,18 +242,18 @@ def basecompiledir_ls():
subdirs
=
sorted
(
subdirs
)
subdirs
=
sorted
(
subdirs
)
others
=
sorted
(
others
)
others
=
sorted
(
others
)
print
(
f
"Base compile dir is {config.base_compiledir}"
)
print
(
f
"Base compile dir is {config.base_compiledir}"
)
# noqa: T201
print
(
"Sub-directories (possible compile caches):"
)
print
(
"Sub-directories (possible compile caches):"
)
# noqa: T201
for
d
in
subdirs
:
for
d
in
subdirs
:
print
(
f
" {d}"
)
print
(
f
" {d}"
)
# noqa: T201
if
not
subdirs
:
if
not
subdirs
:
print
(
" (None)"
)
print
(
" (None)"
)
# noqa: T201
if
others
:
if
others
:
print
()
print
()
# noqa: T201
print
(
"Other files in base_compiledir:"
)
print
(
"Other files in base_compiledir:"
)
# noqa: T201
for
f
in
others
:
for
f
in
others
:
print
(
f
" {f}"
)
print
(
f
" {f}"
)
# noqa: T201
def
basecompiledir_purge
():
def
basecompiledir_purge
():
...
...
pytensor/compile/debugmode.py
浏览文件 @
17748b7d
...
@@ -1315,9 +1315,9 @@ class _VariableEquivalenceTracker:
...
@@ -1315,9 +1315,9 @@ class _VariableEquivalenceTracker:
def
printstuff
(
self
):
def
printstuff
(
self
):
for
key
in
self
.
equiv
:
for
key
in
self
.
equiv
:
print
(
key
)
print
(
key
)
# noqa: T201
for
e
in
self
.
equiv
[
key
]:
for
e
in
self
.
equiv
[
key
]:
print
(
" "
,
e
)
print
(
" "
,
e
)
# noqa: T201
# List of default version of make thunk.
# List of default version of make thunk.
...
@@ -1569,7 +1569,7 @@ class _Linker(LocalLinker):
...
@@ -1569,7 +1569,7 @@ class _Linker(LocalLinker):
#####
#####
for
r
,
s
in
storage_map
.
items
():
for
r
,
s
in
storage_map
.
items
():
if
s
[
0
]
is
not
None
:
if
s
[
0
]
is
not
None
:
print
(
r
,
s
)
print
(
r
,
s
)
# noqa: T201
assert
s
[
0
]
is
None
assert
s
[
0
]
is
None
# try:
# try:
...
@@ -2079,7 +2079,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
...
@@ -2079,7 +2079,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
raise
StochasticOrder
(
infolog
.
getvalue
())
raise
StochasticOrder
(
infolog
.
getvalue
())
else
:
else
:
if
self
.
verbose
:
if
self
.
verbose
:
print
(
print
(
# noqa: T201
"OPTCHECK: optimization"
,
"OPTCHECK: optimization"
,
i
,
i
,
"of"
,
"of"
,
...
...
pytensor/compile/mode.py
浏览文件 @
17748b7d
...
@@ -178,7 +178,7 @@ class PrintCurrentFunctionGraph(GraphRewriter):
...
@@ -178,7 +178,7 @@ class PrintCurrentFunctionGraph(GraphRewriter):
def
apply
(
self
,
fgraph
):
def
apply
(
self
,
fgraph
):
import
pytensor.printing
import
pytensor.printing
print
(
"PrintCurrentFunctionGraph:"
,
self
.
header
)
print
(
"PrintCurrentFunctionGraph:"
,
self
.
header
)
# noqa: T201
pytensor
.
printing
.
debugprint
(
fgraph
.
outputs
)
pytensor
.
printing
.
debugprint
(
fgraph
.
outputs
)
...
...
pytensor/compile/monitormode.py
浏览文件 @
17748b7d
...
@@ -108,8 +108,8 @@ def detect_nan(fgraph, i, node, fn):
...
@@ -108,8 +108,8 @@ def detect_nan(fgraph, i, node, fn):
not
isinstance
(
output
[
0
],
np
.
random
.
RandomState
|
np
.
random
.
Generator
)
not
isinstance
(
output
[
0
],
np
.
random
.
RandomState
|
np
.
random
.
Generator
)
and
np
.
isnan
(
output
[
0
])
.
any
()
and
np
.
isnan
(
output
[
0
])
.
any
()
):
):
print
(
"*** NaN detected ***"
)
print
(
"*** NaN detected ***"
)
# noqa: T201
debugprint
(
node
)
debugprint
(
node
)
print
(
f
"Inputs : {[input[0] for input in fn.inputs]}"
)
print
(
f
"Inputs : {[input[0] for input in fn.inputs]}"
)
# noqa: T201
print
(
f
"Outputs: {[output[0] for output in fn.outputs]}"
)
print
(
f
"Outputs: {[output[0] for output in fn.outputs]}"
)
# noqa: T201
break
break
pytensor/compile/nanguardmode.py
浏览文件 @
17748b7d
...
@@ -236,7 +236,7 @@ class NanGuardMode(Mode):
...
@@ -236,7 +236,7 @@ class NanGuardMode(Mode):
if
config
.
NanGuardMode__action
==
"raise"
:
if
config
.
NanGuardMode__action
==
"raise"
:
raise
AssertionError
(
msg
)
raise
AssertionError
(
msg
)
elif
config
.
NanGuardMode__action
==
"pdb"
:
elif
config
.
NanGuardMode__action
==
"pdb"
:
print
(
msg
)
print
(
msg
)
# noqa: T201
import
pdb
import
pdb
pdb
.
set_trace
()
pdb
.
set_trace
()
...
...
pytensor/compile/profiling.py
浏览文件 @
17748b7d
...
@@ -82,7 +82,7 @@ def _atexit_print_fn():
...
@@ -82,7 +82,7 @@ def _atexit_print_fn():
to_sum
.
append
(
ps
)
to_sum
.
append
(
ps
)
else
:
else
:
# TODO print the name if there is one!
# TODO print the name if there is one!
print
(
"Skipping empty Profile"
)
print
(
"Skipping empty Profile"
)
# noqa: T201
if
len
(
to_sum
)
>
1
:
if
len
(
to_sum
)
>
1
:
# Make a global profile
# Make a global profile
cum
=
copy
.
copy
(
to_sum
[
0
])
cum
=
copy
.
copy
(
to_sum
[
0
])
...
@@ -125,7 +125,7 @@ def _atexit_print_fn():
...
@@ -125,7 +125,7 @@ def _atexit_print_fn():
assert
len
(
merge
)
==
len
(
cum
.
rewriter_profile
[
1
])
assert
len
(
merge
)
==
len
(
cum
.
rewriter_profile
[
1
])
cum
.
rewriter_profile
=
(
cum
.
rewriter_profile
[
0
],
merge
)
cum
.
rewriter_profile
=
(
cum
.
rewriter_profile
[
0
],
merge
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
e
)
print
(
e
)
# noqa: T201
cum
.
rewriter_profile
=
None
cum
.
rewriter_profile
=
None
else
:
else
:
cum
.
rewriter_profile
=
None
cum
.
rewriter_profile
=
None
...
...
pytensor/graph/features.py
浏览文件 @
17748b7d
...
@@ -491,7 +491,7 @@ class Validator(Feature):
...
@@ -491,7 +491,7 @@ class Validator(Feature):
if
verbose
:
if
verbose
:
r
=
uf
.
f_locals
.
get
(
"r"
,
""
)
r
=
uf
.
f_locals
.
get
(
"r"
,
""
)
reason
=
uf_info
.
function
reason
=
uf_info
.
function
print
(
f
"validate failed on node {r}.
\n
Reason: {reason}, {e}"
)
print
(
f
"validate failed on node {r}.
\n
Reason: {reason}, {e}"
)
# noqa: T201
raise
raise
t1
=
time
.
perf_counter
()
t1
=
time
.
perf_counter
()
if
fgraph
.
profile
:
if
fgraph
.
profile
:
...
@@ -603,13 +603,13 @@ class ReplaceValidate(History, Validator):
...
@@ -603,13 +603,13 @@ class ReplaceValidate(History, Validator):
except
Exception
as
e
:
except
Exception
as
e
:
fgraph
.
revert
(
chk
)
fgraph
.
revert
(
chk
)
if
verbose
:
if
verbose
:
print
(
print
(
# noqa: T201
f
"rewriting: validate failed on node {r}.
\n
Reason: {reason}, {e}"
f
"rewriting: validate failed on node {r}.
\n
Reason: {reason}, {e}"
)
)
raise
raise
if
verbose
:
if
verbose
:
print
(
print
(
# noqa: T201
f
"rewriting: rewrite {reason} replaces {r} of {r.owner} with {new_r} of {new_r.owner}"
f
"rewriting: rewrite {reason} replaces {r} of {r.owner} with {new_r} of {new_r.owner}"
)
)
...
@@ -692,11 +692,11 @@ class NodeFinder(Bookkeeper):
...
@@ -692,11 +692,11 @@ class NodeFinder(Bookkeeper):
except
TypeError
:
# node.op is unhashable
except
TypeError
:
# node.op is unhashable
return
return
except
Exception
as
e
:
except
Exception
as
e
:
print
(
"OFFENDING node"
,
type
(
node
),
type
(
node
.
op
),
file
=
sys
.
stderr
)
print
(
"OFFENDING node"
,
type
(
node
),
type
(
node
.
op
),
file
=
sys
.
stderr
)
# noqa: T201
try
:
try
:
print
(
"OFFENDING node hash"
,
hash
(
node
.
op
),
file
=
sys
.
stderr
)
print
(
"OFFENDING node hash"
,
hash
(
node
.
op
),
file
=
sys
.
stderr
)
# noqa: T201
except
Exception
:
except
Exception
:
print
(
"OFFENDING node not hashable"
,
file
=
sys
.
stderr
)
print
(
"OFFENDING node not hashable"
,
file
=
sys
.
stderr
)
# noqa: T201
raise
e
raise
e
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
...
@@ -725,7 +725,7 @@ class PrintListener(Feature):
...
@@ -725,7 +725,7 @@ class PrintListener(Feature):
def
on_attach
(
self
,
fgraph
):
def
on_attach
(
self
,
fgraph
):
if
self
.
active
:
if
self
.
active
:
print
(
"-- attaching to: "
,
fgraph
)
print
(
"-- attaching to: "
,
fgraph
)
# noqa: T201
def
on_detach
(
self
,
fgraph
):
def
on_detach
(
self
,
fgraph
):
"""
"""
...
@@ -733,19 +733,19 @@ class PrintListener(Feature):
...
@@ -733,19 +733,19 @@ class PrintListener(Feature):
that it installed into the function_graph
that it installed into the function_graph
"""
"""
if
self
.
active
:
if
self
.
active
:
print
(
"-- detaching from: "
,
fgraph
)
print
(
"-- detaching from: "
,
fgraph
)
# noqa: T201
def
on_import
(
self
,
fgraph
,
node
,
reason
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
if
self
.
active
:
if
self
.
active
:
print
(
f
"-- importing: {node}, reason: {reason}"
)
print
(
f
"-- importing: {node}, reason: {reason}"
)
# noqa: T201
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
if
self
.
active
:
if
self
.
active
:
print
(
f
"-- pruning: {node}, reason: {reason}"
)
print
(
f
"-- pruning: {node}, reason: {reason}"
)
# noqa: T201
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
=
None
):
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
=
None
):
if
self
.
active
:
if
self
.
active
:
print
(
f
"-- changing ({node}.inputs[{i}]) from {r} to {new_r}"
)
print
(
f
"-- changing ({node}.inputs[{i}]) from {r} to {new_r}"
)
# noqa: T201
class
PreserveVariableAttributes
(
Feature
):
class
PreserveVariableAttributes
(
Feature
):
...
...
pytensor/graph/fg.py
浏览文件 @
17748b7d
...
@@ -491,7 +491,7 @@ class FunctionGraph(MetaObject):
...
@@ -491,7 +491,7 @@ class FunctionGraph(MetaObject):
if
verbose
is
None
:
if
verbose
is
None
:
verbose
=
config
.
optimizer_verbose
verbose
=
config
.
optimizer_verbose
if
verbose
:
if
verbose
:
print
(
print
(
# noqa: T201
f
"rewriting: rewrite {reason} replaces {var} of {var.owner} with {new_var} of {new_var.owner}"
f
"rewriting: rewrite {reason} replaces {var} of {var.owner} with {new_var} of {new_var.owner}"
)
)
...
...
pytensor/graph/rewriting/basic.py
浏览文件 @
17748b7d
...
@@ -1002,7 +1002,7 @@ class MetaNodeRewriter(NodeRewriter):
...
@@ -1002,7 +1002,7 @@ class MetaNodeRewriter(NodeRewriter):
# ensure we have data for all input variables that need it
# ensure we have data for all input variables that need it
if
missing
:
if
missing
:
if
self
.
verbose
>
0
:
if
self
.
verbose
>
0
:
print
(
print
(
# noqa: T201
f
"{self.__class__.__name__} cannot meta-rewrite {node}, "
f
"{self.__class__.__name__} cannot meta-rewrite {node}, "
f
"{len(missing)} of {int(node.nin)} input shapes unknown"
f
"{len(missing)} of {int(node.nin)} input shapes unknown"
)
)
...
@@ -1010,7 +1010,7 @@ class MetaNodeRewriter(NodeRewriter):
...
@@ -1010,7 +1010,7 @@ class MetaNodeRewriter(NodeRewriter):
# now we can apply the different rewrites in turn,
# now we can apply the different rewrites in turn,
# compile the resulting subgraphs and time their execution
# compile the resulting subgraphs and time their execution
if
self
.
verbose
>
1
:
if
self
.
verbose
>
1
:
print
(
print
(
# noqa: T201
f
"{self.__class__.__name__} meta-rewriting {node} ({len(self.get_rewrites(node))} choices):"
f
"{self.__class__.__name__} meta-rewriting {node} ({len(self.get_rewrites(node))} choices):"
)
)
timings
=
[]
timings
=
[]
...
@@ -1027,20 +1027,20 @@ class MetaNodeRewriter(NodeRewriter):
...
@@ -1027,20 +1027,20 @@ class MetaNodeRewriter(NodeRewriter):
continue
continue
except
Exception
as
e
:
except
Exception
as
e
:
if
self
.
verbose
>
0
:
if
self
.
verbose
>
0
:
print
(
f
"* {node_rewriter}: exception"
,
e
)
print
(
f
"* {node_rewriter}: exception"
,
e
)
# noqa: T201
continue
continue
else
:
else
:
if
self
.
verbose
>
1
:
if
self
.
verbose
>
1
:
print
(
f
"* {node_rewriter}: {timing:.5g} sec"
)
print
(
f
"* {node_rewriter}: {timing:.5g} sec"
)
# noqa: T201
timings
.
append
((
timing
,
outputs
,
node_rewriter
))
timings
.
append
((
timing
,
outputs
,
node_rewriter
))
else
:
else
:
if
self
.
verbose
>
0
:
if
self
.
verbose
>
0
:
print
(
f
"* {node_rewriter}: not applicable"
)
print
(
f
"* {node_rewriter}: not applicable"
)
# noqa: T201
# finally, we choose the fastest one
# finally, we choose the fastest one
if
timings
:
if
timings
:
timings
.
sort
()
timings
.
sort
()
if
self
.
verbose
>
1
:
if
self
.
verbose
>
1
:
print
(
f
"= {timings[0][2]}"
)
print
(
f
"= {timings[0][2]}"
)
# noqa: T201
return
timings
[
0
][
1
]
return
timings
[
0
][
1
]
return
return
...
@@ -1305,7 +1305,7 @@ class SequentialNodeRewriter(NodeRewriter):
...
@@ -1305,7 +1305,7 @@ class SequentialNodeRewriter(NodeRewriter):
new_vars
=
list
(
new_repl
.
values
())
new_vars
=
list
(
new_repl
.
values
())
if
config
.
optimizer_verbose
:
if
config
.
optimizer_verbose
:
print
(
print
(
# noqa: T201
f
"rewriting: rewrite {rewrite} replaces node {node} with {new_repl}"
f
"rewriting: rewrite {rewrite} replaces node {node} with {new_repl}"
)
)
...
@@ -2641,21 +2641,21 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
...
@@ -2641,21 +2641,21 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
try
:
try
:
o
.
print_profile
(
stream
,
prof
,
level
+
2
)
o
.
print_profile
(
stream
,
prof
,
level
+
2
)
except
NotImplementedError
:
except
NotImplementedError
:
print
(
blanc
,
"merge not implemented for "
,
o
)
print
(
blanc
,
"merge not implemented for "
,
o
)
# noqa: T201
for
o
,
prof
in
zip
(
for
o
,
prof
in
zip
(
rewrite
.
final_rewriters
,
final_sub_profs
[
i
],
strict
=
True
rewrite
.
final_rewriters
,
final_sub_profs
[
i
],
strict
=
True
):
):
try
:
try
:
o
.
print_profile
(
stream
,
prof
,
level
+
2
)
o
.
print_profile
(
stream
,
prof
,
level
+
2
)
except
NotImplementedError
:
except
NotImplementedError
:
print
(
blanc
,
"merge not implemented for "
,
o
)
print
(
blanc
,
"merge not implemented for "
,
o
)
# noqa: T201
for
o
,
prof
in
zip
(
for
o
,
prof
in
zip
(
rewrite
.
cleanup_rewriters
,
cleanup_sub_profs
[
i
],
strict
=
True
rewrite
.
cleanup_rewriters
,
cleanup_sub_profs
[
i
],
strict
=
True
):
):
try
:
try
:
o
.
print_profile
(
stream
,
prof
,
level
+
2
)
o
.
print_profile
(
stream
,
prof
,
level
+
2
)
except
NotImplementedError
:
except
NotImplementedError
:
print
(
blanc
,
"merge not implemented for "
,
o
)
print
(
blanc
,
"merge not implemented for "
,
o
)
# noqa: T201
@staticmethod
@staticmethod
def
merge_profile
(
prof1
,
prof2
):
def
merge_profile
(
prof1
,
prof2
):
...
...
pytensor/graph/utils.py
浏览文件 @
17748b7d
...
@@ -274,9 +274,9 @@ class Scratchpad:
...
@@ -274,9 +274,9 @@ class Scratchpad:
return
"scratchpad"
+
str
(
self
.
__dict__
)
return
"scratchpad"
+
str
(
self
.
__dict__
)
def
info
(
self
):
def
info
(
self
):
print
(
f
"<pytensor.graph.utils.scratchpad instance at {id(self)}>"
)
print
(
f
"<pytensor.graph.utils.scratchpad instance at {id(self)}>"
)
# noqa: T201
for
k
,
v
in
self
.
__dict__
.
items
():
for
k
,
v
in
self
.
__dict__
.
items
():
print
(
f
" {k}: {v}"
)
print
(
f
" {k}: {v}"
)
# noqa: T201
# These two methods have been added to help Mypy
# These two methods have been added to help Mypy
def
__getattribute__
(
self
,
name
):
def
__getattribute__
(
self
,
name
):
...
...
pytensor/link/c/basic.py
浏览文件 @
17748b7d
...
@@ -875,10 +875,10 @@ class CLinker(Linker):
...
@@ -875,10 +875,10 @@ class CLinker(Linker):
self
.
c_init_code_apply
=
c_init_code_apply
self
.
c_init_code_apply
=
c_init_code_apply
if
(
self
.
init_tasks
,
self
.
tasks
)
!=
self
.
get_init_tasks
():
if
(
self
.
init_tasks
,
self
.
tasks
)
!=
self
.
get_init_tasks
():
print
(
"init_tasks
\n
"
,
self
.
init_tasks
,
file
=
sys
.
stderr
)
print
(
"init_tasks
\n
"
,
self
.
init_tasks
,
file
=
sys
.
stderr
)
# noqa: T201
print
(
self
.
get_init_tasks
()[
0
],
file
=
sys
.
stderr
)
print
(
self
.
get_init_tasks
()[
0
],
file
=
sys
.
stderr
)
# noqa: T201
print
(
"tasks
\n
"
,
self
.
tasks
,
file
=
sys
.
stderr
)
print
(
"tasks
\n
"
,
self
.
tasks
,
file
=
sys
.
stderr
)
# noqa: T201
print
(
self
.
get_init_tasks
()[
1
],
file
=
sys
.
stderr
)
print
(
self
.
get_init_tasks
()[
1
],
file
=
sys
.
stderr
)
# noqa: T201
assert
(
self
.
init_tasks
,
self
.
tasks
)
==
self
.
get_init_tasks
()
assert
(
self
.
init_tasks
,
self
.
tasks
)
==
self
.
get_init_tasks
()
# List of indices that should be ignored when passing the arguments
# List of indices that should be ignored when passing the arguments
...
@@ -1756,7 +1756,7 @@ class _CThunk:
...
@@ -1756,7 +1756,7 @@ class _CThunk:
exc_value
=
exc_type
(
_exc_value
)
exc_value
=
exc_type
(
_exc_value
)
exc_value
.
__thunk_trace__
=
trace
exc_value
.
__thunk_trace__
=
trace
except
Exception
:
except
Exception
:
print
(
print
(
# noqa: T201
(
(
"ERROR retrieving error_storage."
"ERROR retrieving error_storage."
"Was the error set in the c code?"
"Was the error set in the c code?"
...
@@ -1764,7 +1764,7 @@ class _CThunk:
...
@@ -1764,7 +1764,7 @@ class _CThunk:
end
=
" "
,
end
=
" "
,
file
=
sys
.
stderr
,
file
=
sys
.
stderr
,
)
)
print
(
self
.
error_storage
,
file
=
sys
.
stderr
)
print
(
self
.
error_storage
,
file
=
sys
.
stderr
)
# noqa: T201
raise
raise
raise
exc_value
.
with_traceback
(
exc_trace
)
raise
exc_value
.
with_traceback
(
exc_trace
)
...
...
pytensor/link/c/op.py
浏览文件 @
17748b7d
...
@@ -79,7 +79,7 @@ class COp(Op, CLinkerOp):
...
@@ -79,7 +79,7 @@ class COp(Op, CLinkerOp):
# that don't implement c code. In those cases, we
# that don't implement c code. In those cases, we
# don't want to print a warning.
# don't want to print a warning.
cl
.
get_dynamic_module
()
cl
.
get_dynamic_module
()
print
(
f
"Disabling C code for {self} due to unsupported float16"
)
warnings
.
warn
(
f
"Disabling C code for {self} due to unsupported float16"
)
raise
NotImplementedError
(
"float16"
)
raise
NotImplementedError
(
"float16"
)
outputs
=
cl
.
make_thunk
(
outputs
=
cl
.
make_thunk
(
input_storage
=
node_input_storage
,
output_storage
=
node_output_storage
input_storage
=
node_input_storage
,
output_storage
=
node_output_storage
...
...
pytensor/printing.py
浏览文件 @
17748b7d
...
@@ -726,7 +726,7 @@ def _print_fn(op, xin):
...
@@ -726,7 +726,7 @@ def _print_fn(op, xin):
pmsg
=
temp
()
pmsg
=
temp
()
else
:
else
:
pmsg
=
temp
pmsg
=
temp
print
(
op
.
message
,
attr
,
"="
,
pmsg
)
print
(
op
.
message
,
attr
,
"="
,
pmsg
)
# noqa: T201
class
Print
(
Op
):
class
Print
(
Op
):
...
@@ -1657,7 +1657,7 @@ def pydotprint(
...
@@ -1657,7 +1657,7 @@ def pydotprint(
raise
raise
if
print_output_file
:
if
print_output_file
:
print
(
"The output file is available at"
,
outfile
)
print
(
"The output file is available at"
,
outfile
)
# noqa: T201
class
_TagGenerator
:
class
_TagGenerator
:
...
@@ -1824,8 +1824,7 @@ def var_descriptor(obj, _prev_obs: dict | None = None, _tag_generator=None) -> s
...
@@ -1824,8 +1824,7 @@ def var_descriptor(obj, _prev_obs: dict | None = None, _tag_generator=None) -> s
# The __str__ method is encoding the object's id in its str
# The __str__ method is encoding the object's id in its str
name
=
position_independent_str
(
obj
)
name
=
position_independent_str
(
obj
)
if
" at 0x"
in
name
:
if
" at 0x"
in
name
:
print
(
name
)
raise
AssertionError
(
name
)
raise
AssertionError
()
prefix
=
cur_tag
+
"="
prefix
=
cur_tag
+
"="
...
...
pytensor/tensor/basic.py
浏览文件 @
17748b7d
...
@@ -613,7 +613,6 @@ def get_scalar_constant_value(
...
@@ -613,7 +613,6 @@ def get_scalar_constant_value(
"""
"""
if
isinstance
(
v
,
TensorVariable
|
np
.
ndarray
):
if
isinstance
(
v
,
TensorVariable
|
np
.
ndarray
):
if
v
.
ndim
!=
0
:
if
v
.
ndim
!=
0
:
print
(
v
,
v
.
ndim
)
raise
NotScalarConstantError
(
"Input ndim != 0"
)
raise
NotScalarConstantError
(
"Input ndim != 0"
)
return
get_underlying_scalar_constant_value
(
return
get_underlying_scalar_constant_value
(
v
,
v
,
...
...
pytensor/tensor/nlinalg.py
浏览文件 @
17748b7d
...
@@ -216,9 +216,8 @@ class Det(Op):
...
@@ -216,9 +216,8 @@ class Det(Op):
(
z
,)
=
outputs
(
z
,)
=
outputs
try
:
try
:
z
[
0
]
=
np
.
asarray
(
np
.
linalg
.
det
(
x
),
dtype
=
x
.
dtype
)
z
[
0
]
=
np
.
asarray
(
np
.
linalg
.
det
(
x
),
dtype
=
x
.
dtype
)
except
Exception
:
except
Exception
as
e
:
print
(
"Failed to compute determinant"
,
x
)
raise
ValueError
(
"Failed to compute determinant"
,
x
)
from
e
raise
def
grad
(
self
,
inputs
,
g_outputs
):
def
grad
(
self
,
inputs
,
g_outputs
):
(
gz
,)
=
g_outputs
(
gz
,)
=
g_outputs
...
@@ -256,9 +255,8 @@ class SLogDet(Op):
...
@@ -256,9 +255,8 @@ class SLogDet(Op):
(
sign
,
det
)
=
outputs
(
sign
,
det
)
=
outputs
try
:
try
:
sign
[
0
],
det
[
0
]
=
(
np
.
array
(
z
,
dtype
=
x
.
dtype
)
for
z
in
np
.
linalg
.
slogdet
(
x
))
sign
[
0
],
det
[
0
]
=
(
np
.
array
(
z
,
dtype
=
x
.
dtype
)
for
z
in
np
.
linalg
.
slogdet
(
x
))
except
Exception
:
except
Exception
as
e
:
print
(
"Failed to compute determinant"
,
x
)
raise
ValueError
(
"Failed to compute determinant"
,
x
)
from
e
raise
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
return
[(),
()]
return
[(),
()]
...
...
pytensor/tensor/rewriting/blas.py
浏览文件 @
17748b7d
...
@@ -573,7 +573,7 @@ class GemmOptimizer(GraphRewriter):
...
@@ -573,7 +573,7 @@ class GemmOptimizer(GraphRewriter):
print
(
blanc
,
" callbacks_time"
,
file
=
stream
)
print
(
blanc
,
" callbacks_time"
,
file
=
stream
)
for
i
in
sorted
(
prof
[
12
]
.
items
(),
key
=
lambda
a
:
a
[
1
]):
for
i
in
sorted
(
prof
[
12
]
.
items
(),
key
=
lambda
a
:
a
[
1
]):
if
i
[
1
]
>
0
:
if
i
[
1
]
>
0
:
print
(
i
)
print
(
i
)
# noqa: T201
@node_rewriter
([
Dot
])
@node_rewriter
([
Dot
])
...
...
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
17748b7d
...
@@ -314,14 +314,14 @@ class InplaceElemwiseOptimizer(GraphRewriter):
...
@@ -314,14 +314,14 @@ class InplaceElemwiseOptimizer(GraphRewriter):
except
(
ValueError
,
InconsistencyError
)
as
e
:
except
(
ValueError
,
InconsistencyError
)
as
e
:
prof
[
"nb_inconsistent"
]
+=
1
prof
[
"nb_inconsistent"
]
+=
1
if
check_each_change
!=
1
and
not
raised_warning
:
if
check_each_change
!=
1
and
not
raised_warning
:
print
(
print
(
# noqa: T201
(
(
"Some inplace rewriting was not "
"Some inplace rewriting was not "
"performed due to an unexpected error:"
"performed due to an unexpected error:"
),
),
file
=
sys
.
stderr
,
file
=
sys
.
stderr
,
)
)
print
(
e
,
file
=
sys
.
stderr
)
print
(
e
,
file
=
sys
.
stderr
)
# noqa: T201
raised_warning
=
True
raised_warning
=
True
fgraph
.
revert
(
chk
)
fgraph
.
revert
(
chk
)
continue
continue
...
@@ -335,7 +335,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
...
@@ -335,7 +335,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
fgraph
.
validate
()
fgraph
.
validate
()
except
Exception
:
except
Exception
:
if
not
raised_warning
:
if
not
raised_warning
:
print
(
print
(
# noqa: T201
(
(
"Some inplace rewriting was not "
"Some inplace rewriting was not "
"performed due to an unexpected error"
"performed due to an unexpected error"
...
@@ -1080,7 +1080,7 @@ class FusionOptimizer(GraphRewriter):
...
@@ -1080,7 +1080,7 @@ class FusionOptimizer(GraphRewriter):
print
(
blanc
,
" callbacks_time"
,
file
=
stream
)
print
(
blanc
,
" callbacks_time"
,
file
=
stream
)
for
i
in
sorted
(
prof
[
6
]
.
items
(),
key
=
lambda
a
:
a
[
1
])[::
-
1
]:
for
i
in
sorted
(
prof
[
6
]
.
items
(),
key
=
lambda
a
:
a
[
1
])[::
-
1
]:
if
i
[
1
]
>
0
:
if
i
[
1
]
>
0
:
print
(
blanc
,
" "
,
i
)
print
(
blanc
,
" "
,
i
)
# noqa: T201
print
(
blanc
,
" time_toposort"
,
prof
[
7
],
file
=
stream
)
print
(
blanc
,
" time_toposort"
,
prof
[
7
],
file
=
stream
)
...
...
pytensor/tensor/rewriting/math.py
浏览文件 @
17748b7d
...
@@ -3434,14 +3434,14 @@ def perform_sigm_times_exp(
...
@@ -3434,14 +3434,14 @@ def perform_sigm_times_exp(
sigm_minus_x
=
[]
sigm_minus_x
=
[]
if
full_tree
is
None
:
if
full_tree
is
None
:
full_tree
=
tree
full_tree
=
tree
if
False
:
# Debug code.
#
if False: # Debug code.
print
(
"<perform_sigm_times_exp>"
)
#
print("<perform_sigm_times_exp>")
print
(
f
" full_tree = {full_tree}"
)
#
print(f" full_tree = {full_tree}")
print
(
f
" tree = {tree}"
)
#
print(f" tree = {tree}")
print
(
f
" exp_x = {exp_x}"
)
#
print(f" exp_x = {exp_x}")
print
(
f
" exp_minus_x = {exp_minus_x}"
)
#
print(f" exp_minus_x = {exp_minus_x}")
print
(
f
" sigm_x = {sigm_x}"
)
#
print(f" sigm_x = {sigm_x}")
print
(
f
" sigm_minus_x= {sigm_minus_x}"
)
#
print(f" sigm_minus_x= {sigm_minus_x}")
neg
,
inputs
=
tree
neg
,
inputs
=
tree
if
isinstance
(
inputs
,
list
):
if
isinstance
(
inputs
,
list
):
# Recurse through inputs of the multiplication.
# Recurse through inputs of the multiplication.
...
...
scripts/slowest_tests/extract-slow-tests.py
浏览文件 @
17748b7d
...
@@ -72,7 +72,7 @@ def main(read_lines):
...
@@ -72,7 +72,7 @@ def main(read_lines):
lines
=
read_lines
()
lines
=
read_lines
()
times
=
extract_lines
(
lines
)
times
=
extract_lines
(
lines
)
parsed_times
=
format_times
(
times
)
parsed_times
=
format_times
(
times
)
print
(
"
\n
"
.
join
(
parsed_times
))
print
(
"
\n
"
.
join
(
parsed_times
))
# noqa: T201
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/d3viz/test_d3viz.py
浏览文件 @
17748b7d
...
@@ -28,7 +28,7 @@ class TestD3Viz:
...
@@ -28,7 +28,7 @@ class TestD3Viz:
tmp_dir
=
Path
(
tempfile
.
mkdtemp
())
tmp_dir
=
Path
(
tempfile
.
mkdtemp
())
html_file
=
tmp_dir
/
"index.html"
html_file
=
tmp_dir
/
"index.html"
if
verbose
:
if
verbose
:
print
(
html_file
)
print
(
html_file
)
# noqa: T201
d3v
.
d3viz
(
f
,
html_file
)
d3v
.
d3viz
(
f
,
html_file
)
assert
html_file
.
stat
()
.
st_size
>
0
assert
html_file
.
stat
()
.
st_size
>
0
if
reference
:
if
reference
:
...
...
tests/link/c/test_cmodule.py
浏览文件 @
17748b7d
...
@@ -258,7 +258,6 @@ def test_default_blas_ldflags(
...
@@ -258,7 +258,6 @@ def test_default_blas_ldflags(
def
patched_compile_tmp
(
*
args
,
**
kwargs
):
def
patched_compile_tmp
(
*
args
,
**
kwargs
):
def
wrapped
(
test_code
,
tmp_prefix
,
flags
,
try_run
,
output
):
def
wrapped
(
test_code
,
tmp_prefix
,
flags
,
try_run
,
output
):
if
len
(
flags
)
>=
2
and
flags
[:
2
]
==
[
"-framework"
,
"Accelerate"
]:
if
len
(
flags
)
>=
2
and
flags
[:
2
]
==
[
"-framework"
,
"Accelerate"
]:
print
(
enabled_accelerate_framework
)
if
enabled_accelerate_framework
:
if
enabled_accelerate_framework
:
return
(
True
,
True
)
return
(
True
,
True
)
else
:
else
:
...
...
tests/link/numba/test_basic.py
浏览文件 @
17748b7d
...
@@ -836,7 +836,6 @@ def test_config_options_fastmath():
...
@@ -836,7 +836,6 @@ def test_config_options_fastmath():
with
config
.
change_flags
(
numba__fastmath
=
True
):
with
config
.
change_flags
(
numba__fastmath
=
True
):
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
pytensor_numba_fn
=
function
([
x
],
pt
.
sum
(
x
),
mode
=
numba_mode
)
print
(
list
(
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
))
numba_mul_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"impl_sum"
]
numba_mul_fn
=
pytensor_numba_fn
.
vm
.
jit_fn
.
py_func
.
__globals__
[
"impl_sum"
]
assert
numba_mul_fn
.
targetoptions
[
"fastmath"
]
==
{
assert
numba_mul_fn
.
targetoptions
[
"fastmath"
]
==
{
"afn"
,
"afn"
,
...
...
tests/link/test_vm.py
浏览文件 @
17748b7d
import
time
from
collections
import
Counter
from
collections
import
Counter
import
numpy
as
np
import
numpy
as
np
...
@@ -108,23 +107,25 @@ def test_speed():
...
@@ -108,23 +107,25 @@ def test_speed():
return
z
return
z
def
time_numpy
():
def
time_numpy
():
# TODO: Make this a benchmark test
steps_a
=
5
steps_a
=
5
steps_b
=
100
steps_b
=
100
x
=
np
.
asarray
([
2.0
,
3.0
],
dtype
=
config
.
floatX
)
x
=
np
.
asarray
([
2.0
,
3.0
],
dtype
=
config
.
floatX
)
numpy_version
(
x
,
steps_a
)
numpy_version
(
x
,
steps_a
)
t0
=
time
.
perf_counter
()
#
t0 = time.perf_counter()
# print
numpy_version(x, steps_a)
numpy_version
(
x
,
steps_a
)
t1
=
time
.
perf_counter
()
#
t1 = time.perf_counter()
t2
=
time
.
perf_counter
()
#
t2 = time.perf_counter()
# print
numpy_version(x, steps_b)
numpy_version
(
x
,
steps_b
)
t3
=
time
.
perf_counter
()
#
t3 = time.perf_counter()
t_a
=
t1
-
t0
#
t_a = t1 - t0
t_b
=
t3
-
t2
#
t_b = t3 - t2
print
(
f
"numpy takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop"
)
#
print(f"numpy takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop")
def
time_linker
(
name
,
linker
):
def
time_linker
(
name
,
linker
):
# TODO: Make this a benchmark test
steps_a
=
5
steps_a
=
5
steps_b
=
100
steps_b
=
100
x
=
vector
()
x
=
vector
()
...
@@ -135,20 +136,20 @@ def test_speed():
...
@@ -135,20 +136,20 @@ def test_speed():
f_b
=
function
([
x
],
b
,
mode
=
Mode
(
optimizer
=
None
,
linker
=
linker
()))
f_b
=
function
([
x
],
b
,
mode
=
Mode
(
optimizer
=
None
,
linker
=
linker
()))
f_a
([
2.0
,
3.0
])
f_a
([
2.0
,
3.0
])
t0
=
time
.
perf_counter
()
#
t0 = time.perf_counter()
f_a
([
2.0
,
3.0
])
f_a
([
2.0
,
3.0
])
t1
=
time
.
perf_counter
()
#
t1 = time.perf_counter()
f_b
([
2.0
,
3.0
])
f_b
([
2.0
,
3.0
])
t2
=
time
.
perf_counter
()
#
t2 = time.perf_counter()
f_b
([
2.0
,
3.0
])
f_b
([
2.0
,
3.0
])
t3
=
time
.
perf_counter
()
#
t3 = time.perf_counter()
t_a
=
t1
-
t0
#
t_a = t1 - t0
t_b
=
t3
-
t2
#
t_b = t3 - t2
print
(
f
"{name} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop"
)
#
print(f"{name} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop")
time_linker
(
"c|py"
,
OpWiseCLinker
)
time_linker
(
"c|py"
,
OpWiseCLinker
)
time_linker
(
"vmLinker"
,
VMLinker
)
time_linker
(
"vmLinker"
,
VMLinker
)
...
@@ -167,7 +168,7 @@ def test_speed():
...
@@ -167,7 +168,7 @@ def test_speed():
],
],
)
)
def
test_speed_lazy
(
linker
):
def
test_speed_lazy
(
linker
):
# TODO FIXME: This isn't a real test.
# TODO FIXME: This isn't a real test.
Make this a benchmark test
def
build_graph
(
x
,
depth
=
5
):
def
build_graph
(
x
,
depth
=
5
):
z
=
x
z
=
x
...
@@ -185,20 +186,20 @@ def test_speed_lazy(linker):
...
@@ -185,20 +186,20 @@ def test_speed_lazy(linker):
f_b
=
function
([
x
],
b
,
mode
=
Mode
(
optimizer
=
None
,
linker
=
linker
))
f_b
=
function
([
x
],
b
,
mode
=
Mode
(
optimizer
=
None
,
linker
=
linker
))
f_a
([
2.0
])
f_a
([
2.0
])
t0
=
time
.
perf_counter
()
#
t0 = time.perf_counter()
f_a
([
2.0
])
f_a
([
2.0
])
t1
=
time
.
perf_counter
()
#
t1 = time.perf_counter()
f_b
([
2.0
])
f_b
([
2.0
])
t2
=
time
.
perf_counter
()
#
t2 = time.perf_counter()
f_b
([
2.0
])
f_b
([
2.0
])
t3
=
time
.
perf_counter
()
#
t3 = time.perf_counter()
t_a
=
t1
-
t0
#
t_a = t1 - t0
t_b
=
t3
-
t2
#
t_b = t3 - t2
print
(
f
"{linker} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop"
)
#
print(f"{linker} takes {1000 * (t_b - t_a) / (steps_b - steps_a):f} s/Kop")
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
...
tests/scan/test_basic.py
浏览文件 @
17748b7d
...
@@ -12,7 +12,6 @@ Questions and notes about scan that should be answered :
...
@@ -12,7 +12,6 @@ Questions and notes about scan that should be answered :
import
os
import
os
import
pickle
import
pickle
import
shutil
import
shutil
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
from
tempfile
import
mkdtemp
from
tempfile
import
mkdtemp
...
@@ -3076,7 +3075,7 @@ class TestExamples:
...
@@ -3076,7 +3075,7 @@ class TestExamples:
cost
=
result_outer
[
0
][
-
1
]
cost
=
result_outer
[
0
][
-
1
]
H
=
hessian
(
cost
,
W
)
H
=
hessian
(
cost
,
W
)
print
(
"."
,
file
=
sys
.
stderr
)
#
print(".", file=sys.stderr)
f
=
function
([
W
,
n_steps
],
H
)
f
=
function
([
W
,
n_steps
],
H
)
benchmark
(
f
,
np
.
ones
((
8
,),
dtype
=
"float32"
),
1
)
benchmark
(
f
,
np
.
ones
((
8
,),
dtype
=
"float32"
),
1
)
...
...
tests/tensor/rewriting/test_math.py
浏览文件 @
17748b7d
...
@@ -1628,6 +1628,7 @@ def test_local_mul_specialize():
...
@@ -1628,6 +1628,7 @@ def test_local_mul_specialize():
def
speed_local_pow_specialize_range
():
def
speed_local_pow_specialize_range
():
# TODO: This should be a benchmark test
val
=
np
.
random
.
random
(
1e7
)
val
=
np
.
random
.
random
(
1e7
)
v
=
vector
()
v
=
vector
()
mode
=
get_default_mode
()
mode
=
get_default_mode
()
...
@@ -1641,9 +1642,9 @@ def speed_local_pow_specialize_range():
...
@@ -1641,9 +1642,9 @@ def speed_local_pow_specialize_range():
t2
=
time
.
perf_counter
()
t2
=
time
.
perf_counter
()
f2
(
val
)
f2
(
val
)
t3
=
time
.
perf_counter
()
t3
=
time
.
perf_counter
()
print
(
i
,
t2
-
t1
,
t3
-
t2
,
t2
-
t1
<
t3
-
t2
)
#
print(i, t2 - t1, t3 - t2, t2 - t1 < t3 - t2)
if
not
t2
-
t1
<
t3
-
t2
:
if
not
t2
-
t1
<
t3
-
t2
:
print
(
"WARNING WE ARE SLOWER"
)
raise
ValueError
(
"WARNING WE ARE SLOWER"
)
for
i
in
range
(
-
3
,
-
1500
,
-
1
):
for
i
in
range
(
-
3
,
-
1500
,
-
1
):
f1
=
function
([
v
],
v
**
i
,
mode
=
mode
)
f1
=
function
([
v
],
v
**
i
,
mode
=
mode
)
f2
=
function
([
v
],
v
**
i
,
mode
=
mode_without_pow_rewrite
)
f2
=
function
([
v
],
v
**
i
,
mode
=
mode_without_pow_rewrite
)
...
@@ -1653,9 +1654,9 @@ def speed_local_pow_specialize_range():
...
@@ -1653,9 +1654,9 @@ def speed_local_pow_specialize_range():
t2
=
time
.
perf_counter
()
t2
=
time
.
perf_counter
()
f2
(
val
)
f2
(
val
)
t3
=
time
.
perf_counter
()
t3
=
time
.
perf_counter
()
print
(
i
,
t2
-
t1
,
t3
-
t2
,
t2
-
t1
<
t3
-
t2
)
#
print(i, t2 - t1, t3 - t2, t2 - t1 < t3 - t2)
if
not
t2
-
t1
<
t3
-
t2
:
if
not
t2
-
t1
<
t3
-
t2
:
print
(
"WARNING WE ARE SLOWER"
)
raise
ValueError
(
"WARNING WE ARE SLOWER"
)
def
test_local_pow_specialize
():
def
test_local_pow_specialize
():
...
@@ -2483,19 +2484,20 @@ class TestLocalErfc:
...
@@ -2483,19 +2484,20 @@ class TestLocalErfc:
assert
f
.
maker
.
fgraph
.
outputs
[
0
]
.
dtype
==
config
.
floatX
assert
f
.
maker
.
fgraph
.
outputs
[
0
]
.
dtype
==
config
.
floatX
def
speed_local_log_erfc
(
self
):
def
speed_local_log_erfc
(
self
):
# TODO: Make this a benchmark test!
val
=
np
.
random
.
random
(
1e6
)
val
=
np
.
random
.
random
(
1e6
)
x
=
vector
()
x
=
vector
()
mode
=
get_mode
(
"FAST_RUN"
)
mode
=
get_mode
(
"FAST_RUN"
)
f1
=
function
([
x
],
log
(
erfc
(
x
)),
mode
=
mode
.
excluding
(
"local_log_erfc"
))
f1
=
function
([
x
],
log
(
erfc
(
x
)),
mode
=
mode
.
excluding
(
"local_log_erfc"
))
f2
=
function
([
x
],
log
(
erfc
(
x
)),
mode
=
mode
)
f2
=
function
([
x
],
log
(
erfc
(
x
)),
mode
=
mode
)
print
(
f1
.
maker
.
fgraph
.
toposort
())
#
print(f1.maker.fgraph.toposort())
print
(
f2
.
maker
.
fgraph
.
toposort
())
#
print(f2.maker.fgraph.toposort())
t0
=
time
.
perf_counter
()
#
t0 = time.perf_counter()
f1
(
val
)
f1
(
val
)
t1
=
time
.
perf_counter
()
#
t1 = time.perf_counter()
f2
(
val
)
f2
(
val
)
t2
=
time
.
perf_counter
()
#
t2 = time.perf_counter()
print
(
t1
-
t0
,
t2
-
t1
)
#
print(t1 - t0, t2 - t1)
class
TestLocalMergeSwitchSameCond
:
class
TestLocalMergeSwitchSameCond
:
...
@@ -4144,13 +4146,13 @@ class TestSigmoidRewrites:
...
@@ -4144,13 +4146,13 @@ class TestSigmoidRewrites:
perform_sigm_times_exp
(
trees
[
0
])
perform_sigm_times_exp
(
trees
[
0
])
trees
[
0
]
=
simplify_mul
(
trees
[
0
])
trees
[
0
]
=
simplify_mul
(
trees
[
0
])
good
=
is_same_graph
(
compute_mul
(
trees
[
0
]),
compute_mul
(
trees
[
1
]))
good
=
is_same_graph
(
compute_mul
(
trees
[
0
]),
compute_mul
(
trees
[
1
]))
if
not
good
:
#
if not good:
print
(
trees
[
0
])
#
print(trees[0])
print
(
trees
[
1
])
#
print(trees[1])
print
(
"***"
)
#
print("***")
pytensor
.
printing
.
debugprint
(
compute_mul
(
trees
[
0
]))
#
pytensor.printing.debugprint(compute_mul(trees[0]))
print
(
"***"
)
#
print("***")
pytensor
.
printing
.
debugprint
(
compute_mul
(
trees
[
1
]))
#
pytensor.printing.debugprint(compute_mul(trees[1]))
assert
good
assert
good
check
(
sigmoid
(
x
)
*
exp_op
(
-
x
),
sigmoid
(
-
x
))
check
(
sigmoid
(
x
)
*
exp_op
(
-
x
),
sigmoid
(
-
x
))
...
...
tests/tensor/test_complex.py
浏览文件 @
17748b7d
...
@@ -73,9 +73,7 @@ class TestRealImag:
...
@@ -73,9 +73,7 @@ class TestRealImag:
try
:
try
:
utt
.
verify_grad
(
f
,
[
aval
])
utt
.
verify_grad
(
f
,
[
aval
])
except
GradientError
as
e
:
except
GradientError
as
e
:
print
(
e
.
num_grad
.
gf
)
raise
ValueError
(
f
"Failed: {e.num_grad.gf=} {e.analytic_grad=}"
)
from
e
print
(
e
.
analytic_grad
)
raise
@pytest.mark.skip
(
reason
=
"Complex grads not enabled, see #178"
)
@pytest.mark.skip
(
reason
=
"Complex grads not enabled, see #178"
)
def
test_mul_mixed1
(
self
):
def
test_mul_mixed1
(
self
):
...
@@ -88,9 +86,7 @@ class TestRealImag:
...
@@ -88,9 +86,7 @@ class TestRealImag:
try
:
try
:
utt
.
verify_grad
(
f
,
[
aval
])
utt
.
verify_grad
(
f
,
[
aval
])
except
GradientError
as
e
:
except
GradientError
as
e
:
print
(
e
.
num_grad
.
gf
)
raise
ValueError
(
f
"Failed: {e.num_grad.gf=} {e.analytic_grad=}"
)
from
e
print
(
e
.
analytic_grad
)
raise
@pytest.mark.skip
(
reason
=
"Complex grads not enabled, see #178"
)
@pytest.mark.skip
(
reason
=
"Complex grads not enabled, see #178"
)
def
test_mul_mixed
(
self
):
def
test_mul_mixed
(
self
):
...
@@ -104,9 +100,7 @@ class TestRealImag:
...
@@ -104,9 +100,7 @@ class TestRealImag:
try
:
try
:
utt
.
verify_grad
(
f
,
[
aval
,
bval
])
utt
.
verify_grad
(
f
,
[
aval
,
bval
])
except
GradientError
as
e
:
except
GradientError
as
e
:
print
(
e
.
num_grad
.
gf
)
raise
ValueError
(
f
"Failed: {e.num_grad.gf=} {e.analytic_grad=}"
)
from
e
print
(
e
.
analytic_grad
)
raise
@pytest.mark.skip
(
reason
=
"Complex grads not enabled, see #178"
)
@pytest.mark.skip
(
reason
=
"Complex grads not enabled, see #178"
)
def
test_polar_grads
(
self
):
def
test_polar_grads
(
self
):
...
...
tests/tensor/test_fft.py
浏览文件 @
17748b7d
...
@@ -43,7 +43,6 @@ class TestFFT:
...
@@ -43,7 +43,6 @@ class TestFFT:
utt
.
assert_allclose
(
rfft_ref
,
res_rfft_comp
)
utt
.
assert_allclose
(
rfft_ref
,
res_rfft_comp
)
m
=
rfft
.
type
()
m
=
rfft
.
type
()
print
(
m
.
ndim
)
irfft
=
fft
.
irfft
(
m
)
irfft
=
fft
.
irfft
(
m
)
f_irfft
=
pytensor
.
function
([
m
],
irfft
)
f_irfft
=
pytensor
.
function
([
m
],
irfft
)
res_irfft
=
f_irfft
(
res_rfft
)
res_irfft
=
f_irfft
(
res_rfft
)
...
...
tests/tensor/test_shape.py
浏览文件 @
17748b7d
...
@@ -797,7 +797,6 @@ class TestVectorize:
...
@@ -797,7 +797,6 @@ class TestVectorize:
assert
equal_computations
([
vect_out
],
[
reshape
(
mat
,
new_shape
)])
assert
equal_computations
([
vect_out
],
[
reshape
(
mat
,
new_shape
)])
new_shape
=
stack
([[
-
1
,
x
],
[
x
-
1
,
-
1
]],
axis
=
0
)
new_shape
=
stack
([[
-
1
,
x
],
[
x
-
1
,
-
1
]],
axis
=
0
)
print
(
new_shape
.
type
)
[
vect_out
]
=
vectorize_node
(
node
,
vec
,
new_shape
)
.
outputs
[
vect_out
]
=
vectorize_node
(
node
,
vec
,
new_shape
)
.
outputs
vec_test_value
=
np
.
arange
(
6
)
vec_test_value
=
np
.
arange
(
6
)
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
...
...
tests/test_config.py
浏览文件 @
17748b7d
...
@@ -192,7 +192,7 @@ def test_invalid_configvar_access():
...
@@ -192,7 +192,7 @@ def test_invalid_configvar_access():
# But we can make sure that nothing crazy happens when we access it:
# But we can make sure that nothing crazy happens when we access it:
with
pytest
.
raises
(
configparser
.
ConfigAccessViolation
,
match
=
"different instance"
):
with
pytest
.
raises
(
configparser
.
ConfigAccessViolation
,
match
=
"different instance"
):
print
(
root
.
test__on_test_instance
)
assert
root
.
test__on_test_instance
is
not
None
def
test_no_more_dotting
():
def
test_no_more_dotting
():
...
...
tests/test_printing.py
浏览文件 @
17748b7d
...
@@ -138,9 +138,9 @@ def test_min_informative_str():
...
@@ -138,9 +138,9 @@ def test_min_informative_str():
D. D
D. D
E. E"""
E. E"""
if
mis
!=
reference
:
#
if mis != reference:
print
(
"--"
+
mis
+
"--"
)
#
print("--" + mis + "--")
print
(
"--"
+
reference
+
"--"
)
#
print("--" + reference + "--")
assert
mis
==
reference
assert
mis
==
reference
...
...
tests/unittest_tools.py
浏览文件 @
17748b7d
import
logging
import
logging
import
sys
import
sys
import
warnings
from
copy
import
copy
,
deepcopy
from
copy
import
copy
,
deepcopy
from
functools
import
wraps
from
functools
import
wraps
...
@@ -41,12 +42,9 @@ def fetch_seed(pseed=None):
...
@@ -41,12 +42,9 @@ def fetch_seed(pseed=None):
else
:
else
:
seed
=
None
seed
=
None
except
ValueError
:
except
ValueError
:
print
(
warnings
.
warn
(
(
"Error: config.unittests__rseed contains "
"Error: config.unittests__rseed contains "
"invalid seed, using None instead"
"invalid seed, using None instead"
),
file
=
sys
.
stderr
,
)
)
seed
=
None
seed
=
None
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论