Add rewrite for matmul when only one of the inputs has batched dimensions
This rewrites replaces a batched matmul by a core matmul by raveling the batched dimensions of the batched input, doing a 2D matmul and then unravelling the batched dimensions of the output.
This sidesteps the Blockwise Dot operation and allows specialization into BLAS routines.
The idea was taken from these two discussions:
https://github.com/numpy/numpy/issues/7569
https://github.com/numpy/numpy/issues/8957
正在显示
请
注册
或者
登录
后发表评论