提交 e51e8787 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Thomas Wiecki

Add an empty_like implementation

上级 61c3e37b
......@@ -12,7 +12,7 @@ from collections import OrderedDict
from collections.abc import Sequence
from functools import partial
from numbers import Number
from typing import Dict, Tuple, Union
from typing import Dict, Optional, Tuple, Union
import numpy as np
from numpy.core.multiarray import normalize_axis_index
......@@ -4299,6 +4299,27 @@ def empty(shape, dtype=None):
return AllocEmpty(dtype)(*shape)
def empty_like(
prototype: TensorVariable, dtype: Optional[Union[str, np.generic, np.dtype]] = None
) -> TensorVariable:
"""Return a new array with the same shape and type as a given array.
See ``numpy.empty_like``.
Parameters
----------
prototype
The shape and data-type of `prototype` define these same attributes
of the returned array.
dtype : data-type, optional
Overrides the data type of the result.
"""
if dtype is None:
dtype = prototype.dtype
return empty(shape(prototype), dtype)
def atleast_Nd(
*arys: Union[np.ndarray, TensorVariable], n: int = 1, left: bool = True
) -> TensorVariable:
......@@ -4462,4 +4483,5 @@ __all__ = [
"extract_diag",
"full",
"empty",
"empty_like",
]
......@@ -4104,7 +4104,7 @@ class TestChoose(utt.InferShapeTester):
)
def test_allocempty():
def test_empty():
# Test that we allocated correctly
f = aesara.function([], AllocEmpty("float32")(2, 3))
assert len(f.maker.fgraph.apply_nodes) == 1
......@@ -4122,6 +4122,11 @@ def test_allocempty():
assert res.shape == (2, 3)
assert res.dtype == "int64"
empty_at = aet.empty_like(empty_at)
res = aesara.function([], empty_at)()
assert res.shape == (2, 3)
assert res.dtype == "int64"
def test_atleast_Nd():
ary1 = dscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论