Unverified 提交 f4196f9c authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

block_diag of one matrix is Identity (#1865)

上级 557307a6
...@@ -1305,8 +1305,8 @@ class BaseBlockDiagonal(Op): ...@@ -1305,8 +1305,8 @@ class BaseBlockDiagonal(Op):
input_sig = ",".join(f"(m{i},n{i})" for i in range(n_inputs)) input_sig = ",".join(f"(m{i},n{i})" for i in range(n_inputs))
self.gufunc_signature = f"{input_sig}->(m,n)" self.gufunc_signature = f"{input_sig}->(m,n)"
if n_inputs == 0: if n_inputs <= 1:
raise ValueError("n_inputs must be greater than 0") raise ValueError("n_inputs must be greater than 1")
self.n_inputs = n_inputs self.n_inputs = n_inputs
def grad(self, inputs, gout): def grad(self, inputs, gout):
...@@ -1402,6 +1402,9 @@ def block_diag(*matrices: TensorVariable): ...@@ -1402,6 +1402,9 @@ def block_diag(*matrices: TensorVariable):
[0, 0, 5, 6], [0, 0, 5, 6],
[0, 0, 7, 8]]) [0, 0, 7, 8]])
""" """
if len(matrices) == 1:
return matrices[0]
_block_diagonal_matrix = Blockwise(BlockDiagonal(n_inputs=len(matrices))) _block_diagonal_matrix = Blockwise(BlockDiagonal(n_inputs=len(matrices)))
return _block_diagonal_matrix(*matrices) return _block_diagonal_matrix(*matrices)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论