提交 2653ddea authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Linalg Ops: Align output dtypes with those of numpy/scipy

上级 5f6c0103
......@@ -18,8 +18,9 @@ from pytensor.tensor.basic import as_tensor_variable, diagonal
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.type import (
Variable,
dmatrix,
dvector,
lscalar,
iscalar,
matrix,
scalar,
tensor,
......@@ -37,12 +38,16 @@ class MatrixPinv(Op):
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2
return Apply(self, [x], [x.type()])
if x.type.numpy_dtype.kind in "ibu":
out_dtype = "float64"
else:
out_dtype = x.dtype
return Apply(self, [x], [matrix(shape=x.type.shape, dtype=out_dtype)])
def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
z[0] = np.linalg.pinv(x, hermitian=self.hermitian).astype(x.dtype)
z[0] = np.linalg.pinv(x, hermitian=self.hermitian)
def L_op(self, inputs, outputs, g_outputs):
r"""The gradient function should return
......@@ -117,12 +122,16 @@ class MatrixInverse(Op):
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2
return Apply(self, [x], [x.type()])
if x.type.numpy_dtype.kind in "ibu":
out_dtype = "float64"
else:
out_dtype = x.dtype
return Apply(self, [x], [matrix(shape=x.type.shape, dtype=out_dtype)])
def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
z[0] = np.linalg.inv(x).astype(x.dtype)
z[0] = np.linalg.inv(x)
def grad(self, inputs, g_outputs):
r"""The gradient function should return
......@@ -216,14 +225,18 @@ class Det(Op):
raise ValueError(
f"Determinant not defined for non-square matrix inputs. Shape received is {x.type.shape}"
)
o = scalar(dtype=x.dtype)
if x.type.numpy_dtype.kind in "ibu":
out_dtype = "float64"
else:
out_dtype = x.dtype
o = scalar(dtype=out_dtype)
return Apply(self, [x], [o])
def perform(self, node, inputs, outputs):
(x,) = inputs
(z,) = outputs
try:
z[0] = np.asarray(np.linalg.det(x), dtype=x.dtype)
z[0] = np.asarray(np.linalg.det(x))
except Exception as e:
raise ValueError("Failed to compute determinant", x) from e
......@@ -254,15 +267,19 @@ class SLogDet(Op):
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2
sign = scalar(dtype=x.dtype)
det = scalar(dtype=x.dtype)
if x.type.numpy_dtype.kind in "ibu":
out_dtype = "float64"
else:
out_dtype = x.dtype
sign = scalar(dtype=out_dtype)
det = scalar(dtype=out_dtype)
return Apply(self, [x], [sign, det])
def perform(self, node, inputs, outputs):
(x,) = inputs
(sign, det) = outputs
try:
sign[0], det[0] = (np.array(z, dtype=x.dtype) for z in np.linalg.slogdet(x))
sign[0], det[0] = (np.array(z) for z in np.linalg.slogdet(x))
except Exception as e:
raise ValueError("Failed to compute determinant", x) from e
......@@ -735,9 +752,9 @@ class Lstsq(Op):
self,
[x, y, rcond],
[
matrix(),
dmatrix(),
dvector(),
lscalar(),
iscalar(),
dvector(),
],
)
......@@ -746,7 +763,7 @@ class Lstsq(Op):
zz = np.linalg.lstsq(inputs[0], inputs[1], inputs[2])
outputs[0][0] = zz[0]
outputs[1][0] = zz[1]
outputs[2][0] = np.array(zz[2])
outputs[2][0] = np.asarray(zz[2])
outputs[3][0] = zz[3]
......
......@@ -491,20 +491,24 @@ class LU(Op):
f"LU only allowed on matrix (2-D) inputs, got {x.type.ndim}-D input"
)
real_dtype = "f" if np.dtype(x.type.dtype).char in "fF" else "d"
p_dtype = "int32" if self.p_indices else np.dtype(real_dtype)
L = tensor(shape=x.type.shape, dtype=x.type.dtype)
U = tensor(shape=x.type.shape, dtype=x.type.dtype)
if x.type.numpy_dtype.kind in "ibu":
if x.type.numpy_dtype.itemsize <= 2:
out_dtype = "float32"
else:
out_dtype = "float64"
else:
out_dtype = x.type.dtype
L = tensor(shape=x.type.shape, dtype=out_dtype)
U = tensor(shape=x.type.shape, dtype=out_dtype)
if self.permute_l:
# In this case, L is actually P @ L
return Apply(self, inputs=[x], outputs=[L, U])
if self.p_indices:
p_indices = tensor(shape=(x.type.shape[0],), dtype=p_dtype)
p_indices = tensor(shape=(x.type.shape[0],), dtype="int32")
return Apply(self, inputs=[x], outputs=[p_indices, L, U])
P = tensor(shape=x.type.shape, dtype=p_dtype)
P = tensor(shape=x.type.shape, dtype=out_dtype)
return Apply(self, inputs=[x], outputs=[P, L, U])
def perform(self, node, inputs, outputs):
......
......@@ -502,12 +502,13 @@ class TestLstsq:
z = lscalar()
b = lstsq(x, y, z)
f = function([x, y, z], b)
TestMatrix1 = np.asarray([[2, 1], [3, 4]])
TestMatrix2 = np.asarray([[17, 20], [43, 50]])
TestScalar = np.asarray(1)
f = function([x, y, z], b)
m = f(TestMatrix1, TestMatrix2, TestScalar)
assert np.allclose(TestMatrix2, np.dot(TestMatrix1, m[0]))
m0, _, rank, _ = f(TestMatrix1, TestMatrix2, TestScalar)
assert rank.dtype == "int32"
assert np.allclose(TestMatrix2, np.dot(TestMatrix1, m0))
def test_wrong_coefficient_matrix(self):
x = vector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论