Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
0b4d684f
提交
0b4d684f
authored
1月 31, 2026
作者:
ricardoV94
提交者:
Ricardo Vieira
2月 05, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add empty-safe unzip helper
上级
f4196f9c
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
37 行增加
和
26 行删除
+37
-26
elemwise.py
pytensor/tensor/elemwise.py
+3
-3
optimize.py
pytensor/tensor/optimize.py
+2
-1
slinalg.py
pytensor/tensor/slinalg.py
+2
-1
subtensor.py
pytensor/tensor/subtensor.py
+2
-1
utils.py
pytensor/utils.py
+11
-0
shape.py
pytensor/xtensor/shape.py
+2
-1
vectorization.py
pytensor/xtensor/vectorization.py
+15
-19
没有找到文件。
pytensor/tensor/elemwise.py
浏览文件 @
0b4d684f
...
@@ -37,7 +37,7 @@ from pytensor.tensor.utils import (
...
@@ -37,7 +37,7 @@ from pytensor.tensor.utils import (
normalize_reduce_axis
,
normalize_reduce_axis
,
)
)
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.utils
import
uniq
from
pytensor.utils
import
uniq
,
unzip
class
DimShuffle
(
ExternalCOp
):
class
DimShuffle
(
ExternalCOp
):
...
@@ -765,8 +765,8 @@ class Elemwise(OpenMPOp):
...
@@ -765,8 +765,8 @@ class Elemwise(OpenMPOp):
# assert that inames and inputs order stay consistent.
# assert that inames and inputs order stay consistent.
# This is to protect again futur change of uniq.
# This is to protect again futur change of uniq.
assert
len
(
inames
)
==
len
(
inputs
)
assert
len
(
inames
)
==
len
(
inputs
)
ii
,
iii
=
list
(
ii
,
iii
=
unzip
(
zip
(
*
uniq
(
list
(
zip
(
_inames
,
node
.
inputs
,
strict
=
True
))),
strict
=
True
)
uniq
(
list
(
zip
(
_inames
,
node
.
inputs
,
strict
=
True
))),
n
=
2
,
strict
=
True
)
)
assert
all
(
x
==
y
for
x
,
y
in
zip
(
ii
,
inames
,
strict
=
True
))
assert
all
(
x
==
y
for
x
,
y
in
zip
(
ii
,
inames
,
strict
=
True
))
assert
all
(
x
==
y
for
x
,
y
in
zip
(
iii
,
inputs
,
strict
=
True
))
assert
all
(
x
==
y
for
x
,
y
in
zip
(
iii
,
inputs
,
strict
=
True
))
...
...
pytensor/tensor/optimize.py
浏览文件 @
0b4d684f
...
@@ -35,6 +35,7 @@ from pytensor.tensor.math import tensordot
...
@@ -35,6 +35,7 @@ from pytensor.tensor.math import tensordot
from
pytensor.tensor.reshape
import
pack
,
unpack
from
pytensor.tensor.reshape
import
pack
,
unpack
from
pytensor.tensor.slinalg
import
solve
from
pytensor.tensor.slinalg
import
solve
from
pytensor.tensor.variable
import
TensorVariable
,
Variable
from
pytensor.tensor.variable
import
TensorVariable
,
Variable
from
pytensor.utils
import
unzip
# scipy.optimize can be slow to import, and will not be used by most users
# scipy.optimize can be slow to import, and will not be used by most users
...
@@ -297,7 +298,7 @@ class ScipyScalarWrapperOp(ScipyWrapperOp):
...
@@ -297,7 +298,7 @@ class ScipyScalarWrapperOp(ScipyWrapperOp):
# No differentiable arguments, return disconnected gradients
# No differentiable arguments, return disconnected gradients
return
arg_grads
return
arg_grads
outer_args_to_diff
,
df_dthetas
=
zip
(
*
valid_args_and_grads
)
outer_args_to_diff
,
df_dthetas
=
unzip
(
valid_args_and_grads
,
n
=
2
)
replace
=
dict
(
zip
(
fgraph
.
inputs
,
(
x_star
,
*
args
),
strict
=
True
))
replace
=
dict
(
zip
(
fgraph
.
inputs
,
(
x_star
,
*
args
),
strict
=
True
))
df_dx_star
,
*
df_dthetas_stars
=
graph_replace
(
df_dx_star
,
*
df_dthetas_stars
=
graph_replace
(
...
...
pytensor/tensor/slinalg.py
浏览文件 @
0b4d684f
...
@@ -22,6 +22,7 @@ from pytensor.tensor.basic import as_tensor_variable, diagonal
...
@@ -22,6 +22,7 @@ from pytensor.tensor.basic import as_tensor_variable, diagonal
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.type
import
matrix
,
tensor
,
vector
from
pytensor.tensor.type
import
matrix
,
tensor
,
vector
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.utils
import
unzip
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -1323,7 +1324,7 @@ class BaseBlockDiagonal(Op):
...
@@ -1323,7 +1324,7 @@ class BaseBlockDiagonal(Op):
return
[
gout
[
0
][
slc
]
for
slc
in
slices
]
return
[
gout
[
0
][
slc
]
for
slc
in
slices
]
def
infer_shape
(
self
,
fgraph
,
nodes
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
nodes
,
shapes
):
first
,
second
=
zip
(
*
shapes
,
strict
=
True
)
first
,
second
=
unzip
(
shapes
,
n
=
2
,
strict
=
True
)
return
[(
pt
.
add
(
*
first
),
pt
.
add
(
*
second
))]
return
[(
pt
.
add
(
*
first
),
pt
.
add
(
*
second
))]
def
_validate_and_prepare_inputs
(
self
,
matrices
,
as_tensor_func
):
def
_validate_and_prepare_inputs
(
self
,
matrices
,
as_tensor_func
):
...
...
pytensor/tensor/subtensor.py
浏览文件 @
0b4d684f
...
@@ -70,6 +70,7 @@ from pytensor.tensor.type_other import (
...
@@ -70,6 +70,7 @@ from pytensor.tensor.type_other import (
make_slice
,
make_slice
,
)
)
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
from
pytensor.utils
import
unzip
_logger
=
logging
.
getLogger
(
"pytensor.tensor.subtensor"
)
_logger
=
logging
.
getLogger
(
"pytensor.tensor.subtensor"
)
...
@@ -650,7 +651,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
...
@@ -650,7 +651,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
)
)
for
basic
,
grp_dim_indices
in
idx_groups
:
for
basic
,
grp_dim_indices
in
idx_groups
:
dim_nums
,
grp_indices
=
zip
(
*
grp_dim_indices
,
strict
=
True
)
dim_nums
,
grp_indices
=
unzip
(
grp_dim_indices
,
n
=
2
,
strict
=
True
)
remaining_dims
=
tuple
(
dim
for
dim
in
remaining_dims
if
dim
not
in
dim_nums
)
remaining_dims
=
tuple
(
dim
for
dim
in
remaining_dims
if
dim
not
in
dim_nums
)
if
basic
:
if
basic
:
...
...
pytensor/utils.py
浏览文件 @
0b4d684f
...
@@ -338,3 +338,14 @@ class Singleton:
...
@@ -338,3 +338,14 @@ class Singleton:
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
type
(
self
))
return
hash
(
type
(
self
))
def
unzip
(
iterable
,
n
:
int
,
strict
:
bool
=
False
):
"""Unzip a nested iterable, returns n empty tuples if empty.
It can be safely unpacked into n variables.
"""
res
=
tuple
(
zip
(
*
iterable
,
strict
=
strict
))
if
not
res
:
return
((),)
*
n
return
res
pytensor/xtensor/shape.py
浏览文件 @
0b4d684f
...
@@ -12,6 +12,7 @@ from pytensor.tensor import as_tensor, get_scalar_constant_value
...
@@ -12,6 +12,7 @@ from pytensor.tensor import as_tensor, get_scalar_constant_value
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.type
import
integer_dtypes
from
pytensor.tensor.type
import
integer_dtypes
from
pytensor.tensor.utils
import
get_static_shape_from_size_variables
from
pytensor.tensor.utils
import
get_static_shape_from_size_variables
from
pytensor.utils
import
unzip
from
pytensor.xtensor.basic
import
XOp
from
pytensor.xtensor.basic
import
XOp
from
pytensor.xtensor.math
import
cast
,
second
from
pytensor.xtensor.math
import
cast
,
second
from
pytensor.xtensor.type
import
XTensorVariable
,
as_xtensor
,
xtensor
from
pytensor.xtensor.type
import
XTensorVariable
,
as_xtensor
,
xtensor
...
@@ -296,7 +297,7 @@ class Concat(XOp):
...
@@ -296,7 +297,7 @@ class Concat(XOp):
if
concat_dim
not
in
inp
.
type
.
dims
:
if
concat_dim
not
in
inp
.
type
.
dims
:
dims_and_shape
[
concat_dim
]
+=
1
dims_and_shape
[
concat_dim
]
+=
1
dims
,
shape
=
zip
(
*
dims_and_shape
.
items
()
)
dims
,
shape
=
unzip
(
dims_and_shape
.
items
(),
n
=
2
)
dtype
=
upcast
(
*
[
x
.
type
.
dtype
for
x
in
inputs
])
dtype
=
upcast
(
*
[
x
.
type
.
dtype
for
x
in
inputs
])
output
=
xtensor
(
dtype
=
dtype
,
dims
=
dims
,
shape
=
shape
)
output
=
xtensor
(
dtype
=
dtype
,
dims
=
dims
,
shape
=
shape
)
return
Apply
(
self
,
inputs
,
[
output
])
return
Apply
(
self
,
inputs
,
[
output
])
...
...
pytensor/xtensor/vectorization.py
浏览文件 @
0b4d684f
...
@@ -13,6 +13,7 @@ from pytensor.tensor.random.type import RandomType
...
@@ -13,6 +13,7 @@ from pytensor.tensor.random.type import RandomType
from
pytensor.tensor.utils
import
(
from
pytensor.tensor.utils
import
(
get_static_shape_from_size_variables
,
get_static_shape_from_size_variables
,
)
)
from
pytensor.utils
import
unzip
from
pytensor.xtensor.basic
import
XOp
from
pytensor.xtensor.basic
import
XOp
from
pytensor.xtensor.type
import
XTensorVariable
,
as_xtensor
,
xtensor
from
pytensor.xtensor.type
import
XTensorVariable
,
as_xtensor
,
xtensor
...
@@ -57,12 +58,7 @@ class XElemwise(XOp):
...
@@ -57,12 +58,7 @@ class XElemwise(XOp):
f
"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}"
f
"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}"
)
)
dims_and_shape
=
combine_dims_and_shape
(
inputs
)
output_dims
,
output_shape
=
unzip
(
combine_dims_and_shape
(
inputs
)
.
items
(),
n
=
2
)
if
dims_and_shape
:
output_dims
,
output_shape
=
zip
(
*
dims_and_shape
.
items
())
else
:
output_dims
,
output_shape
=
(),
()
dummy_scalars
=
[
ps
.
get_scalar_type
(
inp
.
type
.
dtype
)()
for
inp
in
inputs
]
dummy_scalars
=
[
ps
.
get_scalar_type
(
inp
.
type
.
dtype
)()
for
inp
in
inputs
]
output_dtypes
=
[
output_dtypes
=
[
out
.
type
.
dtype
for
out
in
self
.
scalar_op
.
make_node
(
*
dummy_scalars
)
.
outputs
out
.
type
.
dtype
for
out
in
self
.
scalar_op
.
make_node
(
*
dummy_scalars
)
.
outputs
...
@@ -99,8 +95,9 @@ class XBlockwise(XOp):
...
@@ -99,8 +95,9 @@ class XBlockwise(XOp):
core_inputs_dims
,
core_outputs_dims
=
self
.
core_dims
core_inputs_dims
,
core_outputs_dims
=
self
.
core_dims
core_input_dims_set
=
set
(
chain
.
from_iterable
(
core_inputs_dims
))
core_input_dims_set
=
set
(
chain
.
from_iterable
(
core_inputs_dims
))
batch_dims
,
batch_shape
=
zip
(
batch_dims
,
batch_shape
=
unzip
(
*
((
k
,
v
)
for
k
,
v
in
dims_and_shape
.
items
()
if
k
not
in
core_input_dims_set
)
((
k
,
v
)
for
k
,
v
in
dims_and_shape
.
items
()
if
k
not
in
core_input_dims_set
),
n
=
2
,
)
)
dummy_core_inputs
=
[]
dummy_core_inputs
=
[]
...
@@ -236,17 +233,16 @@ class XRV(XOp, RNGConsumerOp):
...
@@ -236,17 +233,16 @@ class XRV(XOp, RNGConsumerOp):
f
"Size dimensions {sorted(conflict_dims)} conflict with parameter dimensions. They should be unique."
f
"Size dimensions {sorted(conflict_dims)} conflict with parameter dimensions. They should be unique."
)
)
batch_dims_and_shape
=
[
batch_output_dims
,
batch_output_shape
=
unzip
(
(
dim
,
dim_length
)
(
for
dim
,
dim_length
in
(
(
dim
,
dim_length
)
extra_dims_and_shape
|
params_dims_and_shape
for
dim
,
dim_length
in
(
)
.
items
()
extra_dims_and_shape
|
params_dims_and_shape
if
dim
not
in
input_core_dims_set
)
.
items
()
]
if
dim
not
in
input_core_dims_set
if
batch_dims_and_shape
:
),
batch_output_dims
,
batch_output_shape
=
zip
(
*
batch_dims_and_shape
)
n
=
2
,
else
:
)
batch_output_dims
,
batch_output_shape
=
(),
()
dummy_core_inputs
=
[]
dummy_core_inputs
=
[]
for
param
,
core_param_dims
in
zip
(
params
,
param_core_dims
):
for
param
,
core_param_dims
in
zip
(
params
,
param_core_dims
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论