提交 2653ddea authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Linalg Ops: Align output dtypes with those of numpy/scipy

上级 5f6c0103
...@@ -18,8 +18,9 @@ from pytensor.tensor.basic import as_tensor_variable, diagonal ...@@ -18,8 +18,9 @@ from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import ( from pytensor.tensor.type import (
Variable, Variable,
dmatrix,
dvector, dvector,
lscalar, iscalar,
matrix, matrix,
scalar, scalar,
tensor, tensor,
...@@ -37,12 +38,16 @@ class MatrixPinv(Op): ...@@ -37,12 +38,16 @@ class MatrixPinv(Op):
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
assert x.ndim == 2 assert x.ndim == 2
return Apply(self, [x], [x.type()]) if x.type.numpy_dtype.kind in "ibu":
out_dtype = "float64"
else:
out_dtype = x.dtype
return Apply(self, [x], [matrix(shape=x.type.shape, dtype=out_dtype)])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x,) = inputs (x,) = inputs
(z,) = outputs (z,) = outputs
z[0] = np.linalg.pinv(x, hermitian=self.hermitian).astype(x.dtype) z[0] = np.linalg.pinv(x, hermitian=self.hermitian)
def L_op(self, inputs, outputs, g_outputs): def L_op(self, inputs, outputs, g_outputs):
r"""The gradient function should return r"""The gradient function should return
...@@ -117,12 +122,16 @@ class MatrixInverse(Op): ...@@ -117,12 +122,16 @@ class MatrixInverse(Op):
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
assert x.ndim == 2 assert x.ndim == 2
return Apply(self, [x], [x.type()]) if x.type.numpy_dtype.kind in "ibu":
out_dtype = "float64"
else:
out_dtype = x.dtype
return Apply(self, [x], [matrix(shape=x.type.shape, dtype=out_dtype)])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x,) = inputs (x,) = inputs
(z,) = outputs (z,) = outputs
z[0] = np.linalg.inv(x).astype(x.dtype) z[0] = np.linalg.inv(x)
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
r"""The gradient function should return r"""The gradient function should return
...@@ -216,14 +225,18 @@ class Det(Op): ...@@ -216,14 +225,18 @@ class Det(Op):
raise ValueError( raise ValueError(
f"Determinant not defined for non-square matrix inputs. Shape received is {x.type.shape}" f"Determinant not defined for non-square matrix inputs. Shape received is {x.type.shape}"
) )
o = scalar(dtype=x.dtype) if x.type.numpy_dtype.kind in "ibu":
out_dtype = "float64"
else:
out_dtype = x.dtype
o = scalar(dtype=out_dtype)
return Apply(self, [x], [o]) return Apply(self, [x], [o])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x,) = inputs (x,) = inputs
(z,) = outputs (z,) = outputs
try: try:
z[0] = np.asarray(np.linalg.det(x), dtype=x.dtype) z[0] = np.asarray(np.linalg.det(x))
except Exception as e: except Exception as e:
raise ValueError("Failed to compute determinant", x) from e raise ValueError("Failed to compute determinant", x) from e
...@@ -254,15 +267,19 @@ class SLogDet(Op): ...@@ -254,15 +267,19 @@ class SLogDet(Op):
def make_node(self, x): def make_node(self, x):
x = as_tensor_variable(x) x = as_tensor_variable(x)
assert x.ndim == 2 assert x.ndim == 2
sign = scalar(dtype=x.dtype) if x.type.numpy_dtype.kind in "ibu":
det = scalar(dtype=x.dtype) out_dtype = "float64"
else:
out_dtype = x.dtype
sign = scalar(dtype=out_dtype)
det = scalar(dtype=out_dtype)
return Apply(self, [x], [sign, det]) return Apply(self, [x], [sign, det])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x,) = inputs (x,) = inputs
(sign, det) = outputs (sign, det) = outputs
try: try:
sign[0], det[0] = (np.array(z, dtype=x.dtype) for z in np.linalg.slogdet(x)) sign[0], det[0] = (np.array(z) for z in np.linalg.slogdet(x))
except Exception as e: except Exception as e:
raise ValueError("Failed to compute determinant", x) from e raise ValueError("Failed to compute determinant", x) from e
...@@ -735,9 +752,9 @@ class Lstsq(Op): ...@@ -735,9 +752,9 @@ class Lstsq(Op):
self, self,
[x, y, rcond], [x, y, rcond],
[ [
matrix(), dmatrix(),
dvector(), dvector(),
lscalar(), iscalar(),
dvector(), dvector(),
], ],
) )
...@@ -746,7 +763,7 @@ class Lstsq(Op): ...@@ -746,7 +763,7 @@ class Lstsq(Op):
zz = np.linalg.lstsq(inputs[0], inputs[1], inputs[2]) zz = np.linalg.lstsq(inputs[0], inputs[1], inputs[2])
outputs[0][0] = zz[0] outputs[0][0] = zz[0]
outputs[1][0] = zz[1] outputs[1][0] = zz[1]
outputs[2][0] = np.array(zz[2]) outputs[2][0] = np.asarray(zz[2])
outputs[3][0] = zz[3] outputs[3][0] = zz[3]
......
...@@ -491,20 +491,24 @@ class LU(Op): ...@@ -491,20 +491,24 @@ class LU(Op):
f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input" f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
) )
real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d" if x.type.numpy_dtype.kind in "ibu":
p_dtype = "int32" if self.p_indices else np.dtype(real_dtype) if x.type.numpy_dtype.itemsize <= 2:
out_dtype = "float32"
L = tensor(shape=x.type.shape, dtype=x.type.dtype) else:
U = tensor(shape=x.type.shape, dtype=x.type.dtype) out_dtype = "float64"
else:
out_dtype = x.type.dtype
L = tensor(shape=x.type.shape, dtype=out_dtype)
U = tensor(shape=x.type.shape, dtype=out_dtype)
if self.permute_l: if self.permute_l:
# In this case, L is actually P @ L # In this case, L is actually P @ L
return Apply(self, inputs=[x], outputs=[L, U]) return Apply(self, inputs=[x], outputs=[L, U])
if self.p_indices: if self.p_indices:
p_indices = tensor(shape=(x.type.shape[0],), dtype=p_dtype) p_indices = tensor(shape=(x.type.shape[0],), dtype="int32")
return Apply(self, inputs=[x], outputs=[p_indices, L, U]) return Apply(self, inputs=[x], outputs=[p_indices, L, U])
P = tensor(shape=x.type.shape, dtype=p_dtype) P = tensor(shape=x.type.shape, dtype=out_dtype)
return Apply(self, inputs=[x], outputs=[P, L, U]) return Apply(self, inputs=[x], outputs=[P, L, U])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
......
...@@ -502,12 +502,13 @@ class TestLstsq: ...@@ -502,12 +502,13 @@ class TestLstsq:
z = lscalar() z = lscalar()
b = lstsq(x, y, z) b = lstsq(x, y, z)
f = function([x, y, z], b) f = function([x, y, z], b)
TestMatrix1 = np.asarray([[2, 1], [3, 4]]) TestMatrix1 = np.asarray([[2, 1], [3, 4]])
TestMatrix2 = np.asarray([[17, 20], [43, 50]]) TestMatrix2 = np.asarray([[17, 20], [43, 50]])
TestScalar = np.asarray(1) TestScalar = np.asarray(1)
f = function([x, y, z], b) m0, _, rank, _ = f(TestMatrix1, TestMatrix2, TestScalar)
m = f(TestMatrix1, TestMatrix2, TestScalar) assert rank.dtype == "int32"
assert np.allclose(TestMatrix2, np.dot(TestMatrix1, m[0])) assert np.allclose(TestMatrix2, np.dot(TestMatrix1, m0))
def test_wrong_coefficient_matrix(self): def test_wrong_coefficient_matrix(self):
x = vector() x = vector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论