Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
47874eb9
提交
47874eb9
authored
4月 05, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Adapt Numba vectorize iterator for RandomVariables
Co-authored-by:
Jesse Grabowski
<
48652735+jessegrabowski@users.noreply.github.com
>
Co-authored-by:
Adrian Seyboldt
<
aseyboldt@users.noreply.github.com
>
上级
38c04c96
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
304 行增加
和
131 行删除
+304
-131
basic.py
pytensor/link/numba/dispatch/basic.py
+8
-2
elemwise.py
pytensor/link/numba/dispatch/elemwise.py
+10
-1
vectorize_codegen.py
pytensor/link/numba/dispatch/vectorize_codegen.py
+286
-128
没有找到文件。
pytensor/link/numba/dispatch/basic.py
浏览文件 @
47874eb9
...
@@ -62,10 +62,16 @@ def numba_njit(*args, **kwargs):
...
@@ -62,10 +62,16 @@ def numba_njit(*args, **kwargs):
kwargs
.
setdefault
(
"no_cpython_wrapper"
,
True
)
kwargs
.
setdefault
(
"no_cpython_wrapper"
,
True
)
kwargs
.
setdefault
(
"no_cfunc_wrapper"
,
True
)
kwargs
.
setdefault
(
"no_cfunc_wrapper"
,
True
)
# Supress caching warnings
# Suppress cache warning for internal functions
# We have to add an ansi escape code for optional bold text by numba
warnings
.
filterwarnings
(
warnings
.
filterwarnings
(
"ignore"
,
"ignore"
,
message
=
'Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals'
,
message
=
(
"(
\x1b\\
[1m)*"
# ansi escape code for bold text
"Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs)" '
"as it uses dynamic globals"
),
category
=
NumbaWarning
,
category
=
NumbaWarning
,
)
)
...
...
pytensor/link/numba/dispatch/elemwise.py
浏览文件 @
47874eb9
...
@@ -24,6 +24,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import (
...
@@ -24,6 +24,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options
,
_jit_options
,
_vectorized
,
_vectorized
,
encode_literals
,
encode_literals
,
store_core_outputs
,
)
)
from
pytensor.link.utils
import
compile_function_src
,
get_name_for_object
from
pytensor.link.utils
import
compile_function_src
,
get_name_for_object
from
pytensor.scalar.basic
import
(
from
pytensor.scalar.basic
import
(
...
@@ -480,10 +481,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
...
@@ -480,10 +481,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
**
kwargs
,
**
kwargs
,
)
)
nin
=
len
(
node
.
inputs
)
nout
=
len
(
node
.
outputs
)
core_op_fn
=
store_core_outputs
(
scalar_op_fn
,
nin
=
nin
,
nout
=
nout
)
input_bc_patterns
=
tuple
([
inp
.
type
.
broadcastable
for
inp
in
node
.
inputs
])
input_bc_patterns
=
tuple
([
inp
.
type
.
broadcastable
for
inp
in
node
.
inputs
])
output_bc_patterns
=
tuple
([
out
.
type
.
broadcastable
for
out
in
node
.
outputs
])
output_bc_patterns
=
tuple
([
out
.
type
.
broadcastable
for
out
in
node
.
outputs
])
output_dtypes
=
tuple
(
out
.
type
.
dtype
for
out
in
node
.
outputs
)
output_dtypes
=
tuple
(
out
.
type
.
dtype
for
out
in
node
.
outputs
)
inplace_pattern
=
tuple
(
op
.
inplace_pattern
.
items
())
inplace_pattern
=
tuple
(
op
.
inplace_pattern
.
items
())
core_output_shapes
=
tuple
(()
for
_
in
range
(
nout
))
# numba doesn't support nested literals right now...
# numba doesn't support nested literals right now...
input_bc_patterns_enc
=
encode_literals
(
input_bc_patterns
)
input_bc_patterns_enc
=
encode_literals
(
input_bc_patterns
)
...
@@ -493,12 +499,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
...
@@ -493,12 +499,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
def
elemwise_wrapper
(
*
inputs
):
def
elemwise_wrapper
(
*
inputs
):
return
_vectorized
(
return
_vectorized
(
scalar
_op_fn
,
core
_op_fn
,
input_bc_patterns_enc
,
input_bc_patterns_enc
,
output_bc_patterns_enc
,
output_bc_patterns_enc
,
output_dtypes_enc
,
output_dtypes_enc
,
inplace_pattern_enc
,
inplace_pattern_enc
,
(),
# constant_inputs
inputs
,
inputs
,
core_output_shapes
,
# core_shapes
None
,
# size
)
)
# Pure python implementation, that will be used in tests
# Pure python implementation, that will be used in tests
...
...
pytensor/link/numba/dispatch/vectorize_codegen.py
浏览文件 @
47874eb9
...
@@ -2,8 +2,9 @@ from __future__ import annotations
...
@@ -2,8 +2,9 @@ from __future__ import annotations
import
base64
import
base64
import
pickle
import
pickle
from
collections.abc
import
Sequence
from
collections.abc
import
Callable
,
Sequence
from
typing
import
Any
from
textwrap
import
indent
from
typing
import
Any
,
cast
import
numba
import
numba
import
numpy
as
np
import
numpy
as
np
...
@@ -11,13 +12,54 @@ from llvmlite import ir
...
@@ -11,13 +12,54 @@ from llvmlite import ir
from
numba
import
TypingError
,
types
from
numba
import
TypingError
,
types
from
numba.core
import
cgutils
from
numba.core
import
cgutils
from
numba.core.base
import
BaseContext
from
numba.core.base
import
BaseContext
from
numba.core.types.misc
import
NoneType
from
numba.np
import
arrayobj
from
numba.np
import
arrayobj
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.utils
import
compile_function_src
def
encode_literals
(
literals
:
Sequence
)
->
str
:
def
encode_literals
(
literals
:
Sequence
)
->
str
:
return
base64
.
encodebytes
(
pickle
.
dumps
(
literals
))
.
decode
()
return
base64
.
encodebytes
(
pickle
.
dumps
(
literals
))
.
decode
()
def
store_core_outputs
(
core_op_fn
:
Callable
,
nin
:
int
,
nout
:
int
)
->
Callable
:
"""Create a Numba function that wraps a core function and stores its vectorized outputs.
@njit
def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on):
to0, to1, ..., ton = core_op_fn(i0, i1, ..., in)
o0[...] = to0
o1[...] = to1
...
on[...] = ton
"""
inputs
=
[
f
"i{i}"
for
i
in
range
(
nin
)]
outputs
=
[
f
"o{i}"
for
i
in
range
(
nout
)]
inner_outputs
=
[
f
"t{output}"
for
output
in
outputs
]
inp_signature
=
", "
.
join
(
inputs
)
out_signature
=
", "
.
join
(
outputs
)
inner_out_signature
=
", "
.
join
(
inner_outputs
)
store_outputs
=
"
\n
"
.
join
(
[
f
"{output}[...] = {inner_output}"
for
output
,
inner_output
in
zip
(
outputs
,
inner_outputs
)
]
)
func_src
=
f
"""
def store_core_outputs({inp_signature}, {out_signature}):
{inner_out_signature} = core_op_fn({inp_signature})
{indent(store_outputs, " " * 4)}
"""
global_env
=
{
"core_op_fn"
:
core_op_fn
}
func
=
compile_function_src
(
func_src
,
"store_core_outputs"
,
{
**
globals
(),
**
global_env
}
)
return
cast
(
Callable
,
numba_basic
.
numba_njit
(
func
))
_jit_options
=
{
_jit_options
=
{
"fastmath"
:
{
"fastmath"
:
{
"arcp"
,
# Allow Reciprocal
"arcp"
,
# Allow Reciprocal
...
@@ -39,7 +81,10 @@ def _vectorized(
...
@@ -39,7 +81,10 @@ def _vectorized(
output_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
output_dtypes
,
inplace_pattern
,
inplace_pattern
,
inputs
,
constant_inputs_types
,
input_types
,
output_core_shape_types
,
size_type
,
):
):
arg_types
=
[
arg_types
=
[
scalar_func
,
scalar_func
,
...
@@ -47,7 +92,10 @@ def _vectorized(
...
@@ -47,7 +92,10 @@ def _vectorized(
output_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
output_dtypes
,
inplace_pattern
,
inplace_pattern
,
inputs
,
constant_inputs_types
,
input_types
,
output_core_shape_types
,
size_type
,
]
]
if
not
isinstance
(
input_bc_patterns
,
types
.
Literal
):
if
not
isinstance
(
input_bc_patterns
,
types
.
Literal
):
...
@@ -70,34 +118,82 @@ def _vectorized(
...
@@ -70,34 +118,82 @@ def _vectorized(
inplace_pattern
=
inplace_pattern
.
literal_value
inplace_pattern
=
inplace_pattern
.
literal_value
inplace_pattern
=
pickle
.
loads
(
base64
.
decodebytes
(
inplace_pattern
.
encode
()))
inplace_pattern
=
pickle
.
loads
(
base64
.
decodebytes
(
inplace_pattern
.
encode
()))
n_outputs
=
len
(
output_bc_patterns
)
batch_ndim
=
len
(
input_bc_patterns
[
0
])
nin
=
len
(
constant_inputs_types
)
+
len
(
input_types
)
nout
=
len
(
output_bc_patterns
)
if
nin
==
0
:
raise
TypingError
(
"Empty argument list to vectorized op."
)
if
nout
==
0
:
raise
TypingError
(
"Empty list of outputs for vectorized op."
)
if
not
len
(
inputs
)
>
0
:
if
not
all
(
isinstance
(
input
,
types
.
Array
)
for
input
in
input_types
)
:
raise
TypingError
(
"
Empty argument list to elemwise op
."
)
raise
TypingError
(
"
Vectorized inputs must be arrays
."
)
if
not
n_outputs
>
0
:
if
not
all
(
raise
TypingError
(
"Empty list of outputs for elemwise op."
)
len
(
pattern
)
==
batch_ndim
for
pattern
in
input_bc_patterns
+
output_bc_patterns
):
raise
TypingError
(
"Vectorized broadcastable patterns must have the same length."
)
core_input_types
=
[]
for
input_type
,
bc_pattern
in
zip
(
input_types
,
input_bc_patterns
):
core_ndim
=
input_type
.
ndim
-
len
(
bc_pattern
)
# TODO: Reconsider this
if
core_ndim
==
0
:
core_input_type
=
input_type
.
dtype
else
:
core_input_type
=
types
.
Array
(
dtype
=
input_type
.
dtype
,
ndim
=
core_ndim
,
layout
=
input_type
.
layout
)
core_input_types
.
append
(
core_input_type
)
if
not
all
(
isinstance
(
input
,
types
.
Array
)
for
input
in
inputs
):
core_out_types
=
[
raise
TypingError
(
"Inputs to elemwise must be arrays."
)
types
.
Array
(
numba
.
from_dtype
(
np
.
dtype
(
dtype
)),
len
(
output_core_shape
),
"C"
)
ndim
=
inputs
[
0
]
.
ndim
for
dtype
,
output_core_shape
in
zip
(
output_dtypes
,
output_core_shape_types
)
]
if
not
all
(
input
.
ndim
==
ndim
for
input
in
inputs
):
out_types
=
[
raise
TypingError
(
"Inputs to elemwise must have the same rank."
)
types
.
Array
(
numba
.
from_dtype
(
np
.
dtype
(
dtype
)),
batch_ndim
+
len
(
output_core_shape
),
"C"
)
for
dtype
,
output_core_shape
in
zip
(
output_dtypes
,
output_core_shape_types
)
]
if
not
all
(
len
(
pattern
)
==
ndim
for
pattern
in
output_bc_patterns
):
for
output_idx
,
input_idx
in
inplace_pattern
:
raise
TypingError
(
"Invalid output broadcasting pattern."
)
output_type
=
input_types
[
input_idx
]
core_out_types
[
output_idx
]
=
types
.
Array
(
dtype
=
output_type
.
dtype
,
ndim
=
output_type
.
ndim
-
batch_ndim
,
layout
=
input_type
.
layout
,
)
out_types
[
output_idx
]
=
output_type
scalar_signature
=
typingctx
.
resolve_function_type
(
core_signature
=
typingctx
.
resolve_function_type
(
scalar_func
,
[
in_type
.
dtype
for
in_type
in
inputs
],
{}
scalar_func
,
[
*
constant_inputs_types
,
*
core_input_types
,
*
core_out_types
,
],
{},
)
)
ret_type
=
types
.
Tuple
(
out_types
)
if
len
(
output_dtypes
)
==
1
:
ret_type
=
ret_type
.
types
[
0
]
sig
=
ret_type
(
*
arg_types
)
# So we can access the constant values in codegen...
# So we can access the constant values in codegen...
input_bc_patterns_val
=
input_bc_patterns
input_bc_patterns_val
=
input_bc_patterns
output_bc_patterns_val
=
output_bc_patterns
output_bc_patterns_val
=
output_bc_patterns
output_dtypes_val
=
output_dtypes
output_dtypes_val
=
output_dtypes
inplace_pattern_val
=
inplace_pattern
inplace_pattern_val
=
inplace_pattern
input_types
=
inputs
input_types
=
input_types
size_is_none
=
isinstance
(
size_type
,
NoneType
)
def
codegen
(
def
codegen
(
ctx
,
ctx
,
...
@@ -105,8 +201,16 @@ def _vectorized(
...
@@ -105,8 +201,16 @@ def _vectorized(
sig
,
sig
,
args
,
args
,
):
):
[
_
,
_
,
_
,
_
,
_
,
inputs
]
=
args
[
_
,
_
,
_
,
_
,
_
,
constant_inputs
,
inputs
,
output_core_shapes
,
size
]
=
args
constant_inputs
=
cgutils
.
unpack_tuple
(
builder
,
constant_inputs
)
inputs
=
cgutils
.
unpack_tuple
(
builder
,
inputs
)
inputs
=
cgutils
.
unpack_tuple
(
builder
,
inputs
)
output_core_shapes
=
[
cgutils
.
unpack_tuple
(
builder
,
shape
)
for
shape
in
cgutils
.
unpack_tuple
(
builder
,
output_core_shapes
)
]
size
=
None
if
size_is_none
else
cgutils
.
unpack_tuple
(
builder
,
size
)
inputs
=
[
inputs
=
[
arrayobj
.
make_array
(
ty
)(
ctx
,
builder
,
val
)
arrayobj
.
make_array
(
ty
)(
ctx
,
builder
,
val
)
for
ty
,
val
in
zip
(
input_types
,
inputs
)
for
ty
,
val
in
zip
(
input_types
,
inputs
)
...
@@ -118,6 +222,7 @@ def _vectorized(
...
@@ -118,6 +222,7 @@ def _vectorized(
builder
,
builder
,
in_shapes
,
in_shapes
,
input_bc_patterns_val
,
input_bc_patterns_val
,
size
,
)
)
outputs
,
output_types
=
make_outputs
(
outputs
,
output_types
=
make_outputs
(
...
@@ -129,6 +234,7 @@ def _vectorized(
...
@@ -129,6 +234,7 @@ def _vectorized(
inplace_pattern_val
,
inplace_pattern_val
,
inputs
,
inputs
,
input_types
,
input_types
,
output_core_shapes
,
)
)
make_loop_call
(
make_loop_call
(
...
@@ -136,8 +242,9 @@ def _vectorized(
...
@@ -136,8 +242,9 @@ def _vectorized(
ctx
,
ctx
,
builder
,
builder
,
scalar_func
,
scalar_func
,
scalar
_signature
,
core
_signature
,
iter_shape
,
iter_shape
,
constant_inputs
,
inputs
,
inputs
,
outputs
,
outputs
,
input_bc_patterns_val
,
input_bc_patterns_val
,
...
@@ -162,67 +269,92 @@ def _vectorized(
...
@@ -162,67 +269,92 @@ def _vectorized(
builder
,
sig
.
return_type
,
[
out
.
_getvalue
()
for
out
in
outputs
]
builder
,
sig
.
return_type
,
[
out
.
_getvalue
()
for
out
in
outputs
]
)
)
ret_types
=
[
types
.
Array
(
numba
.
from_dtype
(
np
.
dtype
(
dtype
)),
ndim
,
"C"
)
for
dtype
in
output_dtypes
]
for
output_idx
,
input_idx
in
inplace_pattern
:
ret_types
[
output_idx
]
=
input_types
[
input_idx
]
ret_type
=
types
.
Tuple
(
ret_types
)
if
len
(
output_dtypes
)
==
1
:
ret_type
=
ret_type
.
types
[
0
]
sig
=
ret_type
(
*
arg_types
)
return
sig
,
codegen
return
sig
,
codegen
def
compute_itershape
(
def
compute_itershape
(
ctx
:
BaseContext
,
ctx
:
BaseContext
,
builder
:
ir
.
IRBuilder
,
builder
:
ir
.
IRBuilder
,
in_shapes
:
tuple
[
ir
.
Instruction
,
...
],
in_shapes
:
list
[
list
[
ir
.
Instruction
]
],
broadcast_pattern
:
tuple
[
tuple
[
bool
,
...
],
...
],
broadcast_pattern
:
tuple
[
tuple
[
bool
,
...
],
...
],
size
:
list
[
ir
.
Instruction
]
|
None
,
):
):
one
=
ir
.
IntType
(
64
)(
1
)
one
=
ir
.
IntType
(
64
)(
1
)
ndim
=
len
(
in_shapes
[
0
])
batch_ndim
=
len
(
broadcast_pattern
[
0
])
shape
=
[
None
]
*
ndim
shape
=
[
None
]
*
batch_ndim
for
i
in
range
(
ndim
):
if
size
is
not
None
:
shape
=
size
for
i
in
range
(
batch_ndim
):
for
j
,
(
bc
,
in_shape
)
in
enumerate
(
zip
(
broadcast_pattern
,
in_shapes
)):
for
j
,
(
bc
,
in_shape
)
in
enumerate
(
zip
(
broadcast_pattern
,
in_shapes
)):
length
=
in_shape
[
i
]
length
=
in_shape
[
i
]
if
bc
[
i
]:
if
bc
[
i
]:
with
builder
.
if_then
(
with
builder
.
if_then
(
builder
.
icmp_unsigned
(
"!="
,
length
,
one
),
likely
=
False
builder
.
icmp_unsigned
(
"!="
,
length
,
one
),
likely
=
False
):
):
msg
=
f
"Vectorized input {j} is expected to have shape 1 in axis {i}"
ctx
.
call_conv
.
return_user_exc
(
builder
,
ValueError
,
(
msg
,))
else
:
with
builder
.
if_then
(
builder
.
icmp_unsigned
(
"!="
,
length
,
shape
[
i
]),
likely
=
False
):
with
builder
.
if_else
(
builder
.
icmp_unsigned
(
"=="
,
length
,
one
)
)
as
(
then
,
otherwise
,
):
with
then
:
msg
=
(
msg
=
(
f
"Input {j} to elemwise is expected to have shape 1 in axis {i}"
f
"Incompatible vectorized shapes for input {j} and axis {i}. "
f
"Input {j} has shape 1, but is not statically "
"known to have shape 1, and thus not broadcastable."
)
ctx
.
call_conv
.
return_user_exc
(
builder
,
ValueError
,
(
msg
,)
)
with
otherwise
:
msg
=
f
"Vectorized input {j} has an incompatible shape in axis {i}."
ctx
.
call_conv
.
return_user_exc
(
builder
,
ValueError
,
(
msg
,)
)
)
else
:
# Size is implied by the broadcast pattern
for
i
in
range
(
batch_ndim
):
for
j
,
(
bc
,
in_shape
)
in
enumerate
(
zip
(
broadcast_pattern
,
in_shapes
)):
length
=
in_shape
[
i
]
if
bc
[
i
]:
with
builder
.
if_then
(
builder
.
icmp_unsigned
(
"!="
,
length
,
one
),
likely
=
False
):
msg
=
f
"Vectorized input {j} is expected to have shape 1 in axis {i}"
ctx
.
call_conv
.
return_user_exc
(
builder
,
ValueError
,
(
msg
,))
ctx
.
call_conv
.
return_user_exc
(
builder
,
ValueError
,
(
msg
,))
elif
shape
[
i
]
is
not
None
:
elif
shape
[
i
]
is
not
None
:
with
builder
.
if_then
(
with
builder
.
if_then
(
builder
.
icmp_unsigned
(
"!="
,
length
,
shape
[
i
]),
likely
=
False
builder
.
icmp_unsigned
(
"!="
,
length
,
shape
[
i
]),
likely
=
False
):
):
with
builder
.
if_else
(
builder
.
icmp_unsigned
(
"=="
,
length
,
one
))
as
(
with
builder
.
if_else
(
builder
.
icmp_unsigned
(
"=="
,
length
,
one
)
)
as
(
then
,
then
,
otherwise
,
otherwise
,
):
):
with
then
:
with
then
:
msg
=
(
msg
=
(
f
"Incompatible shapes for input {j} and axis {i} of
"
f
"Incompatible vectorized shapes for input {j} and axis {i}.
"
f
"elemwise.
Input {j} has shape 1, but is not statically "
f
"
Input {j} has shape 1, but is not statically "
"known to have shape 1, and thus not broadcastable."
"known to have shape 1, and thus not broadcastable."
)
)
ctx
.
call_conv
.
return_user_exc
(
builder
,
ValueError
,
(
msg
,))
ctx
.
call_conv
.
return_user_exc
(
builder
,
ValueError
,
(
msg
,)
)
with
otherwise
:
with
otherwise
:
msg
=
(
msg
=
f
"Vectorized input {j} has an incompatible shape in axis {i}."
f
"Input {j} to elemwise has an incompatible "
ctx
.
call_conv
.
return_user_exc
(
f
"shape in axis {i}."
builder
,
ValueError
,
(
msg
,)
)
)
ctx
.
call_conv
.
return_user_exc
(
builder
,
ValueError
,
(
msg
,))
else
:
else
:
shape
[
i
]
=
length
shape
[
i
]
=
length
for
i
in
range
(
ndim
):
for
i
in
range
(
batch_
ndim
):
if
shape
[
i
]
is
None
:
if
shape
[
i
]
is
None
:
shape
[
i
]
=
one
shape
[
i
]
=
one
return
shape
return
shape
...
@@ -237,27 +369,32 @@ def make_outputs(
...
@@ -237,27 +369,32 @@ def make_outputs(
inplace
:
tuple
[
tuple
[
int
,
int
],
...
],
inplace
:
tuple
[
tuple
[
int
,
int
],
...
],
inputs
:
tuple
[
Any
,
...
],
inputs
:
tuple
[
Any
,
...
],
input_types
:
tuple
[
Any
,
...
],
input_types
:
tuple
[
Any
,
...
],
):
output_core_shapes
:
tuple
,
arrays
=
[]
)
->
tuple
[
list
[
ir
.
Value
],
list
[
types
.
Array
]]:
ar_types
:
list
[
types
.
Array
]
=
[]
output_arrays
=
[]
output_arry_types
=
[]
one
=
ir
.
IntType
(
64
)(
1
)
one
=
ir
.
IntType
(
64
)(
1
)
inplace_dict
=
dict
(
inplace
)
inplace_dict
=
dict
(
inplace
)
for
i
,
(
bc
,
dtype
)
in
enumerate
(
zip
(
out_bc
,
dtypes
)):
for
i
,
(
core_shape
,
bc
,
dtype
)
in
enumerate
(
zip
(
output_core_shapes
,
out_bc
,
dtypes
)
):
if
i
in
inplace_dict
:
if
i
in
inplace_dict
:
arrays
.
append
(
inputs
[
inplace_dict
[
i
]])
output_
arrays
.
append
(
inputs
[
inplace_dict
[
i
]])
ar
_types
.
append
(
input_types
[
inplace_dict
[
i
]])
output_arry
_types
.
append
(
input_types
[
inplace_dict
[
i
]])
# We need to incref once we return the inplace objects
# We need to incref once we return the inplace objects
continue
continue
dtype
=
numba
.
from_dtype
(
np
.
dtype
(
dtype
))
dtype
=
numba
.
from_dtype
(
np
.
dtype
(
dtype
))
arrtype
=
types
.
Array
(
dtype
,
len
(
iter_shape
),
"C"
)
output_ndim
=
len
(
iter_shape
)
+
len
(
core_shape
)
ar_types
.
append
(
arrtype
)
arrtype
=
types
.
Array
(
dtype
,
output_ndim
,
"C"
)
output_arry_types
.
append
(
arrtype
)
# This is actually an internal numba function, I guess we could
# This is actually an internal numba function, I guess we could
# call `numba.nd.unsafe.ndarray` instead?
# call `numba.nd.unsafe.ndarray` instead?
shape
=
[
batch_
shape
=
[
length
if
not
bc_dim
else
one
for
length
,
bc_dim
in
zip
(
iter_shape
,
bc
)
length
if
not
bc_dim
else
one
for
length
,
bc_dim
in
zip
(
iter_shape
,
bc
)
]
]
shape
=
batch_shape
+
core_shape
array
=
arrayobj
.
_empty_nd_impl
(
ctx
,
builder
,
arrtype
,
shape
)
array
=
arrayobj
.
_empty_nd_impl
(
ctx
,
builder
,
arrtype
,
shape
)
arrays
.
append
(
array
)
output_
arrays
.
append
(
array
)
# If there is no inplace operation, we know that all output arrays
# If there is no inplace operation, we know that all output arrays
# don't alias. Informing llvm can make it easier to vectorize.
# don't alias. Informing llvm can make it easier to vectorize.
...
@@ -265,7 +402,7 @@ def make_outputs(
...
@@ -265,7 +402,7 @@ def make_outputs(
# The first argument is the output pointer
# The first argument is the output pointer
arg
=
builder
.
function
.
args
[
0
]
arg
=
builder
.
function
.
args
[
0
]
arg
.
add_attribute
(
"noalias"
)
arg
.
add_attribute
(
"noalias"
)
return
arrays
,
ar
_types
return
output_arrays
,
output_arry
_types
def
make_loop_call
(
def
make_loop_call
(
...
@@ -275,6 +412,7 @@ def make_loop_call(
...
@@ -275,6 +412,7 @@ def make_loop_call(
scalar_func
:
Any
,
scalar_func
:
Any
,
scalar_signature
:
types
.
FunctionType
,
scalar_signature
:
types
.
FunctionType
,
iter_shape
:
tuple
[
ir
.
Instruction
,
...
],
iter_shape
:
tuple
[
ir
.
Instruction
,
...
],
constant_inputs
:
tuple
[
ir
.
Instruction
,
...
],
inputs
:
tuple
[
ir
.
Instruction
,
...
],
inputs
:
tuple
[
ir
.
Instruction
,
...
],
outputs
:
tuple
[
ir
.
Instruction
,
...
],
outputs
:
tuple
[
ir
.
Instruction
,
...
],
input_bc
:
tuple
[
tuple
[
bool
,
...
],
...
],
input_bc
:
tuple
[
tuple
[
bool
,
...
],
...
],
...
@@ -283,18 +421,8 @@ def make_loop_call(
...
@@ -283,18 +421,8 @@ def make_loop_call(
output_types
:
tuple
[
Any
,
...
],
output_types
:
tuple
[
Any
,
...
],
):
):
safe
=
(
False
,
False
)
safe
=
(
False
,
False
)
n_outputs
=
len
(
outputs
)
# context.printf(builder, "iter shape: " + ', '.join(["%i"] * len(iter_shape)) + "\n", *iter_shape)
# Extract shape and stride information from the array.
n_outputs
=
len
(
outputs
)
# For later use in the loop body to do the indexing
def
extract_array
(
aryty
,
obj
):
shape
=
cgutils
.
unpack_tuple
(
builder
,
obj
.
shape
)
strides
=
cgutils
.
unpack_tuple
(
builder
,
obj
.
strides
)
data
=
obj
.
data
layout
=
aryty
.
layout
return
(
data
,
shape
,
strides
,
layout
)
# TODO I think this is better than the noalias attribute
# TODO I think this is better than the noalias attribute
# for the input, but self_ref isn't supported in a released
# for the input, but self_ref isn't supported in a released
...
@@ -306,12 +434,6 @@ def make_loop_call(
...
@@ -306,12 +434,6 @@ def make_loop_call(
# input_scope_set = mod.add_metadata([input_scope, output_scope])
# input_scope_set = mod.add_metadata([input_scope, output_scope])
# output_scope_set = mod.add_metadata([input_scope, output_scope])
# output_scope_set = mod.add_metadata([input_scope, output_scope])
inputs
=
tuple
(
extract_array
(
aryty
,
ary
)
for
aryty
,
ary
in
zip
(
input_types
,
inputs
))
outputs
=
tuple
(
extract_array
(
aryty
,
ary
)
for
aryty
,
ary
in
zip
(
output_types
,
outputs
)
)
zero
=
ir
.
Constant
(
ir
.
IntType
(
64
),
0
)
zero
=
ir
.
Constant
(
ir
.
IntType
(
64
),
0
)
# Setup loops and initialize accumulators for outputs
# Setup loops and initialize accumulators for outputs
...
@@ -338,69 +460,105 @@ def make_loop_call(
...
@@ -338,69 +460,105 @@ def make_loop_call(
# Load values from input arrays
# Load values from input arrays
input_vals
=
[]
input_vals
=
[]
for
array_info
,
bc
in
zip
(
inputs
,
input_bc
):
for
input
,
input_type
,
bc
in
zip
(
inputs
,
input_types
,
input_bc
):
idxs_bc
=
[
zero
if
bc
else
idx
for
idx
,
bc
in
zip
(
idxs
,
bc
)]
core_ndim
=
input_type
.
ndim
-
len
(
bc
)
ptr
=
cgutils
.
get_item_pointer2
(
context
,
builder
,
*
array_info
,
idxs_bc
,
*
safe
)
idxs_bc
=
[
zero
if
bc
else
idx
for
idx
,
bc
in
zip
(
idxs
,
bc
)]
+
[
zero
]
*
core_ndim
ptr
=
cgutils
.
get_item_pointer2
(
context
,
builder
,
input
.
data
,
cgutils
.
unpack_tuple
(
builder
,
input
.
shape
),
cgutils
.
unpack_tuple
(
builder
,
input
.
strides
),
input_type
.
layout
,
idxs_bc
,
*
safe
,
)
if
core_ndim
==
0
:
# Retrive scalar item at index
val
=
builder
.
load
(
ptr
)
val
=
builder
.
load
(
ptr
)
# val.set_metadata("alias.scope", input_scope_set)
# val.set_metadata("alias.scope", input_scope_set)
# val.set_metadata("noalias", output_scope_set)
# val.set_metadata("noalias", output_scope_set)
else
:
# Retrieve array item at index
# This is a streamlined version of Numba's `GUArrayArg.load`
# TODO check layout arg!
core_arry_type
=
types
.
Array
(
dtype
=
input_type
.
dtype
,
ndim
=
core_ndim
,
layout
=
input_type
.
layout
)
core_array
=
context
.
make_array
(
core_arry_type
)(
context
,
builder
)
core_shape
=
cgutils
.
unpack_tuple
(
builder
,
input
.
shape
)[
-
core_ndim
:]
core_strides
=
cgutils
.
unpack_tuple
(
builder
,
input
.
strides
)[
-
core_ndim
:]
itemsize
=
context
.
get_abi_sizeof
(
context
.
get_data_type
(
input_type
.
dtype
))
context
.
populate_array
(
core_array
,
# TODO whey do we need to bitcast?
data
=
builder
.
bitcast
(
ptr
,
core_array
.
data
.
type
),
shape
=
cgutils
.
pack_array
(
builder
,
core_shape
),
strides
=
cgutils
.
pack_array
(
builder
,
core_strides
),
itemsize
=
context
.
get_constant
(
types
.
intp
,
itemsize
),
# TODO what is meminfo about?
meminfo
=
None
,
)
val
=
core_array
.
_getvalue
()
input_vals
.
append
(
val
)
input_vals
.
append
(
val
)
# Create output slices to pass to inner func
output_slices
=
[]
for
output
,
output_type
,
bc
in
zip
(
outputs
,
output_types
,
output_bc
):
core_ndim
=
output_type
.
ndim
-
len
(
bc
)
size_type
=
output
.
shape
.
type
.
element
# type: ignore
output_shape
=
cgutils
.
unpack_tuple
(
builder
,
output
.
shape
)
# type: ignore
output_strides
=
cgutils
.
unpack_tuple
(
builder
,
output
.
strides
)
# type: ignore
idxs_bc
=
[
zero
if
bc
else
idx
for
idx
,
bc
in
zip
(
idxs
,
bc
)]
+
[
zero
]
*
core_ndim
ptr
=
cgutils
.
get_item_pointer2
(
context
,
builder
,
output
.
data
,
# type:ignore
output_shape
,
output_strides
,
output_type
.
layout
,
idxs_bc
,
*
safe
,
)
# Retrieve array item at index
# This is a streamlined version of Numba's `GUArrayArg.load`
core_arry_type
=
types
.
Array
(
dtype
=
output_type
.
dtype
,
ndim
=
core_ndim
,
layout
=
output_type
.
layout
)
core_array
=
context
.
make_array
(
core_arry_type
)(
context
,
builder
)
core_shape
=
output_shape
[
-
core_ndim
:]
if
core_ndim
>
0
else
[]
core_strides
=
output_strides
[
-
core_ndim
:]
if
core_ndim
>
0
else
[]
itemsize
=
context
.
get_abi_sizeof
(
context
.
get_data_type
(
output_type
.
dtype
))
context
.
populate_array
(
core_array
,
# TODO whey do we need to bitcast?
data
=
builder
.
bitcast
(
ptr
,
core_array
.
data
.
type
),
shape
=
cgutils
.
pack_array
(
builder
,
core_shape
,
ty
=
size_type
),
strides
=
cgutils
.
pack_array
(
builder
,
core_strides
,
ty
=
size_type
),
itemsize
=
context
.
get_constant
(
types
.
intp
,
itemsize
),
# TODO what is meminfo about?
meminfo
=
None
,
)
val
=
core_array
.
_getvalue
()
output_slices
.
append
(
val
)
inner_codegen
=
context
.
get_function
(
scalar_func
,
scalar_signature
)
inner_codegen
=
context
.
get_function
(
scalar_func
,
scalar_signature
)
if
isinstance
(
scalar_signature
.
args
[
0
],
types
.
StarArgTuple
|
types
.
StarArgUniTuple
):
if
isinstance
(
scalar_signature
.
args
[
0
],
types
.
StarArgTuple
|
types
.
StarArgUniTuple
):
input_vals
=
[
context
.
make_tuple
(
builder
,
scalar_signature
.
args
[
0
],
input_vals
)]
input_vals
=
[
context
.
make_tuple
(
builder
,
scalar_signature
.
args
[
0
],
input_vals
)]
output_values
=
inner_codegen
(
builder
,
input_vals
)
if
isinstance
(
scalar_signature
.
return_type
,
types
.
Tuple
|
types
.
UniTuple
):
inner_codegen
(
builder
,
[
*
constant_inputs
,
*
input_vals
,
*
output_slices
])
output_values
=
cgutils
.
unpack_tuple
(
builder
,
output_values
)
func_output_types
=
scalar_signature
.
return_type
.
types
else
:
output_values
=
[
output_values
]
func_output_types
=
[
scalar_signature
.
return_type
]
# Update output value or accumulators respectively
for
i
,
((
accu
,
_
),
value
)
in
enumerate
(
zip
(
output_accumulator
,
output_values
)):
if
accu
is
not
None
:
load
=
builder
.
load
(
accu
)
# load.set_metadata("alias.scope", output_scope_set)
# load.set_metadata("noalias", input_scope_set)
new_value
=
builder
.
fadd
(
load
,
value
)
builder
.
store
(
new_value
,
accu
)
# TODO belongs to noalias scope
# store.set_metadata("alias.scope", output_scope_set)
# store.set_metadata("noalias", input_scope_set)
else
:
idxs_bc
=
[
zero
if
bc
else
idx
for
idx
,
bc
in
zip
(
idxs
,
output_bc
[
i
])]
ptr
=
cgutils
.
get_item_pointer2
(
context
,
builder
,
*
outputs
[
i
],
idxs_bc
)
# store = builder.store(value, ptr)
value
=
context
.
cast
(
builder
,
value
,
func_output_types
[
i
],
output_types
[
i
]
.
dtype
)
arrayobj
.
store_item
(
context
,
builder
,
output_types
[
i
],
value
,
ptr
)
# store.set_metadata("alias.scope", output_scope_set)
# store.set_metadata("noalias", input_scope_set)
# Close the loops
and write accumulator values to the output arrays
# Close the loops
for
depth
,
loop
in
enumerate
(
loop_stack
[::
-
1
]):
for
depth
,
loop
in
enumerate
(
loop_stack
[::
-
1
]):
for
output
,
(
accu
,
accu_depth
)
in
enumerate
(
output_accumulator
):
if
accu_depth
==
depth
:
idxs_bc
=
[
zero
if
bc
else
idx
for
idx
,
bc
in
zip
(
idxs
,
output_bc
[
output
])
]
ptr
=
cgutils
.
get_item_pointer2
(
context
,
builder
,
*
outputs
[
output
],
idxs_bc
)
load
=
builder
.
load
(
accu
)
# load.set_metadata("alias.scope", output_scope_set)
# load.set_metadata("noalias", input_scope_set)
# store = builder.store(load, ptr)
load
=
context
.
cast
(
builder
,
load
,
func_output_types
[
output
],
output_types
[
output
]
.
dtype
)
arrayobj
.
store_item
(
context
,
builder
,
output_types
[
output
],
load
,
ptr
)
# store.set_metadata("alias.scope", output_scope_set)
# store.set_metadata("noalias", input_scope_set)
loop
.
__exit__
(
None
,
None
,
None
)
loop
.
__exit__
(
None
,
None
,
None
)
return
return
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论