提交 47bd6fb5 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba overloads: Boolean is not Number

上级 a7827536
...@@ -25,7 +25,7 @@ def numba_deepcopy(x): ...@@ -25,7 +25,7 @@ def numba_deepcopy(x):
@numba.extending.overload(numba_deepcopy) @numba.extending.overload(numba_deepcopy)
def numba_deepcopy_tensor(x): def numba_deepcopy_tensor(x):
if isinstance(x, numba.types.Number): if isinstance(x, numba.types.Number | numba.types.Boolean):
def number_deepcopy(x): def number_deepcopy(x):
return x return x
......
import numba import numba
import numpy as np import numpy as np
from numba.types import Array, Boolean, List, Number
import pytensor.link.numba.dispatch.basic as numba_basic import pytensor.link.numba.dispatch.basic as numba_basic
from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key
...@@ -37,7 +38,7 @@ def numba_all_equal(x, y): ...@@ -37,7 +38,7 @@ def numba_all_equal(x, y):
def list_all_equal(x, y): def list_all_equal(x, y):
all_equal = None all_equal = None
if isinstance(x, numba.types.List) and isinstance(y, numba.types.List): if isinstance(x, List) and isinstance(y, List):
def all_equal(x, y): def all_equal(x, y):
if len(x) != len(y): if len(x) != len(y):
...@@ -47,12 +48,12 @@ def list_all_equal(x, y): ...@@ -47,12 +48,12 @@ def list_all_equal(x, y):
return False return False
return True return True
if isinstance(x, numba.types.Array) and isinstance(y, numba.types.Array): if isinstance(x, Array) and isinstance(y, Array):
def all_equal(x, y): def all_equal(x, y):
return (x == y).all() return (x == y).all()
if isinstance(x, numba.types.Number) and isinstance(y.numba.types.Number): if isinstance(x, Number | Boolean) and isinstance(y, Number | Boolean):
def all_equal(x, y): def all_equal(x, y):
return x == y return x == y
...@@ -62,7 +63,7 @@ def list_all_equal(x, y): ...@@ -62,7 +63,7 @@ def list_all_equal(x, y):
@numba.extending.overload(numba_deepcopy) @numba.extending.overload(numba_deepcopy)
def numba_deepcopy_list(x): def numba_deepcopy_list(x):
if isinstance(x, numba.types.List): if isinstance(x, List):
def deepcopy_list(x): def deepcopy_list(x):
return [numba_deepcopy(xi) for xi in x] return [numba_deepcopy(xi) for xi in x]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论