提交 d3bd1f15 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add more specialized static output shape to Eye

Importantly, it now provides broadcastability information which is needed elsewhere
上级 28d9d4dc
......@@ -1273,6 +1273,7 @@ def triu_indices_from(
class Eye(Op):
_output_type_depends_on_input_value = True
__props__ = ("dtype",)
def __init__(self, dtype=None):
......@@ -1287,10 +1288,13 @@ class Eye(Op):
assert n.ndim == 0
assert m.ndim == 0
assert k.ndim == 0
_, static_shape = infer_static_shape((n, m))
return Apply(
self,
[n, m, k],
[TensorType(dtype=self.dtype, shape=(None, None))()],
[TensorType(dtype=self.dtype, shape=static_shape)()],
)
def perform(self, node, inp, out_):
......
......@@ -937,38 +937,46 @@ def test_infer_static_shape():
assert static_shape == (1,)
# This is slow for the ('int8', 3) version.
def test_eye():
def check(dtype, N, M_=None, k=0):
# PyTensor does not accept None as a tensor.
# So we must use a real value.
M = M_
# Currently DebugMode does not support None as inputs even if this is
# allowed.
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
M = N
N_symb = iscalar()
M_symb = iscalar()
k_symb = iscalar()
f = function([N_symb, M_symb, k_symb], eye(N_symb, M_symb, k_symb, dtype=dtype))
result = f(N, M, k)
assert np.allclose(result, np.eye(N, M_, k, dtype=dtype))
assert result.dtype == np.dtype(dtype)
class TestEye:
# This is slow for the ('int8', 3) version.
def test_basic(self):
def check(dtype, N, M_=None, k=0):
# PyTensor does not accept None as a tensor.
# So we must use a real value.
M = M_
# Currently DebugMode does not support None as inputs even if this is
# allowed.
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
M = N
N_symb = iscalar()
M_symb = iscalar()
k_symb = iscalar()
f = function(
[N_symb, M_symb, k_symb], eye(N_symb, M_symb, k_symb, dtype=dtype)
)
result = f(N, M, k)
assert np.allclose(result, np.eye(N, M_, k, dtype=dtype))
assert result.dtype == np.dtype(dtype)
for dtype in ALL_DTYPES:
check(dtype, 3)
# M != N, k = 0
check(dtype, 3, 5)
check(dtype, 5, 3)
# N == M, k != 0
check(dtype, 3, 3, 1)
check(dtype, 3, 3, -1)
# N < M, k != 0
check(dtype, 3, 5, 1)
check(dtype, 3, 5, -1)
# N > M, k != 0
check(dtype, 5, 3, 1)
check(dtype, 5, 3, -1)
for dtype in ALL_DTYPES:
check(dtype, 3)
# M != N, k = 0
check(dtype, 3, 5)
check(dtype, 5, 3)
# N == M, k != 0
check(dtype, 3, 3, 1)
check(dtype, 3, 3, -1)
# N < M, k != 0
check(dtype, 3, 5, 1)
check(dtype, 3, 5, -1)
# N > M, k != 0
check(dtype, 5, 3, 1)
check(dtype, 5, 3, -1)
def test_static_output_type(self):
l = lscalar("l")
assert eye(5, 3, l).type.shape == (5, 3)
assert eye(1, l, 3).type.shape == (1, None)
class TestTriangle:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论