提交 44e20d4a authored 作者: jessegrabowski's avatar jessegrabowski 提交者: Ricardo Vieira

Cache tridiagonal solve in numba mode

上级 f9f2080e
......@@ -9,8 +9,10 @@ from scipy import linalg
from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import generate_fallback_impl
from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.linalg._LAPACK import (
_LAPACK,
int_ptr_to_val,
......@@ -346,7 +348,7 @@ def _tridiagonal_solve_impl(
return impl
@numba_funcify.register(LUFactorTridiagonal)
@register_funcify_default_op_cache_key(LUFactorTridiagonal)
def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
if any(i.type.numpy_dtype.kind == "c" for i in node.inputs):
return generate_fallback_impl(op, node=node)
......@@ -389,10 +391,11 @@ def numba_funcify_LUFactorTridiagonal(op: LUFactorTridiagonal, node, **kwargs):
)
return dl, d, du, du2, ipiv
return lu_factor_tridiagonal
cache_key = 1
return lu_factor_tridiagonal, cache_key
@numba_funcify.register(SolveLUFactorTridiagonal)
@register_funcify_default_op_cache_key(SolveLUFactorTridiagonal)
def numba_funcify_SolveLUFactorTridiagonal(
op: SolveLUFactorTridiagonal, node, **kwargs
):
......@@ -443,4 +446,5 @@ def numba_funcify_SolveLUFactorTridiagonal(
)
return x
return solve_lu_factor_tridiagonal
cache_key = 1
return solve_lu_factor_tridiagonal, cache_key
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论