提交 f4fb4833 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

LU Op: Fix P out dtype for complex input

上级 d016564b
......@@ -508,7 +508,12 @@ class LU(Op):
p_indices = tensor(shape=(x.type.shape[0],), dtype="int32")
return Apply(self, inputs=[x], outputs=[p_indices, L, U])
P = tensor(shape=x.type.shape, dtype=out_dtype)
if out_dtype.startswith("complex"):
P_dtype = "float64" if out_dtype == "complex128" else "float32"
else:
P_dtype = out_dtype
P = tensor(shape=x.type.shape, dtype=P_dtype)
return Apply(self, inputs=[x], outputs=[P, L, U])
def perform(self, node, inputs, outputs):
......
......@@ -648,9 +648,9 @@ def test_lu_decomposition(
dtype = config.floatX if not complex else f"complex{int(config.floatX[-2:]) * 2}"
A = tensor("A", shape=shape, dtype=dtype)
out = lu(A, permute_l=permute_l, p_indices=p_indices)
pt_out = lu(A, permute_l=permute_l, p_indices=p_indices)
f = function([A], out)
f = function([A], pt_out)
rng = np.random.default_rng(utt.fetch_seed())
x = rng.normal(size=shape).astype(config.floatX)
......@@ -658,6 +658,8 @@ def test_lu_decomposition(
x = x + 1j * rng.normal(size=shape).astype(config.floatX)
out = f(x)
for numerical_out, symbolic_out in zip(out, pt_out):
assert numerical_out.dtype == symbolic_out.type.dtype
if permute_l:
PL, U = out
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论