Unverified 提交 e40c8274 authored 作者: ltoniazzi's avatar ltoniazzi 提交者: GitHub

Add tensor conversion to `flatnonzero`, `nonzero_values`, `tile`, `inverse_permutation`, and `diag`

上级 c2ed818e
...@@ -940,9 +940,10 @@ def flatnonzero(a): ...@@ -940,9 +940,10 @@ def flatnonzero(a):
nonzero_values : Return the non-zero elements of the input array nonzero_values : Return the non-zero elements of the input array
""" """
if a.ndim == 0: _a = as_tensor_variable(a)
if _a.ndim == 0:
raise ValueError("Nonzero only supports non-scalar arrays.") raise ValueError("Nonzero only supports non-scalar arrays.")
return nonzero(a.flatten(), return_matrix=False)[0] return nonzero(_a.flatten(), return_matrix=False)[0]
def nonzero_values(a): def nonzero_values(a):
...@@ -1324,9 +1325,10 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None): ...@@ -1324,9 +1325,10 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
tensor tensor
tensor the shape of x with ones on main diagonal and zeroes elsewhere of type of dtype. tensor the shape of x with ones on main diagonal and zeroes elsewhere of type of dtype.
""" """
_x = as_tensor_variable(x)
if dtype is None: if dtype is None:
dtype = x.dtype dtype = _x.dtype
return eye(x.shape[0], x.shape[1], k=0, dtype=dtype) return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype)
def infer_broadcastable(shape): def infer_broadcastable(shape):
...@@ -2773,8 +2775,9 @@ def tile(x, reps, ndim=None): ...@@ -2773,8 +2775,9 @@ def tile(x, reps, ndim=None):
""" """
from aesara.tensor.math import ge from aesara.tensor.math import ge
if ndim is not None and ndim < x.ndim: _x = as_tensor_variable(x)
raise ValueError("ndim should be equal or larger than x.ndim") if ndim is not None and ndim < _x.ndim:
raise ValueError("ndim should be equal or larger than _x.ndim")
# If reps is a scalar, integer or vector, we convert it to a list. # If reps is a scalar, integer or vector, we convert it to a list.
if not isinstance(reps, (list, tuple)): if not isinstance(reps, (list, tuple)):
...@@ -2799,8 +2802,8 @@ def tile(x, reps, ndim=None): ...@@ -2799,8 +2802,8 @@ def tile(x, reps, ndim=None):
# assert that reps.shape[0] does not exceed ndim # assert that reps.shape[0] does not exceed ndim
offset = assert_op(offset, ge(offset, 0)) offset = assert_op(offset, ge(offset, 0))
# if reps.ndim is less than x.ndim, we pad the reps with # if reps.ndim is less than _x.ndim, we pad the reps with
# "1" so that reps will have the same ndim as x. # "1" so that reps will have the same ndim as _x.
reps_ = [switch(i < offset, 1, reps[i - offset]) for i in range(ndim)] reps_ = [switch(i < offset, 1, reps[i - offset]) for i in range(ndim)]
reps = reps_ reps = reps_
...@@ -2817,17 +2820,17 @@ def tile(x, reps, ndim=None): ...@@ -2817,17 +2820,17 @@ def tile(x, reps, ndim=None):
): ):
raise ValueError("elements of reps must be scalars of integer dtype") raise ValueError("elements of reps must be scalars of integer dtype")
# If reps.ndim is less than x.ndim, we pad the reps with # If reps.ndim is less than _x.ndim, we pad the reps with
# "1" so that reps will have the same ndim as x # "1" so that reps will have the same ndim as _x
reps = list(reps) reps = list(reps)
if ndim is None: if ndim is None:
ndim = builtins.max(len(reps), x.ndim) ndim = builtins.max(len(reps), _x.ndim)
if len(reps) < ndim: if len(reps) < ndim:
reps = [1] * (ndim - len(reps)) + reps reps = [1] * (ndim - len(reps)) + reps
_shape = [1] * (ndim - x.ndim) + [x.shape[i] for i in range(x.ndim)] _shape = [1] * (ndim - _x.ndim) + [_x.shape[i] for i in range(_x.ndim)]
alloc_shape = reps + _shape alloc_shape = reps + _shape
y = alloc(x, *alloc_shape) y = alloc(_x, *alloc_shape)
shuffle_ind = np.arange(ndim * 2).reshape(2, ndim) shuffle_ind = np.arange(ndim * 2).reshape(2, ndim)
shuffle_ind = shuffle_ind.transpose().flatten() shuffle_ind = shuffle_ind.transpose().flatten()
y = y.dimshuffle(*shuffle_ind) y = y.dimshuffle(*shuffle_ind)
...@@ -3288,8 +3291,9 @@ def inverse_permutation(perm): ...@@ -3288,8 +3291,9 @@ def inverse_permutation(perm):
Each row of input should contain a permutation of the first integers. Each row of input should contain a permutation of the first integers.
""" """
_perm = as_tensor_variable(perm)
return permute_row_elements( return permute_row_elements(
arange(perm.shape[-1], dtype=perm.dtype), perm, inverse=True arange(_perm.shape[-1], dtype=_perm.dtype), _perm, inverse=True
) )
...@@ -3575,12 +3579,14 @@ def diag(v, k=0): ...@@ -3575,12 +3579,14 @@ def diag(v, k=0):
""" """
if v.ndim == 1: _v = as_tensor_variable(v)
return AllocDiag(k)(v)
elif v.ndim >= 2: if _v.ndim == 1:
return diagonal(v, offset=k) return AllocDiag(k)(_v)
elif _v.ndim >= 2:
return diagonal(_v, offset=k)
else: else:
raise ValueError("Input must has v.ndim >= 1.") raise ValueError("Number of dimensions of `v` must be greater than one.")
def stacklists(arg): def stacklists(arg):
......
...@@ -1023,6 +1023,12 @@ class TestNonzero: ...@@ -1023,6 +1023,12 @@ class TestNonzero:
rand2d[:4] = 0 rand2d[:4] = 0
check(rand2d) check(rand2d)
# Test passing a list
m = [1, 2, 0]
out = flatnonzero(m)
f = function([], out)
assert np.array_equal(f(), np.flatnonzero(m))
@config.change_flags(compute_test_value="raise") @config.change_flags(compute_test_value="raise")
def test_nonzero_values(self): def test_nonzero_values(self):
def check(m): def check(m):
...@@ -1449,8 +1455,6 @@ class TestJoinAndSplit: ...@@ -1449,8 +1455,6 @@ class TestJoinAndSplit:
assert (out == want).all() assert (out == want).all()
# Pass a list to make sure `a` is converted to a
# TensorVariable by roll
a = [1, 2, 3, 4, 5, 6] a = [1, 2, 3, 4, 5, 6]
b = roll(a, get_shift(2)) b = roll(a, get_shift(2))
want = np.array([5, 6, 1, 2, 3, 4]) want = np.array([5, 6, 1, 2, 3, 4])
...@@ -2221,6 +2225,20 @@ def test_tile(): ...@@ -2221,6 +2225,20 @@ def test_tile():
== np.tile(x_, (2, 3, 4, 6)) == np.tile(x_, (2, 3, 4, 6))
) )
# Test passing a float
x = scalar()
x_val = 1.0
assert np.array_equal(
run_tile(x, x_val, (2,), use_symbolic_reps), np.tile(x_val, (2,))
)
# Test when x is a list
x = matrix()
x_val = [[1.0, 2.0], [3.0, 4.0]]
assert np.array_equal(
run_tile(x, x_val, (2,), use_symbolic_reps), np.tile(x_val, (2,))
)
# Test when reps is integer, scalar or vector. # Test when reps is integer, scalar or vector.
# Test 1,2,3,4-dimensional cases. # Test 1,2,3,4-dimensional cases.
# Test input x has the shape [2], [2, 4], [2, 4, 3], [2, 4, 3, 5]. # Test input x has the shape [2], [2, 4], [2, 4, 3], [2, 4, 3, 5].
...@@ -2794,6 +2812,12 @@ class TestInversePermutation: ...@@ -2794,6 +2812,12 @@ class TestInversePermutation:
assert np.all(p_val[inv_val] == np.arange(10)) assert np.all(p_val[inv_val] == np.arange(10))
assert np.all(inv_val[p_val] == np.arange(10)) assert np.all(inv_val[p_val] == np.arange(10))
# Test passing a list
p = [2, 4, 3, 0, 1]
inv = at.inverse_permutation(p)
f = aesara.function([], inv)
assert np.array_equal(f(), np.array([3, 4, 0, 2, 1]))
def test_dim2(self): def test_dim2(self):
# Test the inversion of several permutations at a time # Test the inversion of several permutations at a time
# Each row of p is a different permutation to inverse # Each row of p is a different permutation to inverse
...@@ -3449,6 +3473,12 @@ class TestDiag: ...@@ -3449,6 +3473,12 @@ class TestDiag:
with pytest.raises(ValueError): with pytest.raises(ValueError):
diag(xx) diag(xx)
# Test passing a list
xx = [[1, 2], [3, 4]]
g = diag(xx)
f = function([], g)
assert np.array_equal(f(), np.diag(xx))
def test_infer_shape(self): def test_infer_shape(self):
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
...@@ -4136,6 +4166,12 @@ def test_identity_like_dtype(): ...@@ -4136,6 +4166,12 @@ def test_identity_like_dtype():
m_out_float = identity_like(m, dtype=np.float64) m_out_float = identity_like(m, dtype=np.float64)
assert m_out_float.dtype == "float64" assert m_out_float.dtype == "float64"
# Test passing list
m = [[0, 1], [1, 3]]
out = at.identity_like(m)
f = aesara.function([], out)
assert np.array_equal(f(), np.eye(2))
def test_atleast_Nd(): def test_atleast_Nd():
ary1 = dscalar() ary1 = dscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论