Unverified 提交 b9fc4f8e authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Preserve static shape information in `block_diag` (#1529)

上级 68d8dc72
...@@ -1651,7 +1651,18 @@ class BlockDiagonal(BaseBlockDiagonal): ...@@ -1651,7 +1651,18 @@ class BlockDiagonal(BaseBlockDiagonal):
def make_node(self, *matrices): def make_node(self, *matrices):
matrices = self._validate_and_prepare_inputs(matrices, pt.as_tensor) matrices = self._validate_and_prepare_inputs(matrices, pt.as_tensor)
dtype = _largest_common_dtype(matrices) dtype = _largest_common_dtype(matrices)
out_type = pytensor.tensor.matrix(dtype=dtype)
shapes_by_dim = tuple(zip(*(m.type.shape for m in matrices)))
out_shape = tuple(
[
sum(dim_shapes)
if not any(shape is None for shape in dim_shapes)
else None
for dim_shapes in shapes_by_dim
]
)
out_type = pytensor.tensor.matrix(shape=out_shape, dtype=dtype)
return Apply(self, matrices, [out_type]) return Apply(self, matrices, [out_type])
def perform(self, node, inputs, output_storage, params=None): def perform(self, node, inputs, output_storage, params=None):
......
...@@ -1040,11 +1040,28 @@ def test_block_diagonal(): ...@@ -1040,11 +1040,28 @@ def test_block_diagonal():
A = np.array([[1.0, 2.0], [3.0, 4.0]]) A = np.array([[1.0, 2.0], [3.0, 4.0]])
B = np.array([[5.0, 6.0], [7.0, 8.0]]) B = np.array([[5.0, 6.0], [7.0, 8.0]])
result = block_diag(A, B) result = block_diag(A, B)
assert result.type.shape == (4, 4)
assert result.owner.op.core_op._props_dict() == {"n_inputs": 2} assert result.owner.op.core_op._props_dict() == {"n_inputs": 2}
np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(A, B)) np.testing.assert_allclose(result.eval(), scipy.linalg.block_diag(A, B))
def test_block_diagonal_static_shape():
A = pt.dmatrix("A", shape=(5, 5))
B = pt.dmatrix("B", shape=(3, 10))
result = block_diag(A, B)
assert result.type.shape == (8, 15)
A = pt.dmatrix("A", shape=(5, 5))
B = pt.dmatrix("B", shape=(3, None))
result = block_diag(A, B)
assert result.type.shape == (8, None)
A = pt.dmatrix("A", shape=(None, 5))
result = block_diag(A, B)
assert result.type.shape == (None, None)
def test_block_diagonal_grad(): def test_block_diagonal_grad():
A = np.array([[1.0, 2.0], [3.0, 4.0]]) A = np.array([[1.0, 2.0], [3.0, 4.0]])
B = np.array([[5.0, 6.0], [7.0, 8.0]]) B = np.array([[5.0, 6.0], [7.0, 8.0]])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论