提交 bb92b35b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix non-symbolic input issues in aesara.tensor.extra_ops helper functions

上级 60cb58f3
...@@ -586,19 +586,21 @@ def squeeze(x, axis=None): ...@@ -586,19 +586,21 @@ def squeeze(x, axis=None):
`x` without `axis` dimensions. `x` without `axis` dimensions.
""" """
_x = at.as_tensor_variable(x)
if axis is None: if axis is None:
# By default exclude all broadcastable (length=1) axes # By default exclude all broadcastable (length=1) axes
axis = (i for i in range(x.ndim) if x.broadcastable[i]) axis = (i for i in range(_x.ndim) if _x.broadcastable[i])
elif not isinstance(axis, Collection): elif not isinstance(axis, Collection):
axis = (axis,) axis = (axis,)
# scalar inputs are treated as 1D regarding axis in this `Op` # scalar inputs are treated as 1D regarding axis in this `Op`
try: try:
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, x.ndim)) axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, _x.ndim))
except np.AxisError: except np.AxisError:
raise np.AxisError(axis, ndim=x.ndim) raise np.AxisError(axis, ndim=_x.ndim)
return x.dimshuffle([i for i in range(x.ndim) if i not in axis]) return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis])
def compress(condition, x, axis=None): def compress(condition, x, axis=None):
...@@ -626,8 +628,9 @@ def compress(condition, x, axis=None): ...@@ -626,8 +628,9 @@ def compress(condition, x, axis=None):
`x` with selected slices. `x` with selected slices.
""" """
_x = at.as_tensor_variable(x)
indices = at.flatnonzero(condition) indices = at.flatnonzero(condition)
return x.take(indices, axis=axis) return _x.take(indices, axis=axis)
class Repeat(Op): class Repeat(Op):
...@@ -1494,12 +1497,14 @@ def broadcast_shape_iter( ...@@ -1494,12 +1497,14 @@ def broadcast_shape_iter(
else: else:
max_dims = max(a.ndim for a in arrays) max_dims = max(a.ndim for a in arrays)
_arrays = tuple(at.as_tensor_variable(a) for a in arrays)
array_shapes = [ array_shapes = [
(one_at,) * (max_dims - a.ndim) (one_at,) * (max_dims - a.ndim)
+ tuple( + tuple(
one_at if bcast else sh for sh, bcast in zip(a.shape, a.broadcastable) one_at if t_sh == 1 else sh for sh, t_sh in zip(a.shape, a.type.shape)
) )
for a in arrays for a in _arrays
] ]
result_dims = [] result_dims = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论