提交 131982de authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove duplicated Solve dispatch

上级 789c0d17
......@@ -8,8 +8,6 @@ from textwrap import dedent
import numba
import numba.np.unsafe.ndarray as numba_ndarray
import numpy as np
import scipy
import scipy.special
from llvmlite import ir
from numba import types
from numba.core.errors import NumbaWarning, TypingError
......@@ -36,7 +34,6 @@ from pytensor.tensor.basic import Nonzero
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.slinalg import Solve
from pytensor.tensor.sort import ArgSortOp, SortOp
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import MakeSlice, NoneConst
......@@ -626,51 +623,6 @@ def numba_funcify_Dot(op, node, **kwargs):
return dot_with_cast
@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a
# check_finite = op.check_finite
if assume_a != "gen":
lower = op.lower
warnings.warn(
(
"Numba will use object mode to allow the "
"`compute_uv` argument to `numpy.linalg.svd`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba_njit
def solve(a, b):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.solve_triangular(
a,
b,
lower=lower,
# check_finite=check_finite
)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_njit
def solve(a, b):
return np.linalg.solve(
inputs_cast(a),
inputs_cast(b),
# assume_a=assume_a,
# check_finite=check_finite,
).astype(out_dtype)
return solve
@numba_funcify.register(BatchedDot)
def numba_funcify_BatchedDot(op, node, **kwargs):
dtype = node.outputs[0].type.numpy_dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论