提交 4e617870 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix mistaken output types in aesara.tensor.nlinalg

上级 13bb75c1
......@@ -270,7 +270,7 @@ class Eigh(Eig):
# input.
w_dtype = self._numop([[np.dtype(x.dtype).type()]])[0].dtype.name
w = vector(dtype=w_dtype)
v = matrix(dtype=x.dtype)
v = matrix(dtype=w_dtype)
return Apply(self, [x], [w, v])
def perform(self, node, inputs, outputs):
......@@ -407,15 +407,18 @@ class QRFull(Op):
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2, "The input of qr function should be a matrix."
in_dtype = x.type.numpy_dtype
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
q = matrix(dtype=out_dtype)
if self.mode != "raw":
r = matrix(dtype=x.dtype)
r = matrix(dtype=out_dtype)
else:
r = vector(dtype=x.dtype)
r = vector(dtype=out_dtype)
if self.mode != "r":
q = matrix(dtype=out_dtype)
......@@ -507,10 +510,15 @@ class SVD(Op):
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2, "The input of svd function should be a matrix."
s = vector(dtype=x.dtype)
in_dtype = x.type.numpy_dtype
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
s = vector(dtype=out_dtype)
if self.compute_uv:
u = matrix(dtype=x.dtype)
vt = matrix(dtype=x.dtype)
u = matrix(dtype=out_dtype)
vt = matrix(dtype=out_dtype)
return Apply(self, [x], [u, s, vt])
else:
return Apply(self, [x], [s])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论