提交 db449148 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix non-symbolic input issues in aesara.tensor.basic helper functions

上级 1704f22a
...@@ -926,13 +926,15 @@ def ones_like(model, dtype=None, opt=False): ...@@ -926,13 +926,15 @@ def ones_like(model, dtype=None, opt=False):
tensor tensor
tensor the shape of model containing ones of the type of dtype. tensor the shape of model containing ones of the type of dtype.
""" """
_model = as_tensor_variable(model)
if dtype is None: if dtype is None:
dtype = model.type.dtype dtype = _model.type.dtype
ret = constant(1.0, dtype=dtype) ret = constant(1.0, dtype=dtype)
# TODO: Remove this weird option # TODO: Remove this weird option
if opt and ret.type == model.type: if opt and ret.type == _model.type:
return ret return ret
return fill(model, ret) return fill(_model, ret)
def zeros_like(model, dtype=None, opt=False): def zeros_like(model, dtype=None, opt=False):
...@@ -951,13 +953,15 @@ def zeros_like(model, dtype=None, opt=False): ...@@ -951,13 +953,15 @@ def zeros_like(model, dtype=None, opt=False):
tensor the shape of model containing zeros of the type of dtype. tensor the shape of model containing zeros of the type of dtype.
""" """
_model = as_tensor_variable(model)
if dtype is None: if dtype is None:
dtype = model.type.dtype dtype = _model.type.dtype
ret = constant(0.0, dtype=dtype) ret = constant(0.0, dtype=dtype)
# TODO: Remove this weird option # TODO: Remove this weird option
if opt and ret.type == model.type: if opt and ret.type == _model.type:
return ret return ret
return fill(model, ret) return fill(_model, ret)
def zeros(shape, dtype=None): def zeros(shape, dtype=None):
...@@ -1122,7 +1126,8 @@ def nonzero_values(a): ...@@ -1122,7 +1126,8 @@ def nonzero_values(a):
flattened input array. flattened input array.
""" """
return a.flatten()[flatnonzero(a)] _a = as_tensor_variable(a)
return _a.flatten()[flatnonzero(_a)]
class Tri(Op): class Tri(Op):
...@@ -1915,11 +1920,12 @@ def transpose(x, axes=None): ...@@ -1915,11 +1920,12 @@ def transpose(x, axes=None):
This is a macro around dimshuffle that matches the numpy.transpose function. This is a macro around dimshuffle that matches the numpy.transpose function.
""" """
_x = as_tensor_variable(x)
if axes is None: if axes is None:
axes = list(range((x.ndim - 1), -1, -1)) axes = list(range((_x.ndim - 1), -1, -1))
ret = DimShuffle(x.broadcastable, axes)(x) ret = DimShuffle(_x.broadcastable, axes)(_x)
if x.name and axes == list(range((x.ndim - 1), -1, -1)): if _x.name and axes == list(range((_x.ndim - 1), -1, -1)):
ret.name = x.name + ".T" ret.name = _x.name + ".T"
return ret return ret
...@@ -2802,31 +2808,41 @@ def concatenate(tensor_list, axis=0): ...@@ -2802,31 +2808,41 @@ def concatenate(tensor_list, axis=0):
def horizontal_stack(*args): def horizontal_stack(*args):
""" r"""Stack arrays in sequence horizontally (column wise)."""
Horizontally stack two L{TensorType}s.
Stack two L{TensorType}s along the second axis (column wise). These
L{TensorType}s must have the same shape along all dimensions but the
second.
"""
# Note: 'horizontal_stack' and 'vertical_stack' do not behave exactly like # Note: 'horizontal_stack' and 'vertical_stack' do not behave exactly like
# Numpy's hstack and vstack functions. This is intended, because Numpy's # Numpy's hstack and vstack functions. This is intended, because Numpy's
# functions have potentially confusing/incoherent behavior (try them on 1D # functions have potentially confusing/incoherent behavior (try them on 1D
# arrays). If this is fixed in a future version of Numpy, it may be worth # arrays). If this is fixed in a future version of Numpy, it may be worth
# trying to get closer to Numpy's way of doing things. In the meantime, # trying to get closer to Numpy's way of doing things. In the meantime,
# better keep different names to emphasize the implementation divergences. # better keep different names to emphasize the implementation divergences.
assert len(args) >= 2
if len(args) < 2:
raise ValueError("Too few arguments")
_args = []
for arg in args: for arg in args:
assert arg.type.ndim == 2 _arg = as_tensor_variable(arg)
return concatenate(args, axis=1) if _arg.type.ndim != 2:
raise ValueError("All arguments must have two dimensions")
_args.append(_arg)
return concatenate(_args, axis=1)
def vertical_stack(*args): def vertical_stack(*args):
assert len(args) >= 2 r"""Stack arrays in sequence vertically (row wise)."""
if len(args) < 2:
raise ValueError("Too few arguments")
_args = []
for arg in args: for arg in args:
assert arg.type.ndim == 2 _arg = as_tensor_variable(arg)
return concatenate(args, axis=0) if _arg.type.ndim != 2:
raise ValueError("All arguments must have two dimensions")
_args.append(_arg)
return concatenate(_args, axis=0)
class Flatten(COp): class Flatten(COp):
...@@ -3042,19 +3058,21 @@ def flatten(x, ndim=1): ...@@ -3042,19 +3058,21 @@ def flatten(x, ndim=1):
if ndim is None: if ndim is None:
ndim = 1 ndim = 1
_x = as_tensor_variable(x)
# Any input variable can be flattened to have ndim of 1, # Any input variable can be flattened to have ndim of 1,
# even if it's a scalar. Otherwise, ndim must be positive # even if it's a scalar. Otherwise, ndim must be positive
# and smaller than x.ndim. # and smaller than x.ndim.
if ndim < 1 or (ndim > 1 and ndim > x.ndim): if ndim < 1 or (ndim > 1 and ndim > _x.ndim):
raise ValueError(f"ndim {ndim} out of bound [1, {x.ndim + 1})") raise ValueError(f"ndim {ndim} out of bound [1, {_x.ndim + 1})")
if ndim > 1: if ndim > 1:
dims = tuple(x.shape[: ndim - 1]) + (-1,) dims = tuple(_x.shape[: ndim - 1]) + (-1,)
else: else:
dims = (-1,) dims = (-1,)
x_reshaped = x.reshape(dims) x_reshaped = _x.reshape(dims)
bcast_kept_dims = x.broadcastable[: ndim - 1] bcast_kept_dims = _x.broadcastable[: ndim - 1]
bcast_new_dim = builtins.all(x.broadcastable[ndim - 1 :]) bcast_new_dim = builtins.all(_x.broadcastable[ndim - 1 :])
broadcastable = bcast_kept_dims + (bcast_new_dim,) broadcastable = bcast_kept_dims + (bcast_new_dim,)
x_reshaped = addbroadcast(x_reshaped, *[i for i in range(ndim) if broadcastable[i]]) x_reshaped = addbroadcast(x_reshaped, *[i for i in range(ndim) if broadcastable[i]])
return x_reshaped return x_reshaped
...@@ -4260,7 +4278,8 @@ def empty(shape, dtype=None): ...@@ -4260,7 +4278,8 @@ def empty(shape, dtype=None):
def empty_like( def empty_like(
prototype: TensorVariable, dtype: Optional[Union[str, np.generic, np.dtype]] = None prototype: Union[np.ndarray, TensorVariable],
dtype: Optional[Union[str, np.generic, np.dtype]] = None,
) -> TensorVariable: ) -> TensorVariable:
"""Return a new array with the same shape and type as a given array. """Return a new array with the same shape and type as a given array.
......
...@@ -4529,3 +4529,14 @@ def test_full_like(inp, shape): ...@@ -4529,3 +4529,14 @@ def test_full_like(inp, shape):
y.eval({x: np.zeros(shape, dtype=dtype)}), y.eval({x: np.zeros(shape, dtype=dtype)}),
np.full(shape, fill_value, dtype=dtype), np.full(shape, fill_value, dtype=dtype),
) )
@pytest.mark.parametrize("func", [horizontal_stack, vertical_stack])
def test_oriented_stack_functions(func):
with pytest.raises(ValueError):
func()
a = at.tensor(np.float64, shape=(None, None, None))
with pytest.raises(ValueError):
func(a, a)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论