提交 4e617870 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix mistaken output types in aesara.tensor.nlinalg

上级 13bb75c1
...@@ -270,7 +270,7 @@ class Eigh(Eig): ...@@ -270,7 +270,7 @@ class Eigh(Eig):
# input. # input.
w_dtype = self._numop([[np.dtype(x.dtype).type()]])[0].dtype.name w_dtype = self._numop([[np.dtype(x.dtype).type()]])[0].dtype.name
w = vector(dtype=w_dtype) w = vector(dtype=w_dtype)
v = matrix(dtype=x.dtype) v = matrix(dtype=w_dtype)
return Apply(self, [x], [w, v]) return Apply(self, [x], [w, v])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
...@@ -407,15 +407,18 @@ class QRFull(Op): ...@@ -407,15 +407,18 @@ class QRFull(Op):
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
assert x.ndim == 2, "The input of qr function should be a matrix." assert x.ndim == 2, "The input of qr function should be a matrix."
in_dtype = x.type.numpy_dtype in_dtype = x.type.numpy_dtype
out_dtype = np.dtype(f"f{in_dtype.itemsize}") out_dtype = np.dtype(f"f{in_dtype.itemsize}")
q = matrix(dtype=out_dtype)
if self.mode != "raw": if self.mode != "raw":
r = matrix(dtype=x.dtype) r = matrix(dtype=out_dtype)
else: else:
r = vector(dtype=x.dtype) r = vector(dtype=out_dtype)
if self.mode != "r": if self.mode != "r":
q = matrix(dtype=out_dtype) q = matrix(dtype=out_dtype)
...@@ -507,10 +510,15 @@ class SVD(Op): ...@@ -507,10 +510,15 @@ class SVD(Op):
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
assert x.ndim == 2, "The input of svd function should be a matrix." assert x.ndim == 2, "The input of svd function should be a matrix."
s = vector(dtype=x.dtype)
in_dtype = x.type.numpy_dtype
out_dtype = np.dtype(f"f{in_dtype.itemsize}")
s = vector(dtype=out_dtype)
if self.compute_uv: if self.compute_uv:
u = matrix(dtype=x.dtype) u = matrix(dtype=out_dtype)
vt = matrix(dtype=x.dtype) vt = matrix(dtype=out_dtype)
return Apply(self, [x], [u, s, vt]) return Apply(self, [x], [u, s, vt])
else: else:
return Apply(self, [x], [s]) return Apply(self, [x], [s])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论