Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c0a4276d
提交
c0a4276d
authored
4月 23, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move vectorize wrapper to vectorize_codegen
上级
89e9bd6b
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
171 行增加
和
183 行删除
+171
-183
elemwise.py
pytensor/link/numba/dispatch/elemwise.py
+7
-182
vectorize_codegen.py
pytensor/link/numba/dispatch/vectorize_codegen.py
+164
-1
没有找到文件。
pytensor/link/numba/dispatch/elemwise.py
浏览文件 @
c0a4276d
...
...
@@ -8,17 +8,13 @@ from typing import Any
import
numba
import
numpy
as
np
from
numba
import
TypingError
,
types
from
numba.core
import
cgutils
from
numba.core.extending
import
overload
from
numba.np
import
arrayobj
from
numpy.core.numeric
import
normalize_axis_index
,
normalize_axis_tuple
from
pytensor
import
config
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
vectorize_codegen
from
pytensor.link.numba.dispatch.basic
import
(
create_numba_signature
,
create_tuple_creator
,
...
...
@@ -26,6 +22,7 @@ from pytensor.link.numba.dispatch.basic import (
numba_njit
,
use_optimized_cheap_pass
,
)
from
pytensor.link.numba.dispatch.vectorize_codegen
import
_jit_options
,
_vectorized
from
pytensor.link.utils
import
compile_function_src
,
get_name_for_object
from
pytensor.scalar.basic
import
(
AND
,
...
...
@@ -463,167 +460,6 @@ def create_axis_apply_fn(fn, axis, ndim, dtype):
return
axis_apply_fn
_jit_options
=
{
"fastmath"
:
{
"arcp"
,
# Allow Reciprocal
"contract"
,
# Allow floating-point contraction
"afn"
,
# Approximate functions
"reassoc"
,
"nsz"
,
# TODO Do we want this one?
},
"no_cpython_wrapper"
:
True
,
"no_cfunc_wrapper"
:
True
,
}
@numba.extending.intrinsic
(
jit_options
=
_jit_options
,
prefer_literal
=
True
)
def
_vectorized
(
typingctx
,
scalar_func
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
inplace_pattern
,
inputs
,
):
arg_types
=
[
scalar_func
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
inplace_pattern
,
inputs
,
]
if
not
isinstance
(
input_bc_patterns
,
types
.
Literal
):
raise
TypingError
(
"input_bc_patterns must be literal."
)
input_bc_patterns
=
input_bc_patterns
.
literal_value
input_bc_patterns
=
pickle
.
loads
(
base64
.
decodebytes
(
input_bc_patterns
.
encode
()))
if
not
isinstance
(
output_bc_patterns
,
types
.
Literal
):
raise
TypeError
(
"output_bc_patterns must be literal."
)
output_bc_patterns
=
output_bc_patterns
.
literal_value
output_bc_patterns
=
pickle
.
loads
(
base64
.
decodebytes
(
output_bc_patterns
.
encode
()))
if
not
isinstance
(
output_dtypes
,
types
.
Literal
):
raise
TypeError
(
"output_dtypes must be literal."
)
output_dtypes
=
output_dtypes
.
literal_value
output_dtypes
=
pickle
.
loads
(
base64
.
decodebytes
(
output_dtypes
.
encode
()))
if
not
isinstance
(
inplace_pattern
,
types
.
Literal
):
raise
TypeError
(
"inplace_pattern must be literal."
)
inplace_pattern
=
inplace_pattern
.
literal_value
inplace_pattern
=
pickle
.
loads
(
base64
.
decodebytes
(
inplace_pattern
.
encode
()))
n_outputs
=
len
(
output_bc_patterns
)
if
not
len
(
inputs
)
>
0
:
raise
TypingError
(
"Empty argument list to elemwise op."
)
if
not
n_outputs
>
0
:
raise
TypingError
(
"Empty list of outputs for elemwise op."
)
if
not
all
(
isinstance
(
input
,
types
.
Array
)
for
input
in
inputs
):
raise
TypingError
(
"Inputs to elemwise must be arrays."
)
ndim
=
inputs
[
0
]
.
ndim
if
not
all
(
input
.
ndim
==
ndim
for
input
in
inputs
):
raise
TypingError
(
"Inputs to elemwise must have the same rank."
)
if
not
all
(
len
(
pattern
)
==
ndim
for
pattern
in
output_bc_patterns
):
raise
TypingError
(
"Invalid output broadcasting pattern."
)
scalar_signature
=
typingctx
.
resolve_function_type
(
scalar_func
,
[
in_type
.
dtype
for
in_type
in
inputs
],
{}
)
# So we can access the constant values in codegen...
input_bc_patterns_val
=
input_bc_patterns
output_bc_patterns_val
=
output_bc_patterns
output_dtypes_val
=
output_dtypes
inplace_pattern_val
=
inplace_pattern
input_types
=
inputs
def
codegen
(
ctx
,
builder
,
sig
,
args
,
):
[
_
,
_
,
_
,
_
,
_
,
inputs
]
=
args
inputs
=
cgutils
.
unpack_tuple
(
builder
,
inputs
)
inputs
=
[
arrayobj
.
make_array
(
ty
)(
ctx
,
builder
,
val
)
for
ty
,
val
in
zip
(
input_types
,
inputs
)
]
in_shapes
=
[
cgutils
.
unpack_tuple
(
builder
,
obj
.
shape
)
for
obj
in
inputs
]
iter_shape
=
vectorize_codegen
.
compute_itershape
(
ctx
,
builder
,
in_shapes
,
input_bc_patterns_val
,
)
outputs
,
output_types
=
vectorize_codegen
.
make_outputs
(
ctx
,
builder
,
iter_shape
,
output_bc_patterns_val
,
output_dtypes_val
,
inplace_pattern_val
,
inputs
,
input_types
,
)
vectorize_codegen
.
make_loop_call
(
typingctx
,
ctx
,
builder
,
scalar_func
,
scalar_signature
,
iter_shape
,
inputs
,
outputs
,
input_bc_patterns_val
,
output_bc_patterns_val
,
input_types
,
output_types
,
)
if
len
(
outputs
)
==
1
:
if
inplace_pattern
:
assert
inplace_pattern
[
0
][
0
]
==
0
ctx
.
nrt
.
incref
(
builder
,
sig
.
return_type
,
outputs
[
0
]
.
_getvalue
())
return
outputs
[
0
]
.
_getvalue
()
for
inplace_idx
in
dict
(
inplace_pattern
):
ctx
.
nrt
.
incref
(
builder
,
sig
.
return_type
.
types
[
inplace_idx
],
outputs
[
inplace_idx
]
.
_get_value
(),
)
return
ctx
.
make_tuple
(
builder
,
sig
.
return_type
,
[
out
.
_getvalue
()
for
out
in
outputs
]
)
ret_types
=
[
types
.
Array
(
numba
.
from_dtype
(
np
.
dtype
(
dtype
)),
ndim
,
"C"
)
for
dtype
in
output_dtypes
]
for
output_idx
,
input_idx
in
inplace_pattern
:
ret_types
[
output_idx
]
=
input_types
[
input_idx
]
ret_type
=
types
.
Tuple
(
ret_types
)
if
len
(
output_dtypes
)
==
1
:
ret_type
=
ret_type
.
types
[
0
]
sig
=
ret_type
(
*
arg_types
)
return
sig
,
codegen
@numba_funcify.register
(
Elemwise
)
def
numba_funcify_Elemwise
(
op
,
node
,
**
kwargs
):
# Creating a new scalar node is more involved and unnecessary
...
...
@@ -634,16 +470,12 @@ def numba_funcify_Elemwise(op, node, **kwargs):
scalar_inputs
=
[
scalar
(
dtype
=
input
.
dtype
)
for
input
in
node
.
inputs
]
scalar_node
=
op
.
scalar_op
.
make_node
(
*
scalar_inputs
)
flags
=
{
"arcp"
,
# Allow Reciprocal
"contract"
,
# Allow floating-point contraction
"afn"
,
# Approximate functions
"reassoc"
,
"nsz"
,
# TODO Do we want this one?
}
scalar_op_fn
=
numba_funcify
(
op
.
scalar_op
,
node
=
scalar_node
,
parent_node
=
node
,
fastmath
=
flags
,
**
kwargs
op
.
scalar_op
,
node
=
scalar_node
,
parent_node
=
node
,
fastmath
=
_jit_options
[
"fastmath"
],
**
kwargs
,
)
ndim
=
node
.
outputs
[
0
]
.
ndim
...
...
@@ -700,14 +532,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
return
tuple
(
outputs_summed
)
return
outputs_summed
[
0
]
@overload
(
elemwise
,
jit_options
=
{
"fastmath"
:
flags
,
"no_cpython_wrapper"
:
True
,
"no_cfunc_wrapper"
:
True
,
},
)
@overload
(
elemwise
,
jit_options
=
_jit_options
)
def
ov_elemwise
(
*
inputs
):
return
elemwise_wrapper
...
...
pytensor/link/numba/dispatch/vectorize_codegen.py
浏览文件 @
c0a4276d
from
__future__
import
annotations
import
base64
import
pickle
from
typing
import
Any
import
numba
import
numpy
as
np
from
llvmlite
import
ir
from
numba
import
types
from
numba
import
TypingError
,
types
from
numba.core
import
cgutils
from
numba.core.base
import
BaseContext
from
numba.np
import
arrayobj
_jit_options
=
{
"fastmath"
:
{
"arcp"
,
# Allow Reciprocal
"contract"
,
# Allow floating-point contraction
"afn"
,
# Approximate functions
"reassoc"
,
"nsz"
,
# TODO Do we want this one?
},
"no_cpython_wrapper"
:
True
,
"no_cfunc_wrapper"
:
True
,
}
@numba.extending.intrinsic
(
jit_options
=
_jit_options
,
prefer_literal
=
True
)
def
_vectorized
(
typingctx
,
scalar_func
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
inplace_pattern
,
inputs
,
):
arg_types
=
[
scalar_func
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
inplace_pattern
,
inputs
,
]
if
not
isinstance
(
input_bc_patterns
,
types
.
Literal
):
raise
TypingError
(
"input_bc_patterns must be literal."
)
input_bc_patterns
=
input_bc_patterns
.
literal_value
input_bc_patterns
=
pickle
.
loads
(
base64
.
decodebytes
(
input_bc_patterns
.
encode
()))
if
not
isinstance
(
output_bc_patterns
,
types
.
Literal
):
raise
TypeError
(
"output_bc_patterns must be literal."
)
output_bc_patterns
=
output_bc_patterns
.
literal_value
output_bc_patterns
=
pickle
.
loads
(
base64
.
decodebytes
(
output_bc_patterns
.
encode
()))
if
not
isinstance
(
output_dtypes
,
types
.
Literal
):
raise
TypeError
(
"output_dtypes must be literal."
)
output_dtypes
=
output_dtypes
.
literal_value
output_dtypes
=
pickle
.
loads
(
base64
.
decodebytes
(
output_dtypes
.
encode
()))
if
not
isinstance
(
inplace_pattern
,
types
.
Literal
):
raise
TypeError
(
"inplace_pattern must be literal."
)
inplace_pattern
=
inplace_pattern
.
literal_value
inplace_pattern
=
pickle
.
loads
(
base64
.
decodebytes
(
inplace_pattern
.
encode
()))
n_outputs
=
len
(
output_bc_patterns
)
if
not
len
(
inputs
)
>
0
:
raise
TypingError
(
"Empty argument list to elemwise op."
)
if
not
n_outputs
>
0
:
raise
TypingError
(
"Empty list of outputs for elemwise op."
)
if
not
all
(
isinstance
(
input
,
types
.
Array
)
for
input
in
inputs
):
raise
TypingError
(
"Inputs to elemwise must be arrays."
)
ndim
=
inputs
[
0
]
.
ndim
if
not
all
(
input
.
ndim
==
ndim
for
input
in
inputs
):
raise
TypingError
(
"Inputs to elemwise must have the same rank."
)
if
not
all
(
len
(
pattern
)
==
ndim
for
pattern
in
output_bc_patterns
):
raise
TypingError
(
"Invalid output broadcasting pattern."
)
scalar_signature
=
typingctx
.
resolve_function_type
(
scalar_func
,
[
in_type
.
dtype
for
in_type
in
inputs
],
{}
)
# So we can access the constant values in codegen...
input_bc_patterns_val
=
input_bc_patterns
output_bc_patterns_val
=
output_bc_patterns
output_dtypes_val
=
output_dtypes
inplace_pattern_val
=
inplace_pattern
input_types
=
inputs
def
codegen
(
ctx
,
builder
,
sig
,
args
,
):
[
_
,
_
,
_
,
_
,
_
,
inputs
]
=
args
inputs
=
cgutils
.
unpack_tuple
(
builder
,
inputs
)
inputs
=
[
arrayobj
.
make_array
(
ty
)(
ctx
,
builder
,
val
)
for
ty
,
val
in
zip
(
input_types
,
inputs
)
]
in_shapes
=
[
cgutils
.
unpack_tuple
(
builder
,
obj
.
shape
)
for
obj
in
inputs
]
iter_shape
=
compute_itershape
(
ctx
,
builder
,
in_shapes
,
input_bc_patterns_val
,
)
outputs
,
output_types
=
make_outputs
(
ctx
,
builder
,
iter_shape
,
output_bc_patterns_val
,
output_dtypes_val
,
inplace_pattern_val
,
inputs
,
input_types
,
)
make_loop_call
(
typingctx
,
ctx
,
builder
,
scalar_func
,
scalar_signature
,
iter_shape
,
inputs
,
outputs
,
input_bc_patterns_val
,
output_bc_patterns_val
,
input_types
,
output_types
,
)
if
len
(
outputs
)
==
1
:
if
inplace_pattern
:
assert
inplace_pattern
[
0
][
0
]
==
0
ctx
.
nrt
.
incref
(
builder
,
sig
.
return_type
,
outputs
[
0
]
.
_getvalue
())
return
outputs
[
0
]
.
_getvalue
()
for
inplace_idx
in
dict
(
inplace_pattern
):
ctx
.
nrt
.
incref
(
builder
,
sig
.
return_type
.
types
[
inplace_idx
],
outputs
[
inplace_idx
]
.
_get_value
(),
)
return
ctx
.
make_tuple
(
builder
,
sig
.
return_type
,
[
out
.
_getvalue
()
for
out
in
outputs
]
)
ret_types
=
[
types
.
Array
(
numba
.
from_dtype
(
np
.
dtype
(
dtype
)),
ndim
,
"C"
)
for
dtype
in
output_dtypes
]
for
output_idx
,
input_idx
in
inplace_pattern
:
ret_types
[
output_idx
]
=
input_types
[
input_idx
]
ret_type
=
types
.
Tuple
(
ret_types
)
if
len
(
output_dtypes
)
==
1
:
ret_type
=
ret_type
.
types
[
0
]
sig
=
ret_type
(
*
arg_types
)
return
sig
,
codegen
def
compute_itershape
(
ctx
:
BaseContext
,
builder
:
ir
.
IRBuilder
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论