提交 0b4d684f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Add empty-safe unzip helper

上级 f4196f9c
......@@ -37,7 +37,7 @@ from pytensor.tensor.utils import (
normalize_reduce_axis,
)
from pytensor.tensor.variable import TensorVariable
from pytensor.utils import uniq
from pytensor.utils import uniq, unzip
class DimShuffle(ExternalCOp):
......@@ -765,8 +765,8 @@ class Elemwise(OpenMPOp):
# assert that inames and inputs order stay consistent.
# This is to protect again futur change of uniq.
assert len(inames) == len(inputs)
ii, iii = list(
zip(*uniq(list(zip(_inames, node.inputs, strict=True))), strict=True)
ii, iii = unzip(
uniq(list(zip(_inames, node.inputs, strict=True))), n=2, strict=True
)
assert all(x == y for x, y in zip(ii, inames, strict=True))
assert all(x == y for x, y in zip(iii, inputs, strict=True))
......
......@@ -35,6 +35,7 @@ from pytensor.tensor.math import tensordot
from pytensor.tensor.reshape import pack, unpack
from pytensor.tensor.slinalg import solve
from pytensor.tensor.variable import TensorVariable, Variable
from pytensor.utils import unzip
# scipy.optimize can be slow to import, and will not be used by most users
......@@ -297,7 +298,7 @@ class ScipyScalarWrapperOp(ScipyWrapperOp):
# No differentiable arguments, return disconnected gradients
return arg_grads
outer_args_to_diff, df_dthetas = zip(*valid_args_and_grads)
outer_args_to_diff, df_dthetas = unzip(valid_args_and_grads, n=2)
replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
df_dx_star, *df_dthetas_stars = graph_replace(
......
......@@ -22,6 +22,7 @@ from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.variable import TensorVariable
from pytensor.utils import unzip
logger = logging.getLogger(__name__)
......@@ -1323,7 +1324,7 @@ class BaseBlockDiagonal(Op):
return [gout[0][slc] for slc in slices]
def infer_shape(self, fgraph, nodes, shapes):
first, second = zip(*shapes, strict=True)
first, second = unzip(shapes, n=2, strict=True)
return [(pt.add(*first), pt.add(*second))]
def _validate_and_prepare_inputs(self, matrices, as_tensor_func):
......
......@@ -70,6 +70,7 @@ from pytensor.tensor.type_other import (
make_slice,
)
from pytensor.tensor.variable import TensorConstant, TensorVariable
from pytensor.utils import unzip
_logger = logging.getLogger("pytensor.tensor.subtensor")
......@@ -650,7 +651,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
)
for basic, grp_dim_indices in idx_groups:
dim_nums, grp_indices = zip(*grp_dim_indices, strict=True)
dim_nums, grp_indices = unzip(grp_dim_indices, n=2, strict=True)
remaining_dims = tuple(dim for dim in remaining_dims if dim not in dim_nums)
if basic:
......
......@@ -338,3 +338,14 @@ class Singleton:
def __hash__(self):
return hash(type(self))
def unzip(iterable, n: int, strict: bool = False):
"""Unzip a nested iterable, returns n empty tuples if empty.
It can be safely unpacked into n variables.
"""
res = tuple(zip(*iterable, strict=strict))
if not res:
return ((),) * n
return res
......@@ -12,6 +12,7 @@ from pytensor.tensor import as_tensor, get_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.utils import get_static_shape_from_size_variables
from pytensor.utils import unzip
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.math import cast, second
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
......@@ -296,7 +297,7 @@ class Concat(XOp):
if concat_dim not in inp.type.dims:
dims_and_shape[concat_dim] += 1
dims, shape = zip(*dims_and_shape.items())
dims, shape = unzip(dims_and_shape.items(), n=2)
dtype = upcast(*[x.type.dtype for x in inputs])
output = xtensor(dtype=dtype, dims=dims, shape=shape)
return Apply(self, inputs, [output])
......
......@@ -13,6 +13,7 @@ from pytensor.tensor.random.type import RandomType
from pytensor.tensor.utils import (
get_static_shape_from_size_variables,
)
from pytensor.utils import unzip
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
......@@ -57,12 +58,7 @@ class XElemwise(XOp):
f"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}"
)
dims_and_shape = combine_dims_and_shape(inputs)
if dims_and_shape:
output_dims, output_shape = zip(*dims_and_shape.items())
else:
output_dims, output_shape = (), ()
output_dims, output_shape = unzip(combine_dims_and_shape(inputs).items(), n=2)
dummy_scalars = [ps.get_scalar_type(inp.type.dtype)() for inp in inputs]
output_dtypes = [
out.type.dtype for out in self.scalar_op.make_node(*dummy_scalars).outputs
......@@ -99,8 +95,9 @@ class XBlockwise(XOp):
core_inputs_dims, core_outputs_dims = self.core_dims
core_input_dims_set = set(chain.from_iterable(core_inputs_dims))
batch_dims, batch_shape = zip(
*((k, v) for k, v in dims_and_shape.items() if k not in core_input_dims_set)
batch_dims, batch_shape = unzip(
((k, v) for k, v in dims_and_shape.items() if k not in core_input_dims_set),
n=2,
)
dummy_core_inputs = []
......@@ -236,17 +233,16 @@ class XRV(XOp, RNGConsumerOp):
f"Size dimensions {sorted(conflict_dims)} conflict with parameter dimensions. They should be unique."
)
batch_dims_and_shape = [
batch_output_dims, batch_output_shape = unzip(
(
(dim, dim_length)
for dim, dim_length in (
extra_dims_and_shape | params_dims_and_shape
).items()
if dim not in input_core_dims_set
]
if batch_dims_and_shape:
batch_output_dims, batch_output_shape = zip(*batch_dims_and_shape)
else:
batch_output_dims, batch_output_shape = (), ()
),
n=2,
)
dummy_core_inputs = []
for param, core_param_dims in zip(params, param_core_dims):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论