提交 2fc5cc4e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix passing M=None to function in Eye test

上级 ef5bcb50
......@@ -1453,8 +1453,7 @@ def eye(n, m=None, k=0, dtype=None):
dtype = config.floatX
if m is None:
m = n
localop = Eye(dtype)
return localop(n, m, k)
return Eye(dtype)(n, m, k)
def identity_like(x, dtype: str | np.generic | np.dtype | None = None):
......
......@@ -934,22 +934,18 @@ def test_infer_static_shape():
class TestEye:
# This is slow for the ('int8', 3) version.
def test_basic(self):
def check(dtype, N, M_=None, k=0):
# PyTensor does not accept None as a tensor.
# So we must use a real value.
M = M_
# Currently DebugMode does not support None as inputs even if this is
# allowed.
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
M = N
def check(dtype, N, M=None, k=0):
N_symb = iscalar()
M_symb = iscalar()
k_symb = iscalar()
test_inputs = [N, k] if M is None else [N, M, k]
inputs = [N_symb, k_symb] if M is None else [N_symb, M_symb, k_symb]
f = function(
[N_symb, M_symb, k_symb], eye(N_symb, M_symb, k_symb, dtype=dtype)
inputs,
eye(N_symb, None if (M is None) else M_symb, k_symb, dtype=dtype),
)
result = f(N, M, k)
assert np.allclose(result, np.eye(N, M_, k, dtype=dtype))
result = f(*test_inputs)
assert np.allclose(result, np.eye(N, M, k, dtype=dtype))
assert result.dtype == np.dtype(dtype)
for dtype in ALL_DTYPES:
......@@ -1755,7 +1751,7 @@ class TestJoinAndSplit:
got = f(-2)
assert np.allclose(got, want)
with pytest.raises(ValueError):
with pytest.raises((ValueError, IndexError)):
f(-3)
@pytest.mark.parametrize("py_impl", (False, True))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论