Unverified 提交 d2120171 authored 作者: David laid's avatar David laid 提交者: GitHub

Fix incorrect static output shape in MatrixPinv for non-square inputs

上级 c647bb23
...@@ -42,7 +42,11 @@ class MatrixPinv(Op): ...@@ -42,7 +42,11 @@ class MatrixPinv(Op):
out_dtype = "float64" out_dtype = "float64"
else: else:
out_dtype = x.dtype out_dtype = x.dtype
return Apply(self, [x], [matrix(shape=x.type.shape, dtype=out_dtype)]) return Apply(
self,
[x],
[matrix(shape=(x.type.shape[1], x.type.shape[0]), dtype=out_dtype)],
)
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(x,) = inputs (x,) = inputs
......
...@@ -70,6 +70,21 @@ def test_pseudoinverse_grad(): ...@@ -70,6 +70,21 @@ def test_pseudoinverse_grad():
utt.verify_grad(pinv, [r]) utt.verify_grad(pinv, [r])
def test_pseudoinverse_static_shape():
x = matrix(shape=(3, 5))
z = pinv(x)
assert z.type.shape == (5, 3)
g = pytensor.grad(z.sum(), x)
f = function([x], g)
rng = np.random.default_rng(utt.fetch_seed())
r = rng.standard_normal((3, 5)).astype(config.floatX)
assert f(r).shape == (3, 5)
utt.verify_grad(pinv, [r])
class TestMatrixInverse(utt.InferShapeTester): class TestMatrixInverse(utt.InferShapeTester):
def setup_method(self): def setup_method(self):
super().setup_method() super().setup_method()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论