Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
10105bea
提交
10105bea
authored
5月 07, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
5月 09, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Don't specify zip strict kwarg in hot loops
It seems to add a non-trivial 100ns
上级
5335a680
隐藏空白字符变更
内嵌
并排
正在显示
20 个修改的文件
包含
60 行增加
和
72 行删除
+60
-72
pyproject.toml
pyproject.toml
+1
-1
builders.py
pytensor/compile/builders.py
+2
-3
types.py
pytensor/compile/function/types.py
+4
-2
ifelse.py
pytensor/ifelse.py
+4
-4
basic.py
pytensor/link/basic.py
+6
-6
basic.py
pytensor/link/c/basic.py
+6
-10
basic.py
pytensor/link/numba/dispatch/basic.py
+2
-2
cython_support.py
pytensor/link/numba/dispatch/cython_support.py
+1
-4
extra_ops.py
pytensor/link/numba/dispatch/extra_ops.py
+1
-1
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+1
-1
subtensor.py
pytensor/link/numba/dispatch/subtensor.py
+5
-5
utils.py
pytensor/link/utils.py
+2
-2
basic.py
pytensor/scalar/basic.py
+2
-2
loop.py
pytensor/scalar/loop.py
+2
-2
basic.py
pytensor/tensor/basic.py
+2
-2
elemwise.py
pytensor/tensor/elemwise.py
+5
-5
basic.py
pytensor/tensor/random/basic.py
+2
-2
utils.py
pytensor/tensor/random/utils.py
+6
-7
shape.py
pytensor/tensor/shape.py
+2
-4
type.py
pytensor/tensor/type.py
+4
-7
没有找到文件。
pyproject.toml
浏览文件 @
10105bea
...
...
@@ -130,7 +130,7 @@ exclude = ["doc/", "pytensor/_version.py"]
docstring-code-format
=
true
[tool.ruff.lint]
select
=
[
"
B905"
,
"
C"
,
"E"
,
"F"
,
"I"
,
"UP"
,
"W"
,
"RUF"
,
"PERF"
,
"PTH"
,
"ISC"
,
"T20"
,
"NPY201"
]
select
=
[
"C"
,
"E"
,
"F"
,
"I"
,
"UP"
,
"W"
,
"RUF"
,
"PERF"
,
"PTH"
,
"ISC"
,
"T20"
,
"NPY201"
]
ignore
=
[
"C408"
,
"C901"
,
"E501"
,
"E741"
,
"RUF012"
,
"PERF203"
,
"ISC001"
]
unfixable
=
[
# zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead
...
...
pytensor/compile/builders.py
浏览文件 @
10105bea
...
...
@@ -873,7 +873,6 @@ class OpFromGraph(Op, HasInnerGraph):
def
perform
(
self
,
node
,
inputs
,
outputs
):
variables
=
self
.
fn
(
*
inputs
)
assert
len
(
variables
)
==
len
(
outputs
)
# strict=False because asserted above
for
output
,
variable
in
zip
(
outputs
,
variables
,
strict
=
False
):
# zip strict not specified because we are in a hot loop
for
output
,
variable
in
zip
(
outputs
,
variables
):
output
[
0
]
=
variable
pytensor/compile/function/types.py
浏览文件 @
10105bea
...
...
@@ -924,7 +924,8 @@ class Function:
# Reinitialize each container's 'provided' counter
if
trust_input
:
for
arg_container
,
arg
in
zip
(
input_storage
,
args
,
strict
=
False
):
# zip strict not specified because we are in a hot loop
for
arg_container
,
arg
in
zip
(
input_storage
,
args
):
arg_container
.
storage
[
0
]
=
arg
else
:
for
arg_container
in
input_storage
:
...
...
@@ -934,7 +935,8 @@ class Function:
raise
TypeError
(
"Too many parameter passed to pytensor function"
)
# Set positional arguments
for
arg_container
,
arg
in
zip
(
input_storage
,
args
,
strict
=
False
):
# zip strict not specified because we are in a hot loop
for
arg_container
,
arg
in
zip
(
input_storage
,
args
):
# See discussion about None as input
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
if
arg
is
None
:
...
...
pytensor/ifelse.py
浏览文件 @
10105bea
...
...
@@ -305,8 +305,8 @@ class IfElse(_NoPythonOp):
if
len
(
ls
)
>
0
:
return
ls
else
:
#
strict=False
because we are in a hot loop
for
out
,
t
in
zip
(
outputs
,
input_true_branch
,
strict
=
False
):
#
zip strict not specified
because we are in a hot loop
for
out
,
t
in
zip
(
outputs
,
input_true_branch
):
compute_map
[
out
][
0
]
=
1
val
=
storage_map
[
t
][
0
]
if
self
.
as_view
:
...
...
@@ -326,8 +326,8 @@ class IfElse(_NoPythonOp):
if
len
(
ls
)
>
0
:
return
ls
else
:
#
strict=False
because we are in a hot loop
for
out
,
f
in
zip
(
outputs
,
inputs_false_branch
,
strict
=
False
):
#
zip strict not specified
because we are in a hot loop
for
out
,
f
in
zip
(
outputs
,
inputs_false_branch
):
compute_map
[
out
][
0
]
=
1
# can't view both outputs unless destroyhandler
# improves
...
...
pytensor/link/basic.py
浏览文件 @
10105bea
...
...
@@ -539,14 +539,14 @@ class WrapLinker(Linker):
def
f
():
for
inputs
in
input_lists
[
1
:]:
#
strict=False
because we are in a hot loop
for
input1
,
input2
in
zip
(
inputs0
,
inputs
,
strict
=
False
):
#
zip strict not specified
because we are in a hot loop
for
input1
,
input2
in
zip
(
inputs0
,
inputs
):
input2
.
storage
[
0
]
=
copy
(
input1
.
storage
[
0
])
for
x
in
to_reset
:
x
[
0
]
=
None
pre
(
self
,
[
input
.
data
for
input
in
input_lists
[
0
]],
order
,
thunk_groups
)
#
strict=False
because we are in a hot loop
for
i
,
(
thunks
,
node
)
in
enumerate
(
zip
(
thunk_groups
,
order
,
strict
=
False
)):
#
zip strict not specified
because we are in a hot loop
for
i
,
(
thunks
,
node
)
in
enumerate
(
zip
(
thunk_groups
,
order
)):
try
:
wrapper
(
self
.
fgraph
,
i
,
node
,
*
thunks
)
except
Exception
:
...
...
@@ -668,8 +668,8 @@ class JITLinker(PerformLinker):
# since the error may come from any of them?
raise_with_op
(
self
.
fgraph
,
output_nodes
[
0
],
thunk
)
#
strict=False
because we are in a hot loop
for
o_storage
,
o_val
in
zip
(
thunk_outputs
,
outputs
,
strict
=
False
):
#
zip strict not specified
because we are in a hot loop
for
o_storage
,
o_val
in
zip
(
thunk_outputs
,
outputs
):
o_storage
[
0
]
=
o_val
thunk
.
inputs
=
thunk_inputs
...
...
pytensor/link/c/basic.py
浏览文件 @
10105bea
...
...
@@ -1988,27 +1988,23 @@ class DualLinker(Linker):
)
def
f
():
#
strict=False
because we are in a hot loop
for
input1
,
input2
in
zip
(
i1
,
i2
,
strict
=
False
):
#
zip strict not specified
because we are in a hot loop
for
input1
,
input2
in
zip
(
i1
,
i2
):
# Set the inputs to be the same in both branches.
# The copy is necessary in order for inplace ops not to
# interfere.
input2
.
storage
[
0
]
=
copy
(
input1
.
storage
[
0
])
for
thunk1
,
thunk2
,
node1
,
node2
in
zip
(
thunks1
,
thunks2
,
order1
,
order2
,
strict
=
False
):
for
output
,
storage
in
zip
(
node1
.
outputs
,
thunk1
.
outputs
,
strict
=
False
):
for
thunk1
,
thunk2
,
node1
,
node2
in
zip
(
thunks1
,
thunks2
,
order1
,
order2
):
for
output
,
storage
in
zip
(
node1
.
outputs
,
thunk1
.
outputs
):
if
output
in
no_recycling
:
storage
[
0
]
=
None
for
output
,
storage
in
zip
(
node2
.
outputs
,
thunk2
.
outputs
,
strict
=
False
):
for
output
,
storage
in
zip
(
node2
.
outputs
,
thunk2
.
outputs
):
if
output
in
no_recycling
:
storage
[
0
]
=
None
try
:
thunk1
()
thunk2
()
for
output1
,
output2
in
zip
(
thunk1
.
outputs
,
thunk2
.
outputs
,
strict
=
False
):
for
output1
,
output2
in
zip
(
thunk1
.
outputs
,
thunk2
.
outputs
):
self
.
checker
(
output1
,
output2
)
except
Exception
:
raise_with_op
(
fgraph
,
node1
)
...
...
pytensor/link/numba/dispatch/basic.py
浏览文件 @
10105bea
...
...
@@ -312,10 +312,10 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
else
:
def
py_perform_return
(
inputs
):
#
strict=False
because we are in a hot loop
#
zip strict not specified
because we are in a hot loop
return
tuple
(
out_type
.
filter
(
out
[
0
])
for
out_type
,
out
in
zip
(
output_types
,
py_perform
(
inputs
)
,
strict
=
False
)
for
out_type
,
out
in
zip
(
output_types
,
py_perform
(
inputs
))
)
@numba_njit
...
...
pytensor/link/numba/dispatch/cython_support.py
浏览文件 @
10105bea
...
...
@@ -166,10 +166,7 @@ class _CythonWrapper(numba.types.WrapperAddressProtocol):
def
__call__
(
self
,
*
args
,
**
kwargs
):
# no strict argument because of the JIT
# TODO: check
args
=
[
dtype
(
arg
)
for
arg
,
dtype
in
zip
(
args
,
self
.
_signature
.
arg_dtypes
)
# noqa: B905
]
args
=
[
dtype
(
arg
)
for
arg
,
dtype
in
zip
(
args
,
self
.
_signature
.
arg_dtypes
)]
if
self
.
has_pyx_skip_dispatch
():
output
=
self
.
_pyfunc
(
*
args
[:
-
1
],
**
kwargs
)
else
:
...
...
pytensor/link/numba/dispatch/extra_ops.py
浏览文件 @
10105bea
...
...
@@ -186,7 +186,7 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
new_arr
=
arr
.
T
.
astype
(
np
.
float64
)
.
copy
()
for
i
,
b
in
enumerate
(
new_arr
):
# no strict argument to this zip because numba doesn't support it
for
j
,
(
d
,
v
)
in
enumerate
(
zip
(
shape
,
b
)):
# noqa: B905
for
j
,
(
d
,
v
)
in
enumerate
(
zip
(
shape
,
b
)):
if
v
<
0
or
v
>=
d
:
mode_fn
(
new_arr
,
i
,
j
,
v
,
d
)
...
...
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
10105bea
...
...
@@ -183,7 +183,7 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):
r
,
c
=
0
,
0
# no strict argument because it is incompatible with numba
for
arr
,
shape
in
zip
(
arrs
,
shapes
):
# noqa: B905
for
arr
,
shape
in
zip
(
arrs
,
shapes
):
rr
,
cc
=
shape
out
[
r
:
r
+
rr
,
c
:
c
+
cc
]
=
arr
r
+=
rr
...
...
pytensor/link/numba/dispatch/subtensor.py
浏览文件 @
10105bea
...
...
@@ -219,7 +219,7 @@ def numba_funcify_multiple_integer_vector_indexing(
shape_aft
=
x_shape
[
after_last_axis
:]
out_shape
=
(
*
shape_bef
,
*
idx_shape
,
*
shape_aft
)
out_buffer
=
np
.
empty
(
out_shape
,
dtype
=
x
.
dtype
)
for
i
,
scalar_idxs
in
enumerate
(
zip
(
*
vec_idxs
)):
# noqa: B905
for
i
,
scalar_idxs
in
enumerate
(
zip
(
*
vec_idxs
)):
out_buffer
[(
*
none_slices
,
i
)]
=
x
[(
*
none_slices
,
*
scalar_idxs
)]
return
out_buffer
...
...
@@ -253,7 +253,7 @@ def numba_funcify_multiple_integer_vector_indexing(
y
=
np
.
broadcast_to
(
y
,
x_shape
[:
first_axis
]
+
x_shape
[
last_axis
:])
for
outer
in
np
.
ndindex
(
x_shape
[:
first_axis
]):
for
i
,
scalar_idxs
in
enumerate
(
zip
(
*
vec_idxs
)):
# noqa: B905
for
i
,
scalar_idxs
in
enumerate
(
zip
(
*
vec_idxs
)):
out
[(
*
outer
,
*
scalar_idxs
)]
=
y
[(
*
outer
,
i
)]
return
out
...
...
@@ -275,7 +275,7 @@ def numba_funcify_multiple_integer_vector_indexing(
y
=
np
.
broadcast_to
(
y
,
x_shape
[:
first_axis
]
+
x_shape
[
last_axis
:])
for
outer
in
np
.
ndindex
(
x_shape
[:
first_axis
]):
for
i
,
scalar_idxs
in
enumerate
(
zip
(
*
vec_idxs
)):
# noqa: B905
for
i
,
scalar_idxs
in
enumerate
(
zip
(
*
vec_idxs
)):
out
[(
*
outer
,
*
scalar_idxs
)]
+=
y
[(
*
outer
,
i
)]
return
out
...
...
@@ -314,7 +314,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
if
not
len
(
idxs
)
==
len
(
vals
):
raise
ValueError
(
"The number of indices and values must match."
)
# no strict argument because incompatible with numba
for
idx
,
val
in
zip
(
idxs
,
vals
):
# noqa: B905
for
idx
,
val
in
zip
(
idxs
,
vals
):
x
[
idx
]
=
val
return
x
else
:
...
...
@@ -342,7 +342,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
raise
ValueError
(
"The number of indices and values must match."
)
# no strict argument because unsupported by numba
# TODO: this doesn't come up in tests
for
idx
,
val
in
zip
(
idxs
,
vals
):
# noqa: B905
for
idx
,
val
in
zip
(
idxs
,
vals
):
x
[
idx
]
+=
val
return
x
...
...
pytensor/link/utils.py
浏览文件 @
10105bea
...
...
@@ -207,8 +207,8 @@ def streamline(
for
x
in
no_recycling
:
x
[
0
]
=
None
try
:
#
strict=False
because we are in a hot loop
for
thunk
,
node
in
zip
(
thunks
,
order
,
strict
=
False
):
#
zip strict not specified
because we are in a hot loop
for
thunk
,
node
in
zip
(
thunks
,
order
):
thunk
()
except
Exception
:
raise_with_op
(
fgraph
,
node
,
thunk
)
...
...
pytensor/scalar/basic.py
浏览文件 @
10105bea
...
...
@@ -4416,8 +4416,8 @@ class Composite(ScalarInnerGraphOp):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
outputs
=
self
.
py_perform_fn
(
*
inputs
)
#
strict=False
because we are in a hot loop
for
storage
,
out_val
in
zip
(
output_storage
,
outputs
,
strict
=
False
):
#
zip strict not specified
because we are in a hot loop
for
storage
,
out_val
in
zip
(
output_storage
,
outputs
):
storage
[
0
]
=
out_val
def
grad
(
self
,
inputs
,
output_grads
):
...
...
pytensor/scalar/loop.py
浏览文件 @
10105bea
...
...
@@ -196,8 +196,8 @@ class ScalarLoop(ScalarInnerGraphOp):
for
i
in
range
(
n_steps
):
carry
=
inner_fn
(
*
carry
,
*
constant
)
#
strict=False
because we are in a hot loop
for
storage
,
out_val
in
zip
(
output_storage
,
carry
,
strict
=
False
):
#
zip strict not specified
because we are in a hot loop
for
storage
,
out_val
in
zip
(
output_storage
,
carry
):
storage
[
0
]
=
out_val
@property
...
...
pytensor/tensor/basic.py
浏览文件 @
10105bea
...
...
@@ -3589,8 +3589,8 @@ class PermuteRowElements(Op):
# Make sure the output is big enough
out_s
=
[]
#
strict=False
because we are in a hot loop
for
xdim
,
ydim
in
zip
(
x_s
,
y_s
,
strict
=
False
):
#
zip strict not specified
because we are in a hot loop
for
xdim
,
ydim
in
zip
(
x_s
,
y_s
):
if
xdim
==
ydim
:
outdim
=
xdim
elif
xdim
==
1
:
...
...
pytensor/tensor/elemwise.py
浏览文件 @
10105bea
...
...
@@ -712,9 +712,9 @@ class Elemwise(OpenMPOp):
if
nout
==
1
:
variables
=
[
variables
]
#
strict=False
because we are in a hot loop
#
zip strict not specified
because we are in a hot loop
for
i
,
(
variable
,
storage
,
nout
)
in
enumerate
(
zip
(
variables
,
output_storage
,
node
.
outputs
,
strict
=
False
)
zip
(
variables
,
output_storage
,
node
.
outputs
)
):
storage
[
0
]
=
variable
=
np
.
asarray
(
variable
,
dtype
=
nout
.
dtype
)
...
...
@@ -729,11 +729,11 @@ class Elemwise(OpenMPOp):
@staticmethod
def
_check_runtime_broadcast
(
node
,
inputs
):
#
strict=False
because we are in a hot loop
#
zip strict not specified
because we are in a hot loop
for
dims_and_bcast
in
zip
(
*
[
zip
(
input
.
shape
,
sinput
.
type
.
broadcastable
,
strict
=
False
)
for
input
,
sinput
in
zip
(
inputs
,
node
.
inputs
,
strict
=
False
)
zip
(
input
.
shape
,
sinput
.
type
.
broadcastable
)
for
input
,
sinput
in
zip
(
inputs
,
node
.
inputs
)
],
strict
=
False
,
):
...
...
pytensor/tensor/random/basic.py
浏览文件 @
10105bea
...
...
@@ -1865,8 +1865,8 @@ class CategoricalRV(RandomVariable):
# to `p.shape[:-1]` in the call to `vsearchsorted` below.
if
len
(
size
)
<
(
p
.
ndim
-
1
):
raise
ValueError
(
"`size` is incompatible with the shape of `p`"
)
#
strict=False
because we are in a hot loop
for
s
,
ps
in
zip
(
reversed
(
size
),
reversed
(
p
.
shape
[:
-
1
])
,
strict
=
False
):
#
zip strict not specified
because we are in a hot loop
for
s
,
ps
in
zip
(
reversed
(
size
),
reversed
(
p
.
shape
[:
-
1
])):
if
s
==
1
and
ps
!=
1
:
raise
ValueError
(
"`size` is incompatible with the shape of `p`"
)
...
...
pytensor/tensor/random/utils.py
浏览文件 @
10105bea
...
...
@@ -44,8 +44,8 @@ def params_broadcast_shapes(
max_fn
=
maximum
if
use_pytensor
else
max
rev_extra_dims
:
list
[
int
]
=
[]
#
strict=False
because we are in a hot loop
for
ndim_param
,
param_shape
in
zip
(
ndims_params
,
param_shapes
,
strict
=
False
):
#
zip strict not specified
because we are in a hot loop
for
ndim_param
,
param_shape
in
zip
(
ndims_params
,
param_shapes
):
# We need this in order to use `len`
param_shape
=
tuple
(
param_shape
)
extras
=
tuple
(
param_shape
[:
(
len
(
param_shape
)
-
ndim_param
)])
...
...
@@ -64,12 +64,12 @@ def params_broadcast_shapes(
extra_dims
=
tuple
(
reversed
(
rev_extra_dims
))
#
strict=False
because we are in a hot loop
#
zip strict not specified
because we are in a hot loop
bcast_shapes
=
[
(
extra_dims
+
tuple
(
param_shape
)[
-
ndim_param
:])
if
ndim_param
>
0
else
extra_dims
for
ndim_param
,
param_shape
in
zip
(
ndims_params
,
param_shapes
,
strict
=
False
)
for
ndim_param
,
param_shape
in
zip
(
ndims_params
,
param_shapes
)
]
return
bcast_shapes
...
...
@@ -127,10 +127,9 @@ def broadcast_params(
)
broadcast_to_fn
=
broadcast_to
if
use_pytensor
else
np
.
broadcast_to
#
strict=False
because we are in a hot loop
#
zip strict not specified
because we are in a hot loop
bcast_params
=
[
broadcast_to_fn
(
param
,
shape
)
for
shape
,
param
in
zip
(
shapes
,
params
,
strict
=
False
)
broadcast_to_fn
(
param
,
shape
)
for
shape
,
param
in
zip
(
shapes
,
params
)
]
return
bcast_params
...
...
pytensor/tensor/shape.py
浏览文件 @
10105bea
...
...
@@ -447,10 +447,8 @@ class SpecifyShape(COp):
raise
AssertionError
(
f
"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
)
# strict=False because we are in a hot loop
if
not
all
(
xs
==
s
for
xs
,
s
in
zip
(
x
.
shape
,
shape
,
strict
=
False
)
if
s
is
not
None
):
# zip strict not specified because we are in a hot loop
if
not
all
(
xs
==
s
for
xs
,
s
in
zip
(
x
.
shape
,
shape
)
if
s
is
not
None
):
raise
AssertionError
(
f
"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
)
...
...
pytensor/tensor/type.py
浏览文件 @
10105bea
...
...
@@ -261,10 +261,10 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
" PyTensor C code does not support that."
,
)
#
strict=False
because we are in a hot loop
#
zip strict not specified
because we are in a hot loop
if
not
all
(
ds
==
ts
if
ts
is
not
None
else
True
for
ds
,
ts
in
zip
(
data
.
shape
,
self
.
shape
,
strict
=
False
)
for
ds
,
ts
in
zip
(
data
.
shape
,
self
.
shape
)
):
raise
TypeError
(
f
"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})"
...
...
@@ -333,17 +333,14 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
return
False
def
is_super
(
self
,
otype
):
#
strict=False
because we are in a hot loop
#
zip strict not specified
because we are in a hot loop
if
(
isinstance
(
otype
,
type
(
self
))
and
otype
.
dtype
==
self
.
dtype
and
otype
.
ndim
==
self
.
ndim
# `otype` is allowed to be as or more shape-specific than `self`,
# but not less
and
all
(
sb
==
ob
or
sb
is
None
for
sb
,
ob
in
zip
(
self
.
shape
,
otype
.
shape
,
strict
=
False
)
)
and
all
(
sb
==
ob
or
sb
is
None
for
sb
,
ob
in
zip
(
self
.
shape
,
otype
.
shape
))
):
return
True
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论