提交 f191a07e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add shape_tuple helper function

上级 0cf56c68
...@@ -42,6 +42,7 @@ from aesara.tensor.shape import ( ...@@ -42,6 +42,7 @@ from aesara.tensor.shape import (
shape_padaxis, shape_padaxis,
shape_padleft, shape_padleft,
shape_padright, shape_padright,
shape_tuple,
) )
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
...@@ -4152,7 +4153,7 @@ class Choose(Op): ...@@ -4152,7 +4153,7 @@ class Choose(Op):
else: else:
choice = as_tensor_variable(choices) choice = as_tensor_variable(choices)
(out_shape,) = self.infer_shape( (out_shape,) = self.infer_shape(
None, None, [tuple(a.shape), tuple(shape(choice))] None, None, [shape_tuple(a), shape_tuple(choice)]
) )
bcast = [] bcast = []
......
...@@ -20,6 +20,7 @@ from aesara.tensor.elemwise import Elemwise ...@@ -20,6 +20,7 @@ from aesara.tensor.elemwise import Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.random.type import RandomType from aesara.tensor.random.type import RandomType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
from aesara.tensor.shape import shape_tuple
from aesara.tensor.type import TensorType, all_dtypes from aesara.tensor.type import TensorType, all_dtypes
...@@ -199,7 +200,7 @@ class RandomVariable(Op): ...@@ -199,7 +200,7 @@ class RandomVariable(Op):
# Broadcast the parameters # Broadcast the parameters
param_shapes = params_broadcast_shapes( param_shapes = params_broadcast_shapes(
param_shapes or [p.shape for p in dist_params], self.ndims_params param_shapes or [shape_tuple(p) for p in dist_params], self.ndims_params
) )
def slice_ind_dims(p, ps, n): def slice_ind_dims(p, ps, n):
......
...@@ -129,6 +129,19 @@ shape = Shape() ...@@ -129,6 +129,19 @@ shape = Shape()
_shape = shape # was used in the past, now use shape directly. _shape = shape # was used in the past, now use shape directly.
def shape_tuple(x):
"""Get a tuple of symbolic shape values.
This will return a `ScalarConstant` with the value ``1`` wherever
broadcastable is ``True``.
"""
one_at = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1)
return tuple(
one_at if getattr(sh, "value", sh) == 1 or bcast else sh
for sh, bcast in zip(shape(x), getattr(x, "broadcastable", (False,) * x.ndim))
)
class Shape_i(COp): class Shape_i(COp):
""" """
L{Op} to return the shape of a matrix. L{Op} to return the shape of a matrix.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论