提交 d3bd1f15 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add more specialized static output shape to Eye

Importantly, it now provides broadcastability information which is needed elsewhere
上级 28d9d4dc
...@@ -1273,6 +1273,7 @@ def triu_indices_from( ...@@ -1273,6 +1273,7 @@ def triu_indices_from(
class Eye(Op): class Eye(Op):
_output_type_depends_on_input_value = True
__props__ = ("dtype",) __props__ = ("dtype",)
def __init__(self, dtype=None): def __init__(self, dtype=None):
...@@ -1287,10 +1288,13 @@ class Eye(Op): ...@@ -1287,10 +1288,13 @@ class Eye(Op):
assert n.ndim == 0 assert n.ndim == 0
assert m.ndim == 0 assert m.ndim == 0
assert k.ndim == 0 assert k.ndim == 0
_, static_shape = infer_static_shape((n, m))
return Apply( return Apply(
self, self,
[n, m, k], [n, m, k],
[TensorType(dtype=self.dtype, shape=(None, None))()], [TensorType(dtype=self.dtype, shape=static_shape)()],
) )
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
......
...@@ -937,38 +937,46 @@ def test_infer_static_shape(): ...@@ -937,38 +937,46 @@ def test_infer_static_shape():
assert static_shape == (1,) assert static_shape == (1,)
# This is slow for the ('int8', 3) version. class TestEye:
def test_eye(): # This is slow for the ('int8', 3) version.
def check(dtype, N, M_=None, k=0): def test_basic(self):
# PyTensor does not accept None as a tensor. def check(dtype, N, M_=None, k=0):
# So we must use a real value. # PyTensor does not accept None as a tensor.
M = M_ # So we must use a real value.
# Currently DebugMode does not support None as inputs even if this is M = M_
# allowed. # Currently DebugMode does not support None as inputs even if this is
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]: # allowed.
M = N if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
N_symb = iscalar() M = N
M_symb = iscalar() N_symb = iscalar()
k_symb = iscalar() M_symb = iscalar()
f = function([N_symb, M_symb, k_symb], eye(N_symb, M_symb, k_symb, dtype=dtype)) k_symb = iscalar()
result = f(N, M, k) f = function(
assert np.allclose(result, np.eye(N, M_, k, dtype=dtype)) [N_symb, M_symb, k_symb], eye(N_symb, M_symb, k_symb, dtype=dtype)
assert result.dtype == np.dtype(dtype) )
result = f(N, M, k)
assert np.allclose(result, np.eye(N, M_, k, dtype=dtype))
assert result.dtype == np.dtype(dtype)
for dtype in ALL_DTYPES: for dtype in ALL_DTYPES:
check(dtype, 3) check(dtype, 3)
# M != N, k = 0 # M != N, k = 0
check(dtype, 3, 5) check(dtype, 3, 5)
check(dtype, 5, 3) check(dtype, 5, 3)
# N == M, k != 0 # N == M, k != 0
check(dtype, 3, 3, 1) check(dtype, 3, 3, 1)
check(dtype, 3, 3, -1) check(dtype, 3, 3, -1)
# N < M, k != 0 # N < M, k != 0
check(dtype, 3, 5, 1) check(dtype, 3, 5, 1)
check(dtype, 3, 5, -1) check(dtype, 3, 5, -1)
# N > M, k != 0 # N > M, k != 0
check(dtype, 5, 3, 1) check(dtype, 5, 3, 1)
check(dtype, 5, 3, -1) check(dtype, 5, 3, -1)
def test_static_output_type(self):
l = lscalar("l")
assert eye(5, 3, l).type.shape == (5, 3)
assert eye(1, l, 3).type.shape == (1, None)
class TestTriangle: class TestTriangle:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论