提交 131982de authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove duplicated Solve dispatch

上级 789c0d17
...@@ -8,8 +8,6 @@ from textwrap import dedent ...@@ -8,8 +8,6 @@ from textwrap import dedent
import numba import numba
import numba.np.unsafe.ndarray as numba_ndarray import numba.np.unsafe.ndarray as numba_ndarray
import numpy as np import numpy as np
import scipy
import scipy.special
from llvmlite import ir from llvmlite import ir
from numba import types from numba import types
from numba.core.errors import NumbaWarning, TypingError from numba.core.errors import NumbaWarning, TypingError
...@@ -36,7 +34,6 @@ from pytensor.tensor.basic import Nonzero ...@@ -36,7 +34,6 @@ from pytensor.tensor.basic import Nonzero
from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.slinalg import Solve
from pytensor.tensor.sort import ArgSortOp, SortOp from pytensor.tensor.sort import ArgSortOp, SortOp
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import MakeSlice, NoneConst from pytensor.tensor.type_other import MakeSlice, NoneConst
...@@ -626,51 +623,6 @@ def numba_funcify_Dot(op, node, **kwargs): ...@@ -626,51 +623,6 @@ def numba_funcify_Dot(op, node, **kwargs):
return dot_with_cast return dot_with_cast
@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
assume_a = op.assume_a
# check_finite = op.check_finite
if assume_a != "gen":
lower = op.lower
warnings.warn(
(
"Numba will use object mode to allow the "
"`compute_uv` argument to `numpy.linalg.svd`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba_njit
def solve(a, b):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.solve_triangular(
a,
b,
lower=lower,
# check_finite=check_finite
)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_njit
def solve(a, b):
return np.linalg.solve(
inputs_cast(a),
inputs_cast(b),
# assume_a=assume_a,
# check_finite=check_finite,
).astype(out_dtype)
return solve
@numba_funcify.register(BatchedDot) @numba_funcify.register(BatchedDot)
def numba_funcify_BatchedDot(op, node, **kwargs): def numba_funcify_BatchedDot(op, node, **kwargs):
dtype = node.outputs[0].type.numpy_dtype dtype = node.outputs[0].type.numpy_dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论