提交 44e20d4a authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Cache tridiagonal solve in numba mode

上级 f9f2080e
...@@ -9,8 +9,10 @@ from scipy import linalg ...@@ -9,8 +9,10 @@ from scipy import linalg
from pytensor import config from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify from pytensor.link.numba.dispatch.basic import (
from pytensor.link.numba.dispatch.basic import generate_fallback_impl generate_fallback_impl,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.linalg._LAPACK import ( from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK, _LAPACK,
int_ptr_to_val, int_ptr_to_val,
...@@ -346,7 +348,7 @@ def _tridiagonal_solve_impl( ...@@ -346,7 +348,7 @@ def _tridiagonal_solve_impl(
return impl return impl
@numba_funcify.register(LUFactorTridiagonal) @register_funcify_default_op_cache_key(LUFactorTridiagonal)
def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs): def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
if any(i.type.numpy_dtype.kind == "c" for i in node.inputs): if any(i.type.numpy_dtype.kind == "c" for i in node.inputs):
return generate_fallback_impl(op, node=node) return generate_fallback_impl(op, node=node)
...@@ -389,10 +391,11 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs): ...@@ -389,10 +391,11 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
) )
return dl, d, du, du2, ipiv return dl, d, du, du2, ipiv
return lu_factor_tridiagonal cache_key = 1
return lu_factor_tridiagonal, cache_key
@numba_funcify.register(SolveLUFactorTridiagonal) @register_funcify_default_op_cache_key(SolveLUFactorTridiagonal)
def numba_funcify_SolveLUFactorTridiagonal( def numba_funcify_SolveLUFactorTridiagonal(
op: SolveLUFactorTridiagonal, node, **kwargs op: SolveLUFactorTridiagonal, node, **kwargs
): ):
...@@ -443,4 +446,5 @@ def numba_funcify_SolveLUFactorTridiagonal( ...@@ -443,4 +446,5 @@ def numba_funcify_SolveLUFactorTridiagonal(
) )
return x return x
return solve_lu_factor_tridiagonal cache_key = 1
return solve_lu_factor_tridiagonal, cache_key
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论