提交 d09e222b authored 作者: Tommy Guy's avatar Tommy Guy 提交者: Ricardo Vieira

Add dtype option to identity_like

Resolves #816
上级 7393b744
......@@ -1422,8 +1422,22 @@ def eye(n, m=None, k=0, dtype=None):
return localop(n, m, k)
def identity_like(x):
return eye(x.shape[0], x.shape[1], k=0, dtype=x.dtype)
def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
"""Create a tensor with ones on main diagonal and zeroes elsewhere.
Parameters
----------
x : tensor
dtype : data-type, optional
Returns
-------
tensor
tensor the shape of x with ones on main diagonal and zeroes elsewhere of type of dtype.
"""
if dtype is None:
dtype = x.dtype
return eye(x.shape[0], x.shape[1], k=0, dtype=dtype)
def infer_broadcastable(shape):
......
......@@ -777,12 +777,13 @@ Creating Tensors
:returns: An array where all elements are equal to zero, except for the `k`-th
diagonal, whose values are equal to one.
.. function:: identity_like(x)
.. function:: identity_like(x, dtype=None)
:param x: tensor
:param dtype: The dtype of the returned tensor. If `None`, default to dtype of `x`
:returns: A tensor of same shape as `x` that is filled with zeros everywhere
except for the main diagonal, whose values are equal to one. The output
will have same dtype as `x`.
will have same dtype as `x` unless overridden in `dtype`.
.. function:: stack(tensors, axis=0)
......
......@@ -59,6 +59,7 @@ from aesara.tensor.basic import (
get_scalar_constant_value,
get_vector_length,
horizontal_stack,
identity_like,
infer_broadcastable,
inverse_permutation,
join,
......@@ -4392,6 +4393,15 @@ def test_empty():
assert res.dtype == "int64"
def test_identity_like_dtype():
# Test that we allocate eye correctly via identity_like
m = matrix(dtype="int64")
m_out = identity_like(m)
assert m_out.dtype == m.dtype
m_out_float = identity_like(m, dtype=np.float64)
assert m_out_float.dtype == "float64"
def test_atleast_Nd():
ary1 = dscalar()
res_ary1 = atleast_Nd(ary1, n=1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论