Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c7152952
提交
c7152952
authored
10月 04, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
12月 03, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Numba implementation of Blockwise
上级
18ba52cd
显示空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
272 行增加
和
7 行删除
+272
-7
__init__.py
pytensor/link/numba/dispatch/__init__.py
+5
-4
blockwise.py
pytensor/link/numba/dispatch/blockwise.py
+88
-0
random.py
pytensor/link/numba/dispatch/random.py
+1
-1
blockwise.py
pytensor/tensor/blockwise.py
+8
-0
numba.py
pytensor/tensor/random/rewriting/numba.py
+1
-1
__init__.py
pytensor/tensor/rewriting/__init__.py
+1
-0
numba.py
pytensor/tensor/rewriting/numba.py
+108
-0
test_basic.py
tests/link/numba/test_basic.py
+1
-1
test_blockwise.py
tests/link/numba/test_blockwise.py
+59
-0
没有找到文件。
pytensor/link/numba/dispatch/__init__.py
浏览文件 @
c7152952
...
...
@@ -2,15 +2,16 @@
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
,
numba_typify
# Load dispatch specializations
import
pytensor.link.numba.dispatch.
scalar
import
pytensor.link.numba.dispatch.
tensor_basic
import
pytensor.link.numba.dispatch.
blockwise
import
pytensor.link.numba.dispatch.
elemwise
import
pytensor.link.numba.dispatch.extra_ops
import
pytensor.link.numba.dispatch.nlinalg
import
pytensor.link.numba.dispatch.random
import
pytensor.link.numba.dispatch.elemwise
import
pytensor.link.numba.dispatch.scan
import
pytensor.link.numba.dispatch.s
parse
import
pytensor.link.numba.dispatch.s
calar
import
pytensor.link.numba.dispatch.slinalg
import
pytensor.link.numba.dispatch.sparse
import
pytensor.link.numba.dispatch.subtensor
import
pytensor.link.numba.dispatch.tensor_basic
# isort: on
pytensor/link/numba/dispatch/blockwise.py
0 → 100644
浏览文件 @
c7152952
from
typing
import
cast
from
numba.core.extending
import
overload
from
numba.np.unsafe.ndarray
import
to_fixed_tuple
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
,
numba_njit
from
pytensor.link.numba.dispatch.vectorize_codegen
import
(
_jit_options
,
_vectorized
,
encode_literals
,
store_core_outputs
,
)
from
pytensor.link.utils
import
compile_function_src
from
pytensor.tensor
import
TensorVariable
,
get_vector_length
from
pytensor.tensor.blockwise
import
Blockwise
,
BlockwiseWithCoreShape
@numba_funcify.register
def
numba_funcify_Blockwise
(
op
:
BlockwiseWithCoreShape
,
node
,
**
kwargs
):
[
blockwise_node
]
=
op
.
fgraph
.
apply_nodes
blockwise_op
:
Blockwise
=
blockwise_node
.
op
core_op
=
blockwise_op
.
core_op
nin
=
len
(
blockwise_node
.
inputs
)
nout
=
len
(
blockwise_node
.
outputs
)
core_shapes_len
=
tuple
(
get_vector_length
(
sh
)
for
sh
in
node
.
inputs
[
nin
:])
core_node
=
blockwise_op
.
_create_dummy_core_node
(
cast
(
tuple
[
TensorVariable
],
blockwise_node
.
inputs
)
)
core_op_fn
=
numba_funcify
(
core_op
,
node
=
core_node
,
parent_node
=
node
,
fastmath
=
_jit_options
[
"fastmath"
],
**
kwargs
,
)
core_op_fn
=
store_core_outputs
(
core_op_fn
,
nin
=
nin
,
nout
=
nout
)
batch_ndim
=
blockwise_op
.
batch_ndim
(
node
)
# numba doesn't support nested literals right now...
input_bc_patterns
=
encode_literals
(
tuple
(
inp
.
type
.
broadcastable
[:
batch_ndim
]
for
inp
in
node
.
inputs
[:
nin
])
)
output_bc_patterns
=
encode_literals
(
tuple
(
out
.
type
.
broadcastable
[:
batch_ndim
]
for
out
in
node
.
outputs
)
)
output_dtypes
=
encode_literals
(
tuple
(
out
.
type
.
dtype
for
out
in
node
.
outputs
))
inplace_pattern
=
encode_literals
(())
# Numba does not allow a tuple generator in the Jitted function so we have to compile a helper to convert core_shapes into tuples
# Alternatively, add an Op that converts shape vectors into tuples, like we did for JAX
src
=
"def to_tuple(core_shapes): return ("
for
i
in
range
(
nout
):
src
+=
f
"to_fixed_tuple(core_shapes[{i}], {core_shapes_len[i]}),"
src
+=
")"
to_tuple
=
numba_njit
(
compile_function_src
(
src
,
"to_tuple"
,
global_env
=
{
"to_fixed_tuple"
:
to_fixed_tuple
},
)
)
def
blockwise_wrapper
(
*
inputs_and_core_shapes
):
inputs
,
core_shapes
=
inputs_and_core_shapes
[:
nin
],
inputs_and_core_shapes
[
nin
:]
tuple_core_shapes
=
to_tuple
(
core_shapes
)
return
_vectorized
(
core_op_fn
,
input_bc_patterns
,
output_bc_patterns
,
output_dtypes
,
inplace_pattern
,
(),
# constant_inputs
inputs
,
tuple_core_shapes
,
None
,
# size
)
def
blockwise
(
*
inputs_and_core_shapes
):
raise
NotImplementedError
(
"Non-jitted BlockwiseWithCoreShape not implemented"
)
@overload
(
blockwise
,
jit_options
=
_jit_options
)
def
ov_blockwise
(
*
inputs_and_core_shapes
):
return
blockwise_wrapper
return
blockwise
pytensor/link/numba/dispatch/random.py
浏览文件 @
c7152952
...
...
@@ -388,7 +388,7 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
return
rng
,
draws
def
random
(
core_shape
,
rng
,
size
,
*
dist_params
):
pass
raise
NotImplementedError
(
"Non-jitted random variable not implemented"
)
@overload
(
random
,
jit_options
=
_jit_options
)
def
ov_random
(
core_shape
,
rng
,
size
,
*
dist_params
):
...
...
pytensor/tensor/blockwise.py
浏览文件 @
c7152952
...
...
@@ -442,3 +442,11 @@ _vectorize_node.register(Blockwise, _vectorize_not_needed)
class
OpWithCoreShape
(
OpFromGraph
):
"""Generalizes an `Op` to include core shape as an additional input."""
class
BlockwiseWithCoreShape
(
OpWithCoreShape
):
"""Generalizes a Blockwise `Op` to include a core shape parameter."""
def
__str__
(
self
):
[
blockwise_node
]
=
self
.
fgraph
.
apply_nodes
return
f
"[{blockwise_node.op!s}]"
pytensor/tensor/random/rewriting/numba.py
浏览文件 @
c7152952
...
...
@@ -15,7 +15,7 @@ def introduce_explicit_core_shape_rv(fgraph, node):
This core_shape is used by the numba backend to pre-allocate the output array.
If available, the core shape is extracted from the shape feature of the graph,
which has a higher chan
g
e of having been simplified, optimized, constant-folded.
which has a higher chan
c
e of having been simplified, optimized, constant-folded.
If missing, we fall back to the op._supp_shape_from_params method.
This rewrite is required for the numba backend implementation of RandomVariable.
...
...
pytensor/tensor/rewriting/__init__.py
浏览文件 @
c7152952
...
...
@@ -9,6 +9,7 @@ import pytensor.tensor.rewriting.extra_ops
import
pytensor.tensor.rewriting.jax
import
pytensor.tensor.rewriting.linalg
import
pytensor.tensor.rewriting.math
import
pytensor.tensor.rewriting.numba
import
pytensor.tensor.rewriting.ofg
import
pytensor.tensor.rewriting.shape
import
pytensor.tensor.rewriting.special
...
...
pytensor/tensor/rewriting/numba.py
0 → 100644
浏览文件 @
c7152952
from
pytensor.compile
import
optdb
from
pytensor.graph
import
node_rewriter
from
pytensor.graph.basic
import
applys_between
from
pytensor.graph.rewriting.basic
import
out2in
from
pytensor.tensor.basic
import
as_tensor
,
constant
from
pytensor.tensor.blockwise
import
Blockwise
,
BlockwiseWithCoreShape
from
pytensor.tensor.rewriting.shape
import
ShapeFeature
@node_rewriter
([
Blockwise
])
def
introduce_explicit_core_shape_blockwise
(
fgraph
,
node
):
"""Introduce the core shape of a Blockwise.
We wrap Blockwise graphs into a BlockwiseWithCoreShape OpFromGraph
that has an extra "non-functional" input that represents the core shape of the Blockwise variable.
This core_shape is used by the numba backend to pre-allocate the output array.
If available, the core shape is extracted from the shape feature of the graph,
which has a higher change of having been simplified, optimized, constant-folded.
If missing, we fall back to the op._supp_shape_from_params method.
This rewrite is required for the numba backend implementation of Blockwise.
Example
-------
.. code-block:: python
import pytensor
import pytensor.tensor as pt
x = pt.tensor("x", shape=(5, None, None))
outs = pt.linalg.svd(x, compute_uv=True)
pytensor.dprint(outs)
# Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.0 [id A]
# └─ x [id B]
# Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.1 [id A]
# └─ ···
# Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.2 [id A]
# └─ ···
# After the rewrite, note the new 3 core shape inputs
fn = pytensor.function([x], outs, mode="NUMBA")
fn.dprint(print_type=False)
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].0 [id A] 6
# ├─ x [id B]
# ├─ MakeVector{dtype='int64'} [id C] 5
# │ ├─ Shape_i{1} [id D] 2
# │ │ └─ x [id B]
# │ └─ Shape_i{1} [id D] 2
# │ └─ ···
# ├─ MakeVector{dtype='int64'} [id E] 4
# │ └─ Minimum [id F] 3
# │ ├─ Shape_i{1} [id D] 2
# │ │ └─ ···
# │ └─ Shape_i{2} [id G] 0
# │ └─ x [id B]
# └─ MakeVector{dtype='int64'} [id H] 1
# ├─ Shape_i{2} [id G] 0
# │ └─ ···
# └─ Shape_i{2} [id G] 0
# └─ ···
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].1 [id A] 6
# └─ ···
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].2 [id A] 6
# └─ ···
"""
op
:
Blockwise
=
node
.
op
# type: ignore[annotation-unchecked]
batch_ndim
=
op
.
batch_ndim
(
node
)
shape_feature
:
ShapeFeature
|
None
=
getattr
(
fgraph
,
"shape_feature"
,
None
)
# type: ignore[annotation-unchecked]
if
shape_feature
:
core_shapes
=
[
[
shape_feature
.
get_shape
(
out
,
i
)
for
i
in
range
(
batch_ndim
,
out
.
type
.
ndim
)]
for
out
in
node
.
outputs
]
else
:
input_shapes
=
[
tuple
(
inp
.
shape
)
for
inp
in
node
.
inputs
]
core_shapes
=
[
out_shape
[
batch_ndim
:]
for
out_shape
in
op
.
infer_shape
(
None
,
node
,
input_shapes
)
]
core_shapes
=
[
as_tensor
(
core_shape
)
if
len
(
core_shape
)
else
constant
([],
dtype
=
"int64"
)
for
core_shape
in
core_shapes
]
if
any
(
isinstance
(
node
.
op
,
Blockwise
)
for
node
in
applys_between
(
node
.
inputs
,
core_shapes
)
):
# If Blockwise shows up in the shape graph we can't introduce the core shape
return
None
return
BlockwiseWithCoreShape
(
[
*
node
.
inputs
,
*
core_shapes
],
node
.
outputs
,
destroy_map
=
op
.
destroy_map
,
)(
*
node
.
inputs
,
*
core_shapes
,
return_list
=
True
)
optdb
.
register
(
introduce_explicit_core_shape_blockwise
.
__name__
,
out2in
(
introduce_explicit_core_shape_blockwise
),
"numba"
,
position
=
100
,
)
tests/link/numba/test_basic.py
浏览文件 @
c7152952
...
...
@@ -244,7 +244,7 @@ def compare_numba_and_py(
Parameters
----------
fgraph
`FunctionGraph` or
inputs
to compare.
`FunctionGraph` or
tuple(inputs, outputs)
to compare.
inputs
Numeric inputs to be passed to the compiled graphs.
assert_fn
...
...
tests/link/numba/test_blockwise.py
0 → 100644
浏览文件 @
c7152952
import
numpy
as
np
import
pytest
from
pytensor
import
function
from
pytensor.tensor
import
tensor
from
pytensor.tensor.basic
import
ARange
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.nlinalg
import
SVD
,
Det
from
pytensor.tensor.slinalg
import
Cholesky
,
cholesky
from
tests.link.numba.test_basic
import
compare_numba_and_py
,
numba_mode
# Fails if object mode warning is issued when not expected
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
@pytest.mark.parametrize
(
"shape_opt"
,
[
True
,
False
],
ids
=
str
)
@pytest.mark.parametrize
(
"core_op"
,
[
Det
(),
Cholesky
(),
SVD
(
compute_uv
=
True
)],
ids
=
str
)
def
test_blockwise
(
core_op
,
shape_opt
):
x
=
tensor
(
shape
=
(
5
,
None
,
None
))
outs
=
Blockwise
(
core_op
=
core_op
)(
x
,
return_list
=
True
)
mode
=
(
numba_mode
.
including
(
"ShapeOpt"
)
if
shape_opt
else
numba_mode
.
excluding
(
"ShapeOpt"
)
)
x_test
=
np
.
eye
(
3
)
*
np
.
arange
(
1
,
6
)[:,
None
,
None
]
compare_numba_and_py
(
([
x
],
outs
),
[
x_test
],
numba_mode
=
mode
,
eval_obj_mode
=
False
,
)
def
test_non_square_blockwise
():
"""Test that Op that cannot always be blockwised at runtime fails gracefully."""
x
=
tensor
(
shape
=
(
3
,),
dtype
=
"int64"
)
out
=
Blockwise
(
core_op
=
ARange
(
dtype
=
"int64"
),
signature
=
"(),(),()->(a)"
)(
0
,
x
,
1
)
with
pytest
.
warns
(
UserWarning
,
match
=
"Numba will use object mode"
):
fn
=
function
([
x
],
out
,
mode
=
"NUMBA"
)
np
.
testing
.
assert_allclose
(
fn
([
5
,
5
,
5
]),
np
.
broadcast_to
(
np
.
arange
(
5
),
(
3
,
5
)))
with
pytest
.
raises
(
ValueError
):
fn
([
3
,
4
,
5
])
def
test_blockwise_benchmark
(
benchmark
):
x
=
tensor
(
shape
=
(
5
,
3
,
3
))
out
=
cholesky
(
x
)
assert
isinstance
(
out
.
owner
.
op
,
Blockwise
)
fn
=
function
([
x
],
out
,
mode
=
"NUMBA"
)
x_test
=
np
.
eye
(
3
)
*
np
.
arange
(
1
,
6
)[:,
None
,
None
]
fn
(
x_test
)
# JIT compile
benchmark
(
fn
,
x_test
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论