Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9e79f3a0
提交
9e79f3a0
authored
12月 06, 2022
作者:
Adrian Seyboldt
提交者:
Adrian Seyboldt
1月 04, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Initial version of llvm elemwise impl
上级
38dc6c9f
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
481 行增加
和
51 行删除
+481
-51
elemwise.py
pytensor/link/numba/dispatch/elemwise.py
+207
-51
elemwise_codegen.py
pytensor/link/numba/dispatch/elemwise_codegen.py
+231
-0
helpers.py
pytensor/link/numba/dispatch/helpers.py
+43
-0
没有找到文件。
pytensor/link/numba/dispatch/elemwise.py
浏览文件 @
9e79f3a0
import
inspect
from
functools
import
singledispatch
from
functools
import
singledispatch
from
numbers
import
Number
from
numbers
import
Number
import
pickle
from
textwrap
import
indent
from
textwrap
import
indent
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Literal
,
Optional
,
Union
import
base64
import
numba
import
numba
import
numpy
as
np
import
numpy
as
np
from
llvmlite
import
ir
from
numba
import
TypingError
,
literal_unroll
,
types
,
literally
from
numba.core
import
cgutils
from
numba.cpython.unsafe.tuple
import
tuple_setitem
from
numba.np
import
arrayobj
from
numpy.core.numeric
import
normalize_axis_index
,
normalize_axis_tuple
from
numpy.core.numeric
import
normalize_axis_index
,
normalize_axis_tuple
from
pytensor
import
config
from
pytensor
import
config
...
@@ -16,13 +22,12 @@ from pytensor.link.numba.dispatch.basic import (
...
@@ -16,13 +22,12 @@ from pytensor.link.numba.dispatch.basic import (
create_numba_signature
,
create_numba_signature
,
create_tuple_creator
,
create_tuple_creator
,
numba_funcify
,
numba_funcify
,
numba_njit
,
use_optimized_cheap_pass
,
use_optimized_cheap_pass
,
)
)
from
pytensor.link.utils
import
(
from
pytensor.link.numba.dispatch.helpers
import
check_broadcasting
,
tuple_mapper
compile_function_src
,
from
pytensor.link.numba.dispatch
import
elemwise_codegen
get_name_for_object
,
from
pytensor.link.utils
import
compile_function_src
,
get_name_for_object
unique_name_generator
,
)
from
pytensor.scalar.basic
import
(
from
pytensor.scalar.basic
import
(
AND
,
AND
,
OR
,
OR
,
...
@@ -431,6 +436,170 @@ def create_axis_apply_fn(fn, axis, ndim, dtype):
...
@@ -431,6 +436,170 @@ def create_axis_apply_fn(fn, axis, ndim, dtype):
return
axis_apply_fn
return
axis_apply_fn
_jit_options
=
{
"fastmath"
:
{
"arcp"
,
# Allow Reciprocal
"contract"
,
# Allow floating-point contraction
"afn"
,
# Approximate functions
"reassoc"
,
"nsz"
,
# TODO Do we want this one?
}
}
@numba.extending.intrinsic
(
jit_options
=
_jit_options
,
prefer_literal
=
True
)
def
_vectorized
(
typingctx
,
scalar_func
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
inplace_pattern
,
inputs
,
):
#if not isinstance(scalar_func, types.Literal):
# raise TypingError("scalar func must be literal.")
#scalar_func = scalar_func.literal_value
arg_types
=
[
scalar_func
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
inplace_pattern
,
inputs
,
]
if
not
isinstance
(
input_bc_patterns
,
types
.
Literal
):
raise
TypingError
(
"input_bc_patterns must be literal."
)
input_bc_patterns
=
input_bc_patterns
.
literal_value
input_bc_patterns
=
pickle
.
loads
(
base64
.
decodebytes
(
input_bc_patterns
.
encode
()))
if
not
isinstance
(
output_bc_patterns
,
types
.
Literal
):
raise
TypeError
(
"output_bc_patterns must be literal."
)
output_bc_patterns
=
output_bc_patterns
.
literal_value
output_bc_patterns
=
pickle
.
loads
(
base64
.
decodebytes
(
output_bc_patterns
.
encode
()))
if
not
isinstance
(
output_dtypes
,
types
.
Literal
):
raise
TypeError
(
"output_dtypes must be literal."
)
output_dtypes
=
output_dtypes
.
literal_value
output_dtypes
=
pickle
.
loads
(
base64
.
decodebytes
(
output_dtypes
.
encode
()))
if
not
isinstance
(
inplace_pattern
,
types
.
Literal
):
raise
TypeError
(
"inplace_pattern must be literal."
)
inplace_pattern
=
inplace_pattern
.
literal_value
inplace_pattern
=
pickle
.
loads
(
base64
.
decodebytes
(
inplace_pattern
.
encode
()))
n_inputs
=
len
(
inputs
)
n_outputs
=
len
(
output_bc_patterns
)
if
not
len
(
inputs
)
>
0
:
raise
TypingError
(
"Empty argument list to elemwise op."
)
if
not
n_outputs
>
0
:
raise
TypingError
(
"Empty list of outputs for elemwise op."
)
if
not
all
(
isinstance
(
input
,
types
.
Array
)
for
input
in
inputs
):
raise
TypingError
(
"Inputs to elemwise must be arrays."
)
ndim
=
inputs
[
0
]
.
ndim
if
not
all
(
input
.
ndim
==
ndim
for
input
in
inputs
):
raise
TypingError
(
"Inputs to elemwise must have the same rank."
)
if
not
all
(
len
(
pattern
)
==
ndim
for
pattern
in
output_bc_patterns
):
raise
TypingError
(
"Invalid output broadcasting pattern."
)
scalar_signature
=
typingctx
.
resolve_function_type
(
scalar_func
,
[
in_type
.
dtype
for
in_type
in
inputs
],
{}
)
# So we can access the constant values in codegen...
input_bc_patterns_val
=
input_bc_patterns
output_bc_patterns_val
=
output_bc_patterns
output_dtypes_val
=
output_dtypes
inplace_pattern_val
=
inplace_pattern
input_types
=
inputs
#assert not inplace_pattern_val
def
codegen
(
ctx
,
builder
,
sig
,
args
,
):
[
_
,
_
,
_
,
_
,
_
,
inputs
]
=
args
inputs
=
cgutils
.
unpack_tuple
(
builder
,
inputs
)
inputs
=
[
arrayobj
.
make_array
(
ty
)(
ctx
,
builder
,
val
)
for
ty
,
val
in
zip
(
input_types
,
inputs
)]
in_shapes
=
[
cgutils
.
unpack_tuple
(
builder
,
obj
.
shape
)
for
obj
in
inputs
]
iter_shape
=
elemwise_codegen
.
compute_itershape
(
ctx
,
builder
,
in_shapes
,
input_bc_patterns_val
,
)
outputs
,
output_types
=
elemwise_codegen
.
make_outputs
(
ctx
,
builder
,
iter_shape
,
output_bc_patterns_val
,
output_dtypes_val
,
inplace_pattern_val
,
inputs
,
input_types
,
)
def
_check_input_shapes
(
*
_
):
# TODO impl
return
_check_input_shapes
(
ctx
,
builder
,
iter_shape
,
inputs
,
input_bc_patterns_val
,
)
elemwise_codegen
.
make_loop_call
(
typingctx
,
ctx
,
builder
,
scalar_func
,
scalar_signature
,
iter_shape
,
inputs
,
outputs
,
input_bc_patterns_val
,
output_bc_patterns_val
,
input_types
,
output_types
,
)
if
len
(
outputs
)
==
1
:
if
inplace_pattern
:
assert
inplace_pattern
[
0
][
0
]
==
0
ctx
.
nrt
.
incref
(
builder
,
sig
.
return_type
,
outputs
[
0
]
.
_getvalue
())
return
outputs
[
0
]
.
_getvalue
()
for
inplace_idx
in
dict
(
inplace_pattern
):
ctx
.
nrt
.
incref
(
builder
,
sig
.
return_type
.
types
[
inplace_idx
],
outputs
[
inplace_idx
]
.
_get_value
())
return
ctx
.
make_tuple
(
builder
,
sig
.
return_type
,
[
out
.
_getvalue
()
for
out
in
outputs
])
# TODO check inplace_pattern
ret_type
=
types
.
Tuple
([
types
.
Array
(
numba
.
from_dtype
(
np
.
dtype
(
dtype
)),
ndim
,
"C"
)
for
dtype
in
output_dtypes
])
if
len
(
output_dtypes
)
==
1
:
ret_type
=
ret_type
.
types
[
0
]
sig
=
ret_type
(
*
arg_types
)
return
sig
,
codegen
@numba_funcify.register
(
Elemwise
)
@numba_funcify.register
(
Elemwise
)
def
numba_funcify_Elemwise
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Elemwise
(
op
,
node
,
**
kwargs
):
# Creating a new scalar node is more involved and unnecessary
# Creating a new scalar node is more involved and unnecessary
...
@@ -441,55 +610,42 @@ def numba_funcify_Elemwise(op, node, **kwargs):
...
@@ -441,55 +610,42 @@ def numba_funcify_Elemwise(op, node, **kwargs):
scalar_inputs
=
[
scalar
(
dtype
=
input
.
dtype
)
for
input
in
node
.
inputs
]
scalar_inputs
=
[
scalar
(
dtype
=
input
.
dtype
)
for
input
in
node
.
inputs
]
scalar_node
=
op
.
scalar_op
.
make_node
(
*
scalar_inputs
)
scalar_node
=
op
.
scalar_op
.
make_node
(
*
scalar_inputs
)
flags
=
{
"arcp"
,
# Allow Reciprocal
"contract"
,
# Allow floating-point contraction
"afn"
,
# Approximate functions
"reassoc"
,
"nsz"
,
# TODO Do we want this one?
}
scalar_op_fn
=
numba_funcify
(
scalar_op_fn
=
numba_funcify
(
op
.
scalar_op
,
node
=
scalar_node
,
parent_node
=
node
,
inline
=
"always"
,
**
kwargs
op
.
scalar_op
,
node
=
scalar_node
,
parent_node
=
node
,
fastmath
=
flags
,
**
kwargs
)
)
elemwise_fn
=
create_vectorize_func
(
scalar_op_fn
,
node
,
use_signature
=
False
)
elemwise_fn_name
=
elemwise_fn
.
__name__
if
op
.
inplace_pattern
:
input_idx
=
op
.
inplace_pattern
[
0
]
sign_obj
=
inspect
.
signature
(
elemwise_fn
.
py_scalar_func
)
input_names
=
list
(
sign_obj
.
parameters
.
keys
())
unique_names
=
unique_name_generator
([
elemwise_fn_name
,
"np"
],
suffix_sep
=
"_"
)
input_names
=
[
unique_names
(
i
,
force_unique
=
True
)
for
i
in
input_names
]
updated_input_name
=
input_names
[
input_idx
]
ndim
=
node
.
outputs
[
0
]
.
ndim
output_bc_patterns
=
tuple
([(
False
,)
*
ndim
for
_
in
node
.
outputs
])
inplace_global_env
=
{
elemwise_fn_name
:
elemwise_fn
,
"np"
:
np
}
input_bc_patterns
=
tuple
([
input_var
.
broadcastable
for
input_var
in
node
.
inputs
])
output_dtypes
=
tuple
(
variable
.
dtype
for
variable
in
node
.
outputs
)
inplace_elemwise_fn_name
=
f
"{elemwise_fn_name}_inplace"
inplace_pattern
=
tuple
(
op
.
inplace_pattern
.
items
())
input_signature_str
=
", "
.
join
(
input_names
)
# numba doesn't support nested literals right now...
input_bc_patterns
=
base64
.
encodebytes
(
pickle
.
dumps
(
input_bc_patterns
))
.
decode
()
if
node
.
inputs
[
input_idx
]
.
ndim
>
0
:
output_bc_patterns
=
base64
.
encodebytes
(
pickle
.
dumps
(
output_bc_patterns
))
.
decode
()
inplace_elemwise_src
=
f
"""
output_dtypes
=
base64
.
encodebytes
(
pickle
.
dumps
(
output_dtypes
))
.
decode
()
def {inplace_elemwise_fn_name}({input_signature_str}):
inplace_pattern
=
base64
.
encodebytes
(
pickle
.
dumps
(
inplace_pattern
))
.
decode
()
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name})
"""
@numba_njit
else
:
def
elemwise_wrapper
(
*
inputs
):
# We can't perform in-place updates on Numba scalars, so we need to
return
_vectorized
(
# convert them to NumPy scalars.
scalar_op_fn
,
# TODO: We should really prevent the rewrites from creating
input_bc_patterns
,
# in-place updates on scalars when the Numba mode is selected (or
output_bc_patterns
,
# in general?).
output_dtypes
,
inplace_elemwise_src
=
f
"""
inplace_pattern
,
def {inplace_elemwise_fn_name}({input_signature_str}):
inputs
,
{updated_input_name}_scalar = np.asarray({updated_input_name})
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item()
"""
inplace_elemwise_fn
=
compile_function_src
(
inplace_elemwise_src
,
inplace_elemwise_fn_name
,
{
**
globals
(),
**
inplace_global_env
},
)
return
numba_basic
.
numba_njit
(
inline
=
"always"
,
fastmath
=
config
.
numba__fastmath
)(
inplace_elemwise_fn
)
)
return
elemwise_
fn
return
elemwise_
wrapper
@numba_funcify.register
(
CAReduce
)
@numba_funcify.register
(
CAReduce
)
...
...
pytensor/link/numba/dispatch/elemwise_codegen.py
0 → 100644
浏览文件 @
9e79f3a0
from
llvmlite
import
ir
from
numba
import
types
from
numba.np
import
arrayobj
from
numba.core
import
cgutils
import
numba
import
numpy
as
np
def
compute_itershape
(
ctx
,
builder
:
ir
.
IRBuilder
,
in_shapes
,
broadcast_pattern
,
):
one
=
ir
.
IntType
(
64
)(
1
)
ndim
=
len
(
in_shapes
[
0
])
#shape = [ir.IntType(64)(1) for _ in range(ndim)]
shape
=
[
None
]
*
ndim
for
i
in
range
(
ndim
):
# TODO Error checking...
# What if all shapes are 0?
for
bc
,
in_shape
in
zip
(
broadcast_pattern
,
in_shapes
):
if
bc
[
i
]:
# TODO
# raise error if length != 1
pass
else
:
# TODO
# if shape[i] is not None:
# raise Error if !=
shape
[
i
]
=
in_shape
[
i
]
for
i
in
range
(
ndim
):
if
shape
[
i
]
is
None
:
shape
[
i
]
=
one
return
shape
def
make_outputs
(
ctx
,
builder
:
ir
.
IRBuilder
,
iter_shape
,
out_bc
,
dtypes
,
inplace
,
inputs
,
input_types
):
arrays
=
[]
ar_types
:
list
[
types
.
Array
]
=
[]
one
=
ir
.
IntType
(
64
)(
1
)
inplace
=
dict
(
inplace
)
for
i
,
(
bc
,
dtype
)
in
enumerate
(
zip
(
out_bc
,
dtypes
)):
if
i
in
inplace
:
arrays
.
append
(
inputs
[
inplace
[
i
]])
ar_types
.
append
(
input_types
[
inplace
[
i
]])
# We need to incref once we return the inplace objects
continue
dtype
=
numba
.
from_dtype
(
np
.
dtype
(
dtype
))
arrtype
=
types
.
Array
(
dtype
,
len
(
iter_shape
),
"C"
)
ar_types
.
append
(
arrtype
)
# This is actually an interal numba function, I guess we could
# call `numba.nd.unsafe.ndarray` instead?
shape
=
[
length
if
not
bc_dim
else
one
for
length
,
bc_dim
in
zip
(
iter_shape
,
bc
)
]
array
=
arrayobj
.
_empty_nd_impl
(
ctx
,
builder
,
arrtype
,
shape
)
arrays
.
append
(
array
)
# If there is no inplace operation, we know that all output arrays
# don't alias. Informing llvm can make it easier to vectorize.
if
not
inplace
:
# The first argument is the output pointer
arg
=
builder
.
function
.
args
[
0
]
arg
.
add_attribute
(
"noalias"
)
return
arrays
,
ar_types
def
make_loop_call
(
typingctx
,
context
:
numba
.
core
.
base
.
BaseContext
,
builder
:
ir
.
IRBuilder
,
scalar_func
,
scalar_signature
,
iter_shape
,
inputs
,
outputs
,
input_bc
,
output_bc
,
input_types
,
output_types
,
):
safe
=
(
False
,
False
)
n_outputs
=
len
(
outputs
)
#context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape)
# Lower the code of the scalar function so that we can use it in the inner loop
# Caching is set to false to avoid a numba bug TODO ref?
inner_func
=
context
.
compile_subroutine
(
builder
,
# I don't quite understand why we need to access `dispatcher` here.
# The object does seem to be a dispatcher already? But it is missing
# attributes...
scalar_func
.
dispatcher
,
scalar_signature
,
caching
=
False
,
)
inner
=
inner_func
.
fndesc
# Extract shape and stride information from the array.
# For later use in the loop body to do the indexing
def
extract_array
(
aryty
,
obj
):
shape
=
cgutils
.
unpack_tuple
(
builder
,
obj
.
shape
)
strides
=
cgutils
.
unpack_tuple
(
builder
,
obj
.
strides
)
data
=
obj
.
data
layout
=
aryty
.
layout
return
(
data
,
shape
,
strides
,
layout
)
# TODO I think this is better than the noalias attribute
# for the input, but self_ref isn't supported in a released
# llvmlite version yet
# mod = builder.module
# domain = mod.add_metadata([], self_ref=True)
# input_scope = mod.add_metadata([domain], self_ref=True)
# output_scope = mod.add_metadata([domain], self_ref=True)
# input_scope_set = mod.add_metadata([input_scope, output_scope])
# output_scope_set = mod.add_metadata([input_scope, output_scope])
inputs
=
[
extract_array
(
aryty
,
ary
)
for
aryty
,
ary
in
zip
(
input_types
,
inputs
,
strict
=
True
)
]
outputs
=
[
extract_array
(
aryty
,
ary
)
for
aryty
,
ary
in
zip
(
output_types
,
outputs
,
strict
=
True
)
]
zero
=
ir
.
Constant
(
ir
.
IntType
(
64
),
0
)
# Setup loops and initialize accumulators for outputs
# This part corresponds to opening the loops
loop_stack
=
[]
loops
=
[]
output_accumulator
=
[(
None
,
None
)]
*
n_outputs
for
dim
,
length
in
enumerate
(
iter_shape
):
# Find outputs that only have accumulations left
for
output
in
range
(
n_outputs
):
if
output_accumulator
[
output
][
0
]
is
not
None
:
continue
if
all
(
output_bc
[
output
][
dim
:]):
value
=
outputs
[
output
][
0
]
.
type
.
pointee
(
0
)
accu
=
cgutils
.
alloca_once_value
(
builder
,
value
)
output_accumulator
[
output
]
=
(
accu
,
dim
)
loop
=
cgutils
.
for_range
(
builder
,
length
)
loop_stack
.
append
(
loop
)
loops
.
append
(
loop
.
__enter__
())
# Code in the inner most loop...
idxs
=
[
loopval
.
index
for
loopval
in
loops
]
# Load values from input arrays
input_vals
=
[]
for
array_info
,
bc
in
zip
(
inputs
,
input_bc
,
strict
=
True
):
idxs_bc
=
[
zero
if
bc
else
idx
for
idx
,
bc
in
zip
(
idxs
,
bc
,
strict
=
True
)
]
ptr
=
cgutils
.
get_item_pointer2
(
context
,
builder
,
*
array_info
,
idxs_bc
,
*
safe
)
val
=
builder
.
load
(
ptr
)
# val.set_metadata("alias.scope", input_scope_set)
# val.set_metadata("noalias", output_scope_set)
input_vals
.
append
(
val
)
# Call scalar function
output_values
=
context
.
call_internal
(
builder
,
inner
,
scalar_signature
,
input_vals
,
)
if
isinstance
(
scalar_signature
.
return_type
,
types
.
Tuple
):
output_values
=
cgutils
.
unpack_tuple
(
builder
,
output_values
)
else
:
output_values
=
[
output_values
]
# Update output value or accumulators respectively
for
i
,
((
accu
,
_
),
value
)
in
enumerate
(
zip
(
output_accumulator
,
output_values
,
strict
=
True
)
):
if
accu
is
not
None
:
load
=
builder
.
load
(
accu
)
# load.set_metadata("alias.scope", output_scope_set)
# load.set_metadata("noalias", input_scope_set)
new_value
=
builder
.
fadd
(
load
,
value
)
builder
.
store
(
new_value
,
accu
)
# TODO belongs to noalias scope
# store.set_metadata("alias.scope", output_scope_set)
# store.set_metadata("noalias", input_scope_set)
else
:
idxs_bc
=
[
zero
if
bc
else
idx
for
idx
,
bc
in
zip
(
idxs
,
output_bc
[
i
],
strict
=
True
)
]
ptr
=
cgutils
.
get_item_pointer2
(
context
,
builder
,
*
outputs
[
i
],
idxs_bc
)
# store = builder.store(value, ptr)
arrayobj
.
store_item
(
context
,
builder
,
output_types
[
i
],
value
,
ptr
)
# store.set_metadata("alias.scope", output_scope_set)
# store.set_metadata("noalias", input_scope_set)
# Close the loops and write accumulator values to the output arrays
for
depth
,
loop
in
enumerate
(
loop_stack
[::
-
1
]):
for
output
,
(
accu
,
accu_depth
)
in
enumerate
(
output_accumulator
):
if
accu_depth
==
depth
:
idxs_bc
=
[
zero
if
bc
else
idx
for
idx
,
bc
in
zip
(
idxs
,
output_bc
[
output
],
strict
=
True
)
]
ptr
=
cgutils
.
get_item_pointer2
(
context
,
builder
,
*
outputs
[
output
],
idxs_bc
)
load
=
builder
.
load
(
accu
)
# load.set_metadata("alias.scope", output_scope_set)
# load.set_metadata("noalias", input_scope_set)
# store = builder.store(load, ptr)
arrayobj
.
store_item
(
context
,
builder
,
output_types
[
output
],
load
,
ptr
)
# store.set_metadata("alias.scope", output_scope_set)
# store.set_metadata("noalias", input_scope_set)
loop
.
__exit__
(
None
,
None
,
None
)
return
pytensor/link/numba/dispatch/helpers.py
0 → 100644
浏览文件 @
9e79f3a0
from
numba
import
njit
,
types
from
numba.core
import
cgutils
from
numba.extending
import
intrinsic
def
tuple_mapper
(
item_map_func
):
@intrinsic
def
map_tuple
(
typingctx
,
*
input_tuples
):
signatures
=
[
typingctx
.
resolve_function_type
(
item_map_func
,
args
,
{})
for
args
in
zip
(
*
[
in_type
.
types
for
in_type
in
input_tuples
],
strict
=
True
)
]
output_type
=
types
.
Tuple
([
sig
.
return_type
for
sig
in
signatures
])
signature
=
output_type
(
types
.
StarArgTuple
(
input_tuples
))
def
codegen
(
context
,
builder
,
signature
,
args
):
(
input_tuples
,)
=
args
input_values
=
[]
for
val
in
cgutils
.
unpack_tuple
(
builder
,
input_tuples
):
input_values
.
append
(
cgutils
.
unpack_tuple
(
builder
,
val
))
mapped_values
=
[]
for
values
,
sig
in
zip
(
zip
(
*
input_values
),
signatures
,
strict
=
True
):
func
=
context
.
compile_subroutine
(
builder
,
item_map_func
,
sig
)
output
=
context
.
call_internal
(
builder
,
func
.
fndesc
,
sig
,
values
)
mapped_values
.
append
(
output
)
return
context
.
make_tuple
(
builder
,
output_type
,
mapped_values
)
return
signature
,
codegen
return
map_tuple
@njit
def
check_broadcasting
(
array
,
bcs
,
shape
):
assert
array
.
ndim
==
len
(
shape
)
for
bc
,
array_length
,
length
in
zip
(
bcs
,
array
.
shape
,
shape
):
if
bc
:
assert
array_length
==
1
else
:
assert
array_length
==
length
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论