提交 0b4d684f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Add empty-safe unzip helper

上级 f4196f9c
...@@ -37,7 +37,7 @@ from pytensor.tensor.utils import ( ...@@ -37,7 +37,7 @@ from pytensor.tensor.utils import (
normalize_reduce_axis, normalize_reduce_axis,
) )
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
from pytensor.utils import uniq from pytensor.utils import uniq, unzip
class DimShuffle(ExternalCOp): class DimShuffle(ExternalCOp):
...@@ -765,8 +765,8 @@ class Elemwise(OpenMPOp): ...@@ -765,8 +765,8 @@ class Elemwise(OpenMPOp):
# assert that inames and inputs order stay consistent. # assert that inames and inputs order stay consistent.
# This is to protect again futur change of uniq. # This is to protect again futur change of uniq.
assert len(inames) == len(inputs) assert len(inames) == len(inputs)
ii, iii = list( ii, iii = unzip(
zip(*uniq(list(zip(_inames, node.inputs, strict=True))), strict=True) uniq(list(zip(_inames, node.inputs, strict=True))), n=2, strict=True
) )
assert all(x == y for x, y in zip(ii, inames, strict=True)) assert all(x == y for x, y in zip(ii, inames, strict=True))
assert all(x == y for x, y in zip(iii, inputs, strict=True)) assert all(x == y for x, y in zip(iii, inputs, strict=True))
......
...@@ -35,6 +35,7 @@ from pytensor.tensor.math import tensordot ...@@ -35,6 +35,7 @@ from pytensor.tensor.math import tensordot
from pytensor.tensor.reshape import pack, unpack from pytensor.tensor.reshape import pack, unpack
from pytensor.tensor.slinalg import solve from pytensor.tensor.slinalg import solve
from pytensor.tensor.variable import TensorVariable, Variable from pytensor.tensor.variable import TensorVariable, Variable
from pytensor.utils import unzip
# scipy.optimize can be slow to import, and will not be used by most users # scipy.optimize can be slow to import, and will not be used by most users
...@@ -297,7 +298,7 @@ class ScipyScalarWrapperOp(ScipyWrapperOp): ...@@ -297,7 +298,7 @@ class ScipyScalarWrapperOp(ScipyWrapperOp):
# No differentiable arguments, return disconnected gradients # No differentiable arguments, return disconnected gradients
return arg_grads return arg_grads
outer_args_to_diff, df_dthetas = zip(*valid_args_and_grads) outer_args_to_diff, df_dthetas = unzip(valid_args_and_grads, n=2)
replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True)) replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
df_dx_star, *df_dthetas_stars = graph_replace( df_dx_star, *df_dthetas_stars = graph_replace(
......
...@@ -22,6 +22,7 @@ from pytensor.tensor.basic import as_tensor_variable, diagonal ...@@ -22,6 +22,7 @@ from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import matrix, tensor, vector from pytensor.tensor.type import matrix, tensor, vector
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
from pytensor.utils import unzip
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -1323,7 +1324,7 @@ class BaseBlockDiagonal(Op): ...@@ -1323,7 +1324,7 @@ class BaseBlockDiagonal(Op):
return [gout[0][slc] for slc in slices] return [gout[0][slc] for slc in slices]
def infer_shape(self, fgraph, nodes, shapes): def infer_shape(self, fgraph, nodes, shapes):
first, second = zip(*shapes, strict=True) first, second = unzip(shapes, n=2, strict=True)
return [(pt.add(*first), pt.add(*second))] return [(pt.add(*first), pt.add(*second))]
def _validate_and_prepare_inputs(self, matrices, as_tensor_func): def _validate_and_prepare_inputs(self, matrices, as_tensor_func):
......
...@@ -70,6 +70,7 @@ from pytensor.tensor.type_other import ( ...@@ -70,6 +70,7 @@ from pytensor.tensor.type_other import (
make_slice, make_slice,
) )
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
from pytensor.utils import unzip
_logger = logging.getLogger("pytensor.tensor.subtensor") _logger = logging.getLogger("pytensor.tensor.subtensor")
...@@ -650,7 +651,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False): ...@@ -650,7 +651,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
) )
for basic, grp_dim_indices in idx_groups: for basic, grp_dim_indices in idx_groups:
dim_nums, grp_indices = zip(*grp_dim_indices, strict=True) dim_nums, grp_indices = unzip(grp_dim_indices, n=2, strict=True)
remaining_dims = tuple(dim for dim in remaining_dims if dim not in dim_nums) remaining_dims = tuple(dim for dim in remaining_dims if dim not in dim_nums)
if basic: if basic:
......
...@@ -338,3 +338,14 @@ class Singleton: ...@@ -338,3 +338,14 @@ class Singleton:
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def unzip(iterable, n: int, strict: bool = False):
"""Unzip a nested iterable, returns n empty tuples if empty.
It can be safely unpacked into n variables.
"""
res = tuple(zip(*iterable, strict=strict))
if not res:
return ((),) * n
return res
...@@ -12,6 +12,7 @@ from pytensor.tensor import as_tensor, get_scalar_constant_value ...@@ -12,6 +12,7 @@ from pytensor.tensor import as_tensor, get_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import integer_dtypes from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.utils import get_static_shape_from_size_variables from pytensor.tensor.utils import get_static_shape_from_size_variables
from pytensor.utils import unzip
from pytensor.xtensor.basic import XOp from pytensor.xtensor.basic import XOp
from pytensor.xtensor.math import cast, second from pytensor.xtensor.math import cast, second
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
...@@ -296,7 +297,7 @@ class Concat(XOp): ...@@ -296,7 +297,7 @@ class Concat(XOp):
if concat_dim not in inp.type.dims: if concat_dim not in inp.type.dims:
dims_and_shape[concat_dim] += 1 dims_and_shape[concat_dim] += 1
dims, shape = zip(*dims_and_shape.items()) dims, shape = unzip(dims_and_shape.items(), n=2)
dtype = upcast(*[x.type.dtype for x in inputs]) dtype = upcast(*[x.type.dtype for x in inputs])
output = xtensor(dtype=dtype, dims=dims, shape=shape) output = xtensor(dtype=dtype, dims=dims, shape=shape)
return Apply(self, inputs, [output]) return Apply(self, inputs, [output])
......
...@@ -13,6 +13,7 @@ from pytensor.tensor.random.type import RandomType ...@@ -13,6 +13,7 @@ from pytensor.tensor.random.type import RandomType
from pytensor.tensor.utils import ( from pytensor.tensor.utils import (
get_static_shape_from_size_variables, get_static_shape_from_size_variables,
) )
from pytensor.utils import unzip
from pytensor.xtensor.basic import XOp from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
...@@ -57,12 +58,7 @@ class XElemwise(XOp): ...@@ -57,12 +58,7 @@ class XElemwise(XOp):
f"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}" f"Wrong number of inputs, expected {self.scalar_op.nin}, got {len(inputs)}"
) )
dims_and_shape = combine_dims_and_shape(inputs) output_dims, output_shape = unzip(combine_dims_and_shape(inputs).items(), n=2)
if dims_and_shape:
output_dims, output_shape = zip(*dims_and_shape.items())
else:
output_dims, output_shape = (), ()
dummy_scalars = [ps.get_scalar_type(inp.type.dtype)() for inp in inputs] dummy_scalars = [ps.get_scalar_type(inp.type.dtype)() for inp in inputs]
output_dtypes = [ output_dtypes = [
out.type.dtype for out in self.scalar_op.make_node(*dummy_scalars).outputs out.type.dtype for out in self.scalar_op.make_node(*dummy_scalars).outputs
...@@ -99,8 +95,9 @@ class XBlockwise(XOp): ...@@ -99,8 +95,9 @@ class XBlockwise(XOp):
core_inputs_dims, core_outputs_dims = self.core_dims core_inputs_dims, core_outputs_dims = self.core_dims
core_input_dims_set = set(chain.from_iterable(core_inputs_dims)) core_input_dims_set = set(chain.from_iterable(core_inputs_dims))
batch_dims, batch_shape = zip( batch_dims, batch_shape = unzip(
*((k, v) for k, v in dims_and_shape.items() if k not in core_input_dims_set) ((k, v) for k, v in dims_and_shape.items() if k not in core_input_dims_set),
n=2,
) )
dummy_core_inputs = [] dummy_core_inputs = []
...@@ -236,17 +233,16 @@ class XRV(XOp, RNGConsumerOp): ...@@ -236,17 +233,16 @@ class XRV(XOp, RNGConsumerOp):
f"Size dimensions {sorted(conflict_dims)} conflict with parameter dimensions. They should be unique." f"Size dimensions {sorted(conflict_dims)} conflict with parameter dimensions. They should be unique."
) )
batch_dims_and_shape = [ batch_output_dims, batch_output_shape = unzip(
(dim, dim_length) (
for dim, dim_length in ( (dim, dim_length)
extra_dims_and_shape | params_dims_and_shape for dim, dim_length in (
).items() extra_dims_and_shape | params_dims_and_shape
if dim not in input_core_dims_set ).items()
] if dim not in input_core_dims_set
if batch_dims_and_shape: ),
batch_output_dims, batch_output_shape = zip(*batch_dims_and_shape) n=2,
else: )
batch_output_dims, batch_output_shape = (), ()
dummy_core_inputs = [] dummy_core_inputs = []
for param, core_param_dims in zip(params, param_core_dims): for param, core_param_dims in zip(params, param_core_dims):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论