提交 d09e222b authored 作者: Tommy Guy's avatar Tommy Guy 提交者: Ricardo Vieira

Add dtype option to identity_like

Resolves #816
上级 7393b744
...@@ -1422,8 +1422,22 @@ def eye(n, m=None, k=0, dtype=None): ...@@ -1422,8 +1422,22 @@ def eye(n, m=None, k=0, dtype=None):
return localop(n, m, k) return localop(n, m, k)
def identity_like(x): def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
return eye(x.shape[0], x.shape[1], k=0, dtype=x.dtype) """Create a tensor with ones on main diagonal and zeroes elsewhere.
Parameters
----------
x : tensor
dtype : data-type, optional
Returns
-------
tensor
tensor the shape of x with ones on main diagonal and zeroes elsewhere of type of dtype.
"""
if dtype is None:
dtype = x.dtype
return eye(x.shape[0], x.shape[1], k=0, dtype=dtype)
def infer_broadcastable(shape): def infer_broadcastable(shape):
......
...@@ -777,12 +777,13 @@ Creating Tensors ...@@ -777,12 +777,13 @@ Creating Tensors
:returns: An array where all elements are equal to zero, except for the `k`-th :returns: An array where all elements are equal to zero, except for the `k`-th
diagonal, whose values are equal to one. diagonal, whose values are equal to one.
.. function:: identity_like(x) .. function:: identity_like(x, dtype=None)
:param x: tensor :param x: tensor
:param dtype: The dtype of the returned tensor. If `None`, default to dtype of `x`
:returns: A tensor of same shape as `x` that is filled with zeros everywhere :returns: A tensor of same shape as `x` that is filled with zeros everywhere
except for the main diagonal, whose values are equal to one. The output except for the main diagonal, whose values are equal to one. The output
will have same dtype as `x`. will have same dtype as `x` unless overridden in `dtype`.
.. function:: stack(tensors, axis=0) .. function:: stack(tensors, axis=0)
......
...@@ -59,6 +59,7 @@ from aesara.tensor.basic import ( ...@@ -59,6 +59,7 @@ from aesara.tensor.basic import (
get_scalar_constant_value, get_scalar_constant_value,
get_vector_length, get_vector_length,
horizontal_stack, horizontal_stack,
identity_like,
infer_broadcastable, infer_broadcastable,
inverse_permutation, inverse_permutation,
join, join,
...@@ -4392,6 +4393,15 @@ def test_empty(): ...@@ -4392,6 +4393,15 @@ def test_empty():
assert res.dtype == "int64" assert res.dtype == "int64"
def test_identity_like_dtype():
# Test that we allocate eye correctly via identity_like
m = matrix(dtype="int64")
m_out = identity_like(m)
assert m_out.dtype == m.dtype
m_out_float = identity_like(m, dtype=np.float64)
assert m_out_float.dtype == "float64"
def test_atleast_Nd(): def test_atleast_Nd():
ary1 = dscalar() ary1 = dscalar()
res_ary1 = atleast_Nd(ary1, n=1) res_ary1 = atleast_Nd(ary1, n=1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论