提交 3c1a6c33 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add missing NumPy interface functions `numpy.full` and `numpy.empty`

上级 7a81d896
...@@ -1581,6 +1581,28 @@ alloc = Alloc() ...@@ -1581,6 +1581,28 @@ alloc = Alloc()
pprint.assign(alloc, printing.FunctionPrinter("alloc")) pprint.assign(alloc, printing.FunctionPrinter("alloc"))
def full(shape, fill_value, dtype=None):
"""Return a new array of given shape and type, filled with `fill_value`.
See ``numpy.full``.
Parameters
----------
shape : int or sequence of ints
Shape of the new array, e.g., ``(2, 3)`` or ``2``.
fill_value : scalar or array_like
Fill value.
dtype : data-type, optional
The desired data-type for the array The default, None, means
`np.array(fill_value).dtype`.
"""
fill_value = as_tensor_variable(fill_value)
if dtype:
fill_value = fill_value.astype(dtype)
return alloc(fill_value, *shape)
class MakeVector(COp): class MakeVector(COp):
"""Concatenate a number of scalars together into a vector. """Concatenate a number of scalars together into a vector.
...@@ -4249,6 +4271,24 @@ class AllocEmpty(COp): ...@@ -4249,6 +4271,24 @@ class AllocEmpty(COp):
return [zeros(inputs, self.dtype)] return [zeros(inputs, self.dtype)]
def empty(shape, dtype=None):
"""Return a new array of given shape and type, without initializing entries.
See ``numpy.empty``.
Parameters
----------
shape : int or tuple of int
Shape of the empty array, e.g., ``(2, 3)`` or ``2``.
dtype : data-type, optional
Desired output data-type for the array, e.g, `numpy.int8`. Default is
`numpy.float64`.
"""
if dtype is None:
dtype = config.floatX
return AllocEmpty(dtype)(*shape)
__all__ = [ __all__ = [
"choose", "choose",
"swapaxes", "swapaxes",
...@@ -4305,4 +4345,6 @@ __all__ = [ ...@@ -4305,4 +4345,6 @@ __all__ = [
"as_tensor_variable", "as_tensor_variable",
"as_tensor", "as_tensor",
"extract_diag", "extract_diag",
"full",
"empty",
] ]
...@@ -643,6 +643,11 @@ class TestAlloc: ...@@ -643,6 +643,11 @@ class TestAlloc:
inp = np.zeros(shp, dtype=config.floatX) inp = np.zeros(shp, dtype=config.floatX)
assert np.allclose(zeros_tensor(inp), np.zeros(shp)) assert np.allclose(zeros_tensor(inp), np.zeros(shp))
def test_full(self):
full_at = aet.full((2, 3), 3, dtype="int64")
res = aesara.function([], full_at, mode=self.mode)()
assert np.array_equal(res, np.full((2, 3), 3, dtype="int64"))
# This is slow for the ('int8', 3) version. # This is slow for the ('int8', 3) version.
def test_eye(): def test_eye():
...@@ -4104,3 +4109,12 @@ def test_allocempty(): ...@@ -4104,3 +4109,12 @@ def test_allocempty():
assert out.shape == (2, 3) assert out.shape == (2, 3)
assert out.dtype == "float32" assert out.dtype == "float32"
empty_at = aet.empty((2, 3), dtype=None)
res = aesara.function([], empty_at)()
assert res.shape == (2, 3)
empty_at = aet.empty((2, 3), dtype="int64")
res = aesara.function([], empty_at)()
assert res.shape == (2, 3)
assert res.dtype == "int64"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论