Unverified 提交 b0ba476b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Added `aesara.tensor.full_like` equivalent to `np.full_like` (#567)

上级 41253900
......@@ -1613,6 +1613,24 @@ def full(shape, fill_value, dtype=None):
return alloc(fill_value, *shape)
def full_like(
a: TensorVariable,
fill_value: Union[TensorVariable, int, float],
dtype: Union[str, np.generic, np.dtype] = None,
) -> TensorVariable:
"""Equivalent of `numpy.full_like`.
Returns
-------
tensor
tensor the shape of `a` containing `fill_value` of the type of dtype.
"""
fill_value = as_tensor_variable(fill_value)
if dtype is not None:
fill_value = fill_value.astype(dtype)
return fill(a, fill_value)
class MakeVector(COp):
"""Concatenate a number of scalars together into a vector.
......@@ -4482,6 +4500,7 @@ __all__ = [
"as_tensor",
"extract_diag",
"full",
"full_like",
"empty",
"empty_like",
]
......@@ -55,6 +55,7 @@ from aesara.tensor.basic import (
fill,
flatnonzero,
flatten,
full_like,
get_scalar_constant_value,
get_vector_length,
horizontal_stack,
......@@ -4216,3 +4217,20 @@ class TestTakeAlongAxis:
indices = aet.tensor(np.float64, [False] * 2)
with pytest.raises(IndexError):
aet.take_along_axis(arr, indices)
@pytest.mark.parametrize(
"inp, shape",
[(scalar, ()), (vector, 3), (matrix, (3, 4))],
)
def test_full_like(inp, shape):
fill_value = 5
dtype = config.floatX
x = inp("x")
y = full_like(x, fill_value, dtype=dtype)
np.testing.assert_array_equal(
y.eval({x: np.zeros(shape, dtype=dtype)}),
np.full(shape, fill_value, dtype=dtype),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论