提交 e51e8787 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Thomas Wiecki

Add an empty_like implementation

上级 61c3e37b
...@@ -12,7 +12,7 @@ from collections import OrderedDict ...@@ -12,7 +12,7 @@ from collections import OrderedDict
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial from functools import partial
from numbers import Number from numbers import Number
from typing import Dict, Tuple, Union from typing import Dict, Optional, Tuple, Union
import numpy as np import numpy as np
from numpy.core.multiarray import normalize_axis_index from numpy.core.multiarray import normalize_axis_index
...@@ -4299,6 +4299,27 @@ def empty(shape, dtype=None): ...@@ -4299,6 +4299,27 @@ def empty(shape, dtype=None):
return AllocEmpty(dtype)(*shape) return AllocEmpty(dtype)(*shape)
def empty_like(
prototype: TensorVariable, dtype: Optional[Union[str, np.generic, np.dtype]] = None
) -> TensorVariable:
"""Return a new array with the same shape and type as a given array.
See ``numpy.empty_like``.
Parameters
----------
prototype
The shape and data-type of `prototype` define these same attributes
of the returned array.
dtype : data-type, optional
Overrides the data type of the result.
"""
if dtype is None:
dtype = prototype.dtype
return empty(shape(prototype), dtype)
def atleast_Nd( def atleast_Nd(
*arys: Union[np.ndarray, TensorVariable], n: int = 1, left: bool = True *arys: Union[np.ndarray, TensorVariable], n: int = 1, left: bool = True
) -> TensorVariable: ) -> TensorVariable:
...@@ -4462,4 +4483,5 @@ __all__ = [ ...@@ -4462,4 +4483,5 @@ __all__ = [
"extract_diag", "extract_diag",
"full", "full",
"empty", "empty",
"empty_like",
] ]
...@@ -4104,7 +4104,7 @@ class TestChoose(utt.InferShapeTester): ...@@ -4104,7 +4104,7 @@ class TestChoose(utt.InferShapeTester):
) )
def test_allocempty(): def test_empty():
# Test that we allocated correctly # Test that we allocated correctly
f = aesara.function([], AllocEmpty("float32")(2, 3)) f = aesara.function([], AllocEmpty("float32")(2, 3))
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
...@@ -4122,6 +4122,11 @@ def test_allocempty(): ...@@ -4122,6 +4122,11 @@ def test_allocempty():
assert res.shape == (2, 3) assert res.shape == (2, 3)
assert res.dtype == "int64" assert res.dtype == "int64"
empty_at = aet.empty_like(empty_at)
res = aesara.function([], empty_at)()
assert res.shape == (2, 3)
assert res.dtype == "int64"
def test_atleast_Nd(): def test_atleast_Nd():
ary1 = dscalar() ary1 = dscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论