Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6a295b9f
提交
6a295b9f
authored
6月 18, 2024
作者:
Virgile Andreani
提交者:
Virgile Andreani
7月 09, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove unused stuff and type pytensor/utils.py
上级
b083fb91
显示空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
36 行增加
和
53 行删除
+36
-53
cmodule.py
pytensor/link/c/cmodule.py
+3
-3
blas.py
pytensor/tensor/blas.py
+6
-5
utils.py
pytensor/utils.py
+16
-44
test_types.py
tests/compile/function/test_types.py
+11
-1
没有找到文件。
pytensor/link/c/cmodule.py
浏览文件 @
6a295b9f
...
@@ -426,7 +426,7 @@ def is_same_entry(entry_1, entry_2):
...
@@ -426,7 +426,7 @@ def is_same_entry(entry_1, entry_2):
return
False
return
False
def
get_module_hash
(
src_code
,
key
)
:
def
get_module_hash
(
src_code
:
str
,
key
)
->
str
:
"""
"""
Return a SHA256 hash that uniquely identifies a module.
Return a SHA256 hash that uniquely identifies a module.
...
@@ -466,13 +466,13 @@ def get_module_hash(src_code, key):
...
@@ -466,13 +466,13 @@ def get_module_hash(src_code, key):
if
isinstance
(
key_element
,
tuple
):
if
isinstance
(
key_element
,
tuple
):
# This should be the C++ compilation command line parameters or the
# This should be the C++ compilation command line parameters or the
# libraries to link against.
# libraries to link against.
to_hash
+=
list
(
key_element
)
to_hash
+=
[
str
(
e
)
for
e
in
key_element
]
elif
isinstance
(
key_element
,
str
):
elif
isinstance
(
key_element
,
str
):
if
key_element
.
startswith
(
"md5:"
)
or
key_element
.
startswith
(
"hash:"
):
if
key_element
.
startswith
(
"md5:"
)
or
key_element
.
startswith
(
"hash:"
):
# This is actually a sha256 hash of the config options.
# This is actually a sha256 hash of the config options.
# Currently, we still keep md5 to don't break old PyTensor.
# Currently, we still keep md5 to don't break old PyTensor.
# We add 'hash:' so that when we change it in
# We add 'hash:' so that when we change it in
# the futur, it won't break this version of PyTensor.
# the futur
e
, it won't break this version of PyTensor.
break
break
elif
key_element
.
startswith
(
"NPY_ABI_VERSION=0x"
)
or
key_element
.
startswith
(
elif
key_element
.
startswith
(
"NPY_ABI_VERSION=0x"
)
or
key_element
.
startswith
(
"c_compiler_str="
"c_compiler_str="
...
...
pytensor/tensor/blas.py
浏览文件 @
6a295b9f
...
@@ -75,6 +75,7 @@ Optimizations associated with these BLAS Ops are in tensor.rewriting.blas
...
@@ -75,6 +75,7 @@ Optimizations associated with these BLAS Ops are in tensor.rewriting.blas
"""
"""
import
functools
import
logging
import
logging
import
os
import
os
import
time
import
time
...
@@ -104,7 +105,6 @@ from pytensor.tensor.elemwise import DimShuffle
...
@@ -104,7 +105,6 @@ from pytensor.tensor.elemwise import DimShuffle
from
pytensor.tensor.math
import
add
,
mul
,
neg
,
sub
from
pytensor.tensor.math
import
add
,
mul
,
neg
,
sub
from
pytensor.tensor.shape
import
shape_padright
,
specify_broadcastable
from
pytensor.tensor.shape
import
shape_padright
,
specify_broadcastable
from
pytensor.tensor.type
import
DenseTensorType
,
TensorType
,
integer_dtypes
,
tensor
from
pytensor.tensor.type
import
DenseTensorType
,
TensorType
,
integer_dtypes
,
tensor
from
pytensor.utils
import
memoize
_logger
=
logging
.
getLogger
(
"pytensor.tensor.blas"
)
_logger
=
logging
.
getLogger
(
"pytensor.tensor.blas"
)
...
@@ -365,8 +365,10 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
...
@@ -365,8 +365,10 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
)
)
@memoize
@functools.cache
def
_ldflags
(
ldflags_str
,
libs
,
flags
,
libs_dir
,
include_dir
):
def
_ldflags
(
ldflags_str
:
str
,
libs
:
bool
,
flags
:
bool
,
libs_dir
:
bool
,
include_dir
:
bool
)
->
list
[
str
]:
"""Extract list of compilation flags from a string.
"""Extract list of compilation flags from a string.
Depending on the options, different type of flags will be kept.
Depending on the options, different type of flags will be kept.
...
@@ -422,7 +424,7 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
...
@@ -422,7 +424,7 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
t
=
t
[
1
:
-
1
]
t
=
t
[
1
:
-
1
]
try
:
try
:
t0
,
t1
,
t2
=
t
[
0
:
3
]
t0
,
t1
=
t
[
0
],
t
[
1
]
assert
t0
==
"-"
assert
t0
==
"-"
except
Exception
:
except
Exception
:
raise
ValueError
(
f
'invalid token "{t}" in ldflags_str: "{ldflags_str}"'
)
raise
ValueError
(
f
'invalid token "{t}" in ldflags_str: "{ldflags_str}"'
)
...
@@ -435,7 +437,6 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
...
@@ -435,7 +437,6 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
" is not wanted."
,
" is not wanted."
,
t
,
t
,
)
)
rval
.
append
(
t
[
2
:])
elif
libs
and
t1
==
"l"
:
# example -lmkl
elif
libs
and
t1
==
"l"
:
# example -lmkl
rval
.
append
(
t
[
2
:])
rval
.
append
(
t
[
2
:])
elif
flags
and
t1
not
in
(
"L"
,
"I"
,
"l"
):
# example -openmp
elif
flags
and
t1
not
in
(
"L"
,
"I"
,
"l"
):
# example -openmp
...
...
pytensor/utils.py
浏览文件 @
6a295b9f
...
@@ -6,7 +6,9 @@ import os
...
@@ -6,7 +6,9 @@ import os
import
struct
import
struct
import
subprocess
import
subprocess
import
sys
import
sys
from
collections.abc
import
Iterable
,
Sequence
from
functools
import
partial
from
functools
import
partial
from
pathlib
import
Path
__all__
=
[
__all__
=
[
...
@@ -85,18 +87,6 @@ def add_excepthook(hook):
...
@@ -85,18 +87,6 @@ def add_excepthook(hook):
sys
.
excepthook
=
__call_excepthooks
sys
.
excepthook
=
__call_excepthooks
def
exc_message
(
e
):
"""
In python 3.x, when an exception is reraised it saves original
exception in its args, therefore in order to find the actual
message, we need to unpack arguments recursively.
"""
msg
=
e
.
args
[
0
]
if
isinstance
(
msg
,
Exception
):
return
exc_message
(
msg
)
return
msg
def
get_unbound_function
(
unbound
):
def
get_unbound_function
(
unbound
):
# Op.make_thunk isn't bound, so don't have a __func__ attr.
# Op.make_thunk isn't bound, so don't have a __func__ attr.
# But bound method, have a __func__ method that point to the
# But bound method, have a __func__ method that point to the
...
@@ -106,8 +96,9 @@ def get_unbound_function(unbound):
...
@@ -106,8 +96,9 @@ def get_unbound_function(unbound):
return
unbound
return
unbound
def
maybe_add_to_os_environ_pathlist
(
var
,
newpath
):
def
maybe_add_to_os_environ_pathlist
(
var
:
str
,
newpath
:
Path
|
str
)
->
None
:
"""Unfortunately, Conda offers to make itself the default Python
"""
Unfortunately, Conda offers to make itself the default Python
and those who use it that way will probably not activate envs
and those who use it that way will probably not activate envs
correctly meaning e.g. mingw-w64 g++ may not be on their PATH.
correctly meaning e.g. mingw-w64 g++ may not be on their PATH.
...
@@ -118,15 +109,15 @@ def maybe_add_to_os_environ_pathlist(var, newpath):
...
@@ -118,15 +109,15 @@ def maybe_add_to_os_environ_pathlist(var, newpath):
The reason we check first is because Windows environment vars
The reason we check first is because Windows environment vars
are limited to 8191 characters and it is easy to hit that.
are limited to 8191 characters and it is easy to hit that.
`var` will typically be 'PATH'."""
`var` will typically be 'PATH'.
"""
import
os
if
not
Path
(
newpath
)
.
is_absolute
():
return
if
os
.
path
.
isabs
(
newpath
):
try
:
try
:
oldpaths
=
os
.
environ
[
var
]
.
split
(
os
.
pathsep
)
oldpaths
=
os
.
environ
[
var
]
.
split
(
os
.
pathsep
)
if
newpath
not
in
oldpaths
:
if
str
(
newpath
)
not
in
oldpaths
:
newpaths
=
os
.
pathsep
.
join
([
newpath
,
*
oldpaths
])
newpaths
=
os
.
pathsep
.
join
([
str
(
newpath
)
,
*
oldpaths
])
os
.
environ
[
var
]
=
newpaths
os
.
environ
[
var
]
=
newpaths
except
Exception
:
except
Exception
:
pass
pass
...
@@ -210,7 +201,7 @@ def output_subprocess_Popen(command, **params):
...
@@ -210,7 +201,7 @@ def output_subprocess_Popen(command, **params):
return
(
*
out
,
p
.
returncode
)
return
(
*
out
,
p
.
returncode
)
def
hash_from_code
(
msg
)
:
def
hash_from_code
(
msg
:
str
|
bytes
)
->
str
:
"""Return the SHA256 hash of a string or bytes."""
"""Return the SHA256 hash of a string or bytes."""
# hashlib.sha256() requires an object that supports buffer interface,
# hashlib.sha256() requires an object that supports buffer interface,
# but Python 3 (unicode) strings don't.
# but Python 3 (unicode) strings don't.
...
@@ -221,27 +212,7 @@ def hash_from_code(msg):
...
@@ -221,27 +212,7 @@ def hash_from_code(msg):
return
"m"
+
hashlib
.
sha256
(
msg
)
.
hexdigest
()
return
"m"
+
hashlib
.
sha256
(
msg
)
.
hexdigest
()
def
memoize
(
f
):
def
uniq
(
seq
:
Sequence
)
->
list
:
"""
Cache the return value for each tuple of arguments (which must be hashable).
"""
cache
=
{}
def
rval
(
*
args
,
**
kwargs
):
kwtup
=
tuple
(
kwargs
.
items
())
key
=
(
args
,
kwtup
)
if
key
not
in
cache
:
val
=
f
(
*
args
,
**
kwargs
)
cache
[
key
]
=
val
else
:
val
=
cache
[
key
]
return
val
return
rval
def
uniq
(
seq
):
"""
"""
Do not use set, this must always return the same value at the same index.
Do not use set, this must always return the same value at the same index.
If we just exchange other values, but keep the same pattern of duplication,
If we just exchange other values, but keep the same pattern of duplication,
...
@@ -253,11 +224,12 @@ def uniq(seq):
...
@@ -253,11 +224,12 @@ def uniq(seq):
return
[
x
for
i
,
x
in
enumerate
(
seq
)
if
seq
.
index
(
x
)
==
i
]
return
[
x
for
i
,
x
in
enumerate
(
seq
)
if
seq
.
index
(
x
)
==
i
]
def
difference
(
seq1
,
seq2
):
def
difference
(
seq1
:
Iterable
,
seq2
:
Iterable
):
r"""
r"""
Returns all elements in seq1 which are not in seq2: i.e ``seq1\seq2``.
Returns all elements in seq1 which are not in seq2: i.e ``seq1\seq2``.
"""
"""
seq2
=
list
(
seq2
)
try
:
try
:
# try to use O(const * len(seq1)) algo
# try to use O(const * len(seq1)) algo
if
len
(
seq2
)
<
4
:
# I'm guessing this threshold -JB
if
len
(
seq2
)
<
4
:
# I'm guessing this threshold -JB
...
@@ -285,7 +257,7 @@ def from_return_values(values):
...
@@ -285,7 +257,7 @@ def from_return_values(values):
return
[
values
]
return
[
values
]
def
flatten
(
a
):
def
flatten
(
a
)
->
list
:
"""
"""
Recursively flatten tuple, list and set in a list.
Recursively flatten tuple, list and set in a list.
...
...
tests/compile/function/test_types.py
浏览文件 @
6a295b9f
...
@@ -31,7 +31,6 @@ from pytensor.tensor.type import (
...
@@ -31,7 +31,6 @@ from pytensor.tensor.type import (
scalars
,
scalars
,
vector
,
vector
,
)
)
from
pytensor.utils
import
exc_message
def
PatternOptimizer
(
p1
,
p2
,
ign
=
True
):
def
PatternOptimizer
(
p1
,
p2
,
ign
=
True
):
...
@@ -1182,6 +1181,17 @@ class TestPicklefunction:
...
@@ -1182,6 +1181,17 @@ class TestPicklefunction:
def
pers_load
(
id
):
def
pers_load
(
id
):
return
saves
[
id
]
return
saves
[
id
]
def
exc_message
(
e
):
"""
In Python 3, when an exception is reraised it saves the original
exception in its args, therefore in order to find the actual
message, we need to unpack arguments recursively.
"""
msg
=
e
.
args
[
0
]
if
isinstance
(
msg
,
Exception
):
return
exc_message
(
msg
)
return
msg
b
=
np
.
random
.
random
((
5
,
4
))
b
=
np
.
random
.
random
((
5
,
4
))
x
=
matrix
()
x
=
matrix
()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论