提交 f191a07e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add shape_tuple helper function

上级 0cf56c68
......@@ -42,6 +42,7 @@ from aesara.tensor.shape import (
shape_padaxis,
shape_padleft,
shape_padright,
shape_tuple,
)
from aesara.tensor.type import (
TensorType,
......@@ -4152,7 +4153,7 @@ class Choose(Op):
else:
choice = as_tensor_variable(choices)
(out_shape,) = self.infer_shape(
None, None, [tuple(a.shape), tuple(shape(choice))]
None, None, [shape_tuple(a), shape_tuple(choice)]
)
bcast = []
......
......@@ -20,6 +20,7 @@ from aesara.tensor.elemwise import Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.random.type import RandomType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
from aesara.tensor.shape import shape_tuple
from aesara.tensor.type import TensorType, all_dtypes
......@@ -199,7 +200,7 @@ class RandomVariable(Op):
# Broadcast the parameters
param_shapes = params_broadcast_shapes(
param_shapes or [p.shape for p in dist_params], self.ndims_params
param_shapes or [shape_tuple(p) for p in dist_params], self.ndims_params
)
def slice_ind_dims(p, ps, n):
......
......@@ -129,6 +129,19 @@ shape = Shape()
_shape = shape # was used in the past, now use shape directly.
def shape_tuple(x):
"""Get a tuple of symbolic shape values.
This will return a `ScalarConstant` with the value ``1`` wherever
broadcastable is ``True``.
"""
one_at = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1)
return tuple(
one_at if getattr(sh, "value", sh) == 1 or bcast else sh
for sh, bcast in zip(shape(x), getattr(x, "broadcastable", (False,) * x.ndim))
)
class Shape_i(COp):
"""
L{Op} to return the shape of a matrix.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论