提交 18f245fa authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Expand batched_vector_b_solve_to_matrix rewrite

It now supports an arbitrary number of batched dimensions of b, by raveling them together
上级 2751bcc6
...@@ -138,14 +138,12 @@ def generic_solve_to_solve_triangular(fgraph, node): ...@@ -138,14 +138,12 @@ def generic_solve_to_solve_triangular(fgraph, node):
] ]
@register_stabilize
@register_specialize @register_specialize
@node_rewriter([Blockwise]) @node_rewriter([Blockwise])
def batched_vector_b_solve_to_matrix_b_solve(fgraph, node): def batched_vector_b_solve_to_matrix_b_solve(fgraph, node):
"""Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T """Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T
`a` must have no batched dimensions, while `b` can have arbitrary batched dimensions. `a` must have no batched dimensions, while `b` can have arbitrary batched dimensions.
Only the last two dimensions of `b` and the output are swapped.
""" """
core_op = node.op.core_op core_op = node.op.core_op
...@@ -175,8 +173,17 @@ def batched_vector_b_solve_to_matrix_b_solve(fgraph, node): ...@@ -175,8 +173,17 @@ def batched_vector_b_solve_to_matrix_b_solve(fgraph, node):
new_core_op = type(core_op)(**props) new_core_op = type(core_op)(**props)
matrix_b_solve = Blockwise(new_core_op) matrix_b_solve = Blockwise(new_core_op)
# Ravel any batched dims
original_b_shape = tuple(b.shape)
if len(original_b_shape) > 2:
b = b.reshape((-1, original_b_shape[-1]))
# Apply the rewrite # Apply the rewrite
new_solve = _T(matrix_b_solve(a, _T(b))) new_solve = matrix_b_solve(a, b.T).T
# Unravel any batched dims
if len(original_b_shape) > 2:
new_solve = new_solve.reshape(original_b_shape)
old_solve = node.outputs[0] old_solve = node.outputs[0]
copy_stack_trace(old_solve, new_solve) copy_stack_trace(old_solve, new_solve)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论