Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
3fcf6369
提交
3fcf6369
authored
2月 07, 2023
作者:
Virgile Andreani
提交者:
Ricardo Vieira
2月 08, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add pyupgrade to pre-commit and apply it
上级
f254492b
显示空白字符变更
内嵌
并排
正在显示
15 个修改的文件
包含
58 行增加
和
46 行删除
+58
-46
.pre-commit-config.yaml
.pre-commit-config.yaml
+5
-0
mode.py
pytensor/compile/mode.py
+1
-3
gradient.py
pytensor/gradient.py
+1
-1
fg.py
pytensor/graph/fg.py
+1
-2
op.py
pytensor/graph/op.py
+1
-2
basic.py
pytensor/graph/rewriting/basic.py
+2
-4
cmodule.py
pytensor/link/c/cmodule.py
+11
-2
elemwise_codegen.py
pytensor/link/numba/dispatch/elemwise_codegen.py
+17
-19
printing.py
pytensor/printing.py
+12
-2
type.py
pytensor/sparse/type.py
+1
-2
basic.py
pytensor/tensor/random/rewriting/basic.py
+1
-1
utils.py
pytensor/tensor/random/utils.py
+1
-2
slinalg.py
pytensor/tensor/slinalg.py
+1
-2
type.py
pytensor/tensor/type.py
+1
-2
test_scan.py
tests/link/numba/test_scan.py
+2
-2
没有找到文件。
.pre-commit-config.yaml
浏览文件 @
3fcf6369
...
@@ -19,6 +19,11 @@ repos:
...
@@ -19,6 +19,11 @@ repos:
pytensor/tensor/var\.py|
pytensor/tensor/var\.py|
)$
)$
-
id
:
check-merge-conflict
-
id
:
check-merge-conflict
-
repo
:
https://github.com/asottile/pyupgrade
rev
:
v3.3.1
hooks
:
-
id
:
pyupgrade
args
:
[
--py38-plus
]
-
repo
:
https://github.com/psf/black
-
repo
:
https://github.com/psf/black
rev
:
22.10.0
rev
:
22.10.0
hooks
:
hooks
:
...
...
pytensor/compile/mode.py
浏览文件 @
3fcf6369
...
@@ -5,9 +5,7 @@ WRITEME
...
@@ -5,9 +5,7 @@ WRITEME
import
logging
import
logging
import
warnings
import
warnings
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Literal
,
Optional
,
Tuple
,
Union
from
typing_extensions
import
Literal
from
pytensor.compile.function.types
import
Supervisor
from
pytensor.compile.function.types
import
Supervisor
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
...
...
pytensor/gradient.py
浏览文件 @
3fcf6369
...
@@ -8,6 +8,7 @@ from typing import (
...
@@ -8,6 +8,7 @@ from typing import (
Callable
,
Callable
,
Dict
,
Dict
,
List
,
List
,
Literal
,
Mapping
,
Mapping
,
MutableSequence
,
MutableSequence
,
Optional
,
Optional
,
...
@@ -18,7 +19,6 @@ from typing import (
...
@@ -18,7 +19,6 @@ from typing import (
)
)
import
numpy
as
np
import
numpy
as
np
from
typing_extensions
import
Literal
import
pytensor
import
pytensor
from
pytensor.compile.ops
import
ViewOp
from
pytensor.compile.ops
import
ViewOp
...
...
pytensor/graph/fg.py
浏览文件 @
3fcf6369
...
@@ -7,6 +7,7 @@ from typing import (
...
@@ -7,6 +7,7 @@ from typing import (
Dict
,
Dict
,
Iterable
,
Iterable
,
List
,
List
,
Literal
,
Optional
,
Optional
,
Sequence
,
Sequence
,
Set
,
Set
,
...
@@ -15,8 +16,6 @@ from typing import (
...
@@ -15,8 +16,6 @@ from typing import (
cast
,
cast
,
)
)
from
typing_extensions
import
Literal
import
pytensor
import
pytensor
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Apply
,
AtomicVariable
,
Variable
,
applys_between
from
pytensor.graph.basic
import
Apply
,
AtomicVariable
,
Variable
,
applys_between
...
...
pytensor/graph/op.py
浏览文件 @
3fcf6369
...
@@ -9,6 +9,7 @@ from typing import (
...
@@ -9,6 +9,7 @@ from typing import (
Dict
,
Dict
,
List
,
List
,
Optional
,
Optional
,
Protocol
,
Sequence
,
Sequence
,
Tuple
,
Tuple
,
TypeVar
,
TypeVar
,
...
@@ -16,8 +17,6 @@ from typing import (
...
@@ -16,8 +17,6 @@ from typing import (
cast
,
cast
,
)
)
from
typing_extensions
import
Protocol
import
pytensor
import
pytensor
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Apply
,
NoParams
,
Variable
from
pytensor.graph.basic
import
Apply
,
NoParams
,
Variable
...
...
pytensor/graph/rewriting/basic.py
浏览文件 @
3fcf6369
...
@@ -15,9 +15,7 @@ from functools import _compose_mro, partial, reduce # type: ignore
...
@@ -15,9 +15,7 @@ from functools import _compose_mro, partial, reduce # type: ignore
from
itertools
import
chain
from
itertools
import
chain
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
from
typing
import
Iterable
as
IterableType
from
typing
import
Iterable
as
IterableType
from
typing
import
List
,
Optional
,
Sequence
,
Tuple
,
Union
,
cast
from
typing
import
List
,
Literal
,
Optional
,
Sequence
,
Tuple
,
Union
,
cast
from
typing_extensions
import
Literal
import
pytensor
import
pytensor
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
...
@@ -1185,7 +1183,7 @@ class OpToRewriterTracker:
...
@@ -1185,7 +1183,7 @@ class OpToRewriterTracker:
matches
.
extend
(
match
)
matches
.
extend
(
match
)
return
matches
return
matches
@functools.lru_cache
()
@functools.lru_cache
def
get_trackers
(
self
,
op
:
Op
)
->
List
[
NodeRewriter
]:
def
get_trackers
(
self
,
op
:
Op
)
->
List
[
NodeRewriter
]:
"""Get all the rewrites applicable to `op`."""
"""Get all the rewrites applicable to `op`."""
return
(
return
(
...
...
pytensor/link/c/cmodule.py
浏览文件 @
3fcf6369
...
@@ -19,7 +19,17 @@ import textwrap
...
@@ -19,7 +19,17 @@ import textwrap
import
time
import
time
import
warnings
import
warnings
from
io
import
BytesIO
,
StringIO
from
io
import
BytesIO
,
StringIO
from
typing
import
TYPE_CHECKING
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
cast
from
typing
import
(
TYPE_CHECKING
,
Callable
,
Dict
,
List
,
Optional
,
Protocol
,
Set
,
Tuple
,
cast
,
)
import
numpy
as
np
import
numpy
as
np
from
setuptools._distutils.sysconfig
import
(
from
setuptools._distutils.sysconfig
import
(
...
@@ -28,7 +38,6 @@ from setuptools._distutils.sysconfig import (
...
@@ -28,7 +38,6 @@ from setuptools._distutils.sysconfig import (
get_python_inc
,
get_python_inc
,
get_python_lib
,
get_python_lib
,
)
)
from
typing_extensions
import
Protocol
# we will abuse the lockfile mechanism when reading and writing the registry
# we will abuse the lockfile mechanism when reading and writing the registry
from
pytensor.compile.compilelock
import
lock_ctx
from
pytensor.compile.compilelock
import
lock_ctx
...
...
pytensor/link/numba/dispatch/elemwise_codegen.py
浏览文件 @
3fcf6369
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Any
,
List
,
Optional
,
Tuple
from
typing
import
Any
import
numba
import
numba
import
numpy
as
np
import
numpy
as
np
...
@@ -14,8 +14,8 @@ from numba.np import arrayobj
...
@@ -14,8 +14,8 @@ from numba.np import arrayobj
def
compute_itershape
(
def
compute_itershape
(
ctx
:
BaseContext
,
ctx
:
BaseContext
,
builder
:
ir
.
IRBuilder
,
builder
:
ir
.
IRBuilder
,
in_shapes
:
T
uple
[
ir
.
Instruction
,
...
],
in_shapes
:
t
uple
[
ir
.
Instruction
,
...
],
broadcast_pattern
:
Tuple
[
T
uple
[
bool
,
...
],
...
],
broadcast_pattern
:
tuple
[
t
uple
[
bool
,
...
],
...
],
):
):
one
=
ir
.
IntType
(
64
)(
1
)
one
=
ir
.
IntType
(
64
)(
1
)
ndim
=
len
(
in_shapes
[
0
])
ndim
=
len
(
in_shapes
[
0
])
...
@@ -63,12 +63,12 @@ def compute_itershape(
...
@@ -63,12 +63,12 @@ def compute_itershape(
def
make_outputs
(
def
make_outputs
(
ctx
:
numba
.
core
.
base
.
BaseContext
,
ctx
:
numba
.
core
.
base
.
BaseContext
,
builder
:
ir
.
IRBuilder
,
builder
:
ir
.
IRBuilder
,
iter_shape
:
T
uple
[
ir
.
Instruction
,
...
],
iter_shape
:
t
uple
[
ir
.
Instruction
,
...
],
out_bc
:
Tuple
[
T
uple
[
bool
,
...
],
...
],
out_bc
:
tuple
[
t
uple
[
bool
,
...
],
...
],
dtypes
:
T
uple
[
Any
,
...
],
dtypes
:
t
uple
[
Any
,
...
],
inplace
:
Tuple
[
T
uple
[
int
,
int
],
...
],
inplace
:
tuple
[
t
uple
[
int
,
int
],
...
],
inputs
:
T
uple
[
Any
,
...
],
inputs
:
t
uple
[
Any
,
...
],
input_types
:
T
uple
[
Any
,
...
],
input_types
:
t
uple
[
Any
,
...
],
):
):
arrays
=
[]
arrays
=
[]
ar_types
:
list
[
types
.
Array
]
=
[]
ar_types
:
list
[
types
.
Array
]
=
[]
...
@@ -106,13 +106,13 @@ def make_loop_call(
...
@@ -106,13 +106,13 @@ def make_loop_call(
builder
:
ir
.
IRBuilder
,
builder
:
ir
.
IRBuilder
,
scalar_func
:
Any
,
scalar_func
:
Any
,
scalar_signature
:
types
.
FunctionType
,
scalar_signature
:
types
.
FunctionType
,
iter_shape
:
T
uple
[
ir
.
Instruction
,
...
],
iter_shape
:
t
uple
[
ir
.
Instruction
,
...
],
inputs
:
T
uple
[
ir
.
Instruction
,
...
],
inputs
:
t
uple
[
ir
.
Instruction
,
...
],
outputs
:
T
uple
[
ir
.
Instruction
,
...
],
outputs
:
t
uple
[
ir
.
Instruction
,
...
],
input_bc
:
Tuple
[
T
uple
[
bool
,
...
],
...
],
input_bc
:
tuple
[
t
uple
[
bool
,
...
],
...
],
output_bc
:
Tuple
[
T
uple
[
bool
,
...
],
...
],
output_bc
:
tuple
[
t
uple
[
bool
,
...
],
...
],
input_types
:
T
uple
[
Any
,
...
],
input_types
:
t
uple
[
Any
,
...
],
output_types
:
T
uple
[
Any
,
...
],
output_types
:
t
uple
[
Any
,
...
],
):
):
safe
=
(
False
,
False
)
safe
=
(
False
,
False
)
n_outputs
=
len
(
outputs
)
n_outputs
=
len
(
outputs
)
...
@@ -150,9 +150,7 @@ def make_loop_call(
...
@@ -150,9 +150,7 @@ def make_loop_call(
# This part corresponds to opening the loops
# This part corresponds to opening the loops
loop_stack
=
[]
loop_stack
=
[]
loops
=
[]
loops
=
[]
output_accumulator
:
List
[
Tuple
[
Optional
[
Any
],
Optional
[
int
]]]
=
[
output_accumulator
:
list
[
tuple
[
Any
|
None
,
int
|
None
]]
=
[(
None
,
None
)]
*
n_outputs
(
None
,
None
)
]
*
n_outputs
for
dim
,
length
in
enumerate
(
iter_shape
):
for
dim
,
length
in
enumerate
(
iter_shape
):
# Find outputs that only have accumulations left
# Find outputs that only have accumulations left
for
output
in
range
(
n_outputs
):
for
output
in
range
(
n_outputs
):
...
...
pytensor/printing.py
浏览文件 @
3fcf6369
...
@@ -9,10 +9,20 @@ from contextlib import contextmanager
...
@@ -9,10 +9,20 @@ from contextlib import contextmanager
from
copy
import
copy
from
copy
import
copy
from
functools
import
reduce
,
singledispatch
from
functools
import
reduce
,
singledispatch
from
io
import
StringIO
from
io
import
StringIO
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
TextIO
,
Tuple
,
Union
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
,
TextIO
,
Tuple
,
Union
,
)
import
numpy
as
np
import
numpy
as
np
from
typing_extensions
import
Literal
from
pytensor.compile
import
Function
,
SharedVariable
from
pytensor.compile
import
Function
,
SharedVariable
from
pytensor.compile.io
import
In
,
Out
from
pytensor.compile.io
import
In
,
Out
...
...
pytensor/sparse/type.py
浏览文件 @
3fcf6369
from
typing
import
Iterable
,
Optional
,
Union
from
typing
import
Iterable
,
Literal
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
scipy.sparse
import
scipy.sparse
from
typing_extensions
import
Literal
import
pytensor
import
pytensor
from
pytensor
import
scalar
as
aes
from
pytensor
import
scalar
as
aes
...
...
pytensor/tensor/random/rewriting/basic.py
浏览文件 @
3fcf6369
...
@@ -143,7 +143,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
...
@@ -143,7 +143,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
# Check that Dimshuffle does not affect support dims
# Check that Dimshuffle does not affect support dims
supp_dims
=
set
(
range
(
rv
.
ndim
-
rv_op
.
ndim_supp
,
rv
.
ndim
))
supp_dims
=
set
(
range
(
rv
.
ndim
-
rv_op
.
ndim_supp
,
rv
.
ndim
))
shuffled_dims
=
{
dim
for
i
,
dim
in
enumerate
(
ds_op
.
shuffle
)
if
dim
!=
i
}
shuffled_dims
=
{
dim
for
i
,
dim
in
enumerate
(
ds_op
.
shuffle
)
if
dim
!=
i
}
augmented_dims
=
set
(
d
-
rv_op
.
ndim_supp
for
d
in
ds_op
.
augment
)
augmented_dims
=
{
d
-
rv_op
.
ndim_supp
for
d
in
ds_op
.
augment
}
if
(
shuffled_dims
|
augmented_dims
)
&
supp_dims
:
if
(
shuffled_dims
|
augmented_dims
)
&
supp_dims
:
return
False
return
False
...
...
pytensor/tensor/random/utils.py
浏览文件 @
3fcf6369
...
@@ -2,10 +2,9 @@ from collections.abc import Sequence
...
@@ -2,10 +2,9 @@ from collections.abc import Sequence
from
functools
import
wraps
from
functools
import
wraps
from
itertools
import
zip_longest
from
itertools
import
zip_longest
from
types
import
ModuleType
from
types
import
ModuleType
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
from
typing_extensions
import
Literal
from
pytensor.compile.sharedvalue
import
shared
from
pytensor.compile.sharedvalue
import
shared
from
pytensor.graph.basic
import
Constant
,
Variable
from
pytensor.graph.basic
import
Constant
,
Variable
...
...
pytensor/tensor/slinalg.py
浏览文件 @
3fcf6369
import
logging
import
logging
import
warnings
import
warnings
from
typing
import
TYPE_CHECKING
,
Union
from
typing
import
TYPE_CHECKING
,
Literal
,
Union
import
numpy
as
np
import
numpy
as
np
import
scipy.linalg
import
scipy.linalg
from
typing_extensions
import
Literal
import
pytensor
import
pytensor
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
...
...
pytensor/tensor/type.py
浏览文件 @
3fcf6369
import
logging
import
logging
import
warnings
import
warnings
from
typing
import
TYPE_CHECKING
,
Iterable
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Iterable
,
Literal
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
from
typing_extensions
import
Literal
import
pytensor
import
pytensor
from
pytensor
import
scalar
as
aes
from
pytensor
import
scalar
as
aes
...
...
tests/link/numba/test_scan.py
浏览文件 @
3fcf6369
...
@@ -435,9 +435,9 @@ def test_inner_graph_optimized():
...
@@ -435,9 +435,9 @@ def test_inner_graph_optimized():
# Disable scan pushout, in which case the whole scan is replaced by an Elemwise
# Disable scan pushout, in which case the whole scan is replaced by an Elemwise
f
=
function
([
xs
],
seq
,
mode
=
get_mode
(
"NUMBA"
)
.
excluding
(
"scan_pushout"
))
f
=
function
([
xs
],
seq
,
mode
=
get_mode
(
"NUMBA"
)
.
excluding
(
"scan_pushout"
))
(
scan_node
,)
=
[
(
scan_node
,)
=
(
node
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
node
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
]
)
inner_scan_nodes
=
scan_node
.
op
.
fgraph
.
apply_nodes
inner_scan_nodes
=
scan_node
.
op
.
fgraph
.
apply_nodes
assert
len
(
inner_scan_nodes
)
==
1
assert
len
(
inner_scan_nodes
)
==
1
(
inner_scan_node
,)
=
scan_node
.
op
.
fgraph
.
apply_nodes
(
inner_scan_node
,)
=
scan_node
.
op
.
fgraph
.
apply_nodes
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论