Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
a62e785d
提交
a62e785d
authored
7月 05, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
7月 08, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Allow Blockwise to create dummy core nodes with outer inputs, if these are unbatched
上级
efc9d693
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
105 行增加
和
56 行删除
+105
-56
blockwise.py
pytensor/link/jax/dispatch/blockwise.py
+5
-13
blockwise.py
pytensor/link/numba/dispatch/blockwise.py
+3
-2
blockwise.py
pytensor/tensor/blockwise.py
+96
-29
blockwise.py
pytensor/tensor/rewriting/blockwise.py
+1
-12
没有找到文件。
pytensor/link/jax/dispatch/blockwise.py
浏览文件 @
a62e785d
import
jax.numpy
as
jnp
from
pytensor.graph
import
FunctionGraph
from
pytensor.link.jax.dispatch
import
jax_funcify
from
pytensor.tensor.blockwise
import
Blockwise
@jax_funcify.register
(
Blockwise
)
def
funcify_Blockwise
(
op
:
Blockwise
,
node
,
*
args
,
**
kwargs
):
def
jax_funcify_Blockwise
(
op
:
Blockwise
,
node
,
**
kwargs
):
signature
=
op
.
signature
core_node
=
op
.
_create_dummy_core_node
(
node
.
inputs
)
core_fgraph
=
FunctionGraph
(
inputs
=
core_node
.
inputs
,
outputs
=
core_node
.
outputs
)
tuple_core_fn
=
jax_funcify
(
core_fgraph
)
if
len
(
node
.
outputs
)
==
1
:
def
core_fn
(
*
inputs
):
return
tuple_core_fn
(
*
inputs
)[
0
]
else
:
core_fn
=
tuple_core_fn
core_node
=
op
.
_create_dummy_core_node
(
node
.
inputs
,
propagate_unbatched_core_inputs
=
True
)
core_fn
=
jax_funcify
(
core_node
.
op
,
node
=
core_node
,
**
kwargs
)
vect_fn
=
jnp
.
vectorize
(
core_fn
,
signature
=
signature
)
...
...
pytensor/link/numba/dispatch/blockwise.py
浏览文件 @
a62e785d
...
...
@@ -16,7 +16,7 @@ from pytensor.tensor import TensorVariable, get_vector_length
from
pytensor.tensor.blockwise
import
Blockwise
,
BlockwiseWithCoreShape
@numba_funcify.register
@numba_funcify.register
(
BlockwiseWithCoreShape
)
def
numba_funcify_Blockwise
(
op
:
BlockwiseWithCoreShape
,
node
,
**
kwargs
):
[
blockwise_node
]
=
op
.
fgraph
.
apply_nodes
blockwise_op
:
Blockwise
=
blockwise_node
.
op
...
...
@@ -26,7 +26,8 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
core_shapes_len
=
tuple
(
get_vector_length
(
sh
)
for
sh
in
node
.
inputs
[
nin
:])
core_node
=
blockwise_op
.
_create_dummy_core_node
(
cast
(
tuple
[
TensorVariable
],
blockwise_node
.
inputs
)
cast
(
tuple
[
TensorVariable
],
node
.
inputs
[:
nin
]),
propagate_unbatched_core_inputs
=
True
,
)
core_op_fn
=
numba_funcify
(
core_op
,
...
...
pytensor/tensor/blockwise.py
浏览文件 @
a62e785d
from
collections.abc
import
Callable
,
Sequence
from
typing
import
Any
,
cast
from
typing
import
Any
,
Literal
,
cast
,
overload
import
numpy
as
np
from
numpy
import
broadcast_shapes
,
empty
...
...
@@ -32,6 +32,17 @@ from pytensor.tensor.utils import (
from
pytensor.tensor.variable
import
TensorVariable
def
_squeeze_left
(
x
,
stop_at_dim
:
int
|
None
=
None
):
"""Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
x_dims
=
x
.
type
.
broadcastable
squeeze_ndim
=
len
(
x_dims
)
if
all
(
x_dims
)
else
x_dims
.
index
(
False
)
if
stop_at_dim
is
not
None
:
squeeze_ndim
=
min
(
squeeze_ndim
,
stop_at_dim
)
if
squeeze_ndim
==
0
:
return
x
return
x
.
squeeze
(
axis
=
tuple
(
range
(
squeeze_ndim
)))
def
_vectorize_node_perform
(
core_node
:
Apply
,
batch_bcast_patterns
:
Sequence
[
tuple
[
bool
,
...
]],
...
...
@@ -143,8 +154,6 @@ def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_
class
Blockwise
(
COp
):
"""Generalizes a core `Op` to work with batched dimensions.
TODO: Dispatch JAX (should be easy with the vectorize macro)
TODO: Dispatch Numba
TODO: C implementation?
TODO: Fuse Blockwise?
"""
...
...
@@ -202,21 +211,52 @@ class Blockwise(COp):
super
()
.
__init__
(
**
kwargs
)
def
_create_dummy_core_node
(
self
,
inputs
:
Sequence
[
TensorVariable
])
->
Apply
:
core_input_types
=
[]
@overload
def
_create_dummy_core_node
(
self
,
inputs
:
Sequence
[
TensorVariable
],
*
,
propagate_unbatched_core_inputs
:
bool
=
False
,
return_dummy_inputs
:
Literal
[
False
]
=
...
,
)
->
Apply
:
...
@overload
def
_create_dummy_core_node
(
self
,
inputs
:
Sequence
[
TensorVariable
],
*
,
propagate_unbatched_core_inputs
:
bool
=
False
,
return_dummy_inputs
:
Literal
[
True
]
=
...
,
)
->
tuple
[
Apply
,
list
[
TensorVariable
]]:
...
def
_create_dummy_core_node
(
self
,
inputs
:
Sequence
[
TensorVariable
],
*
,
propagate_unbatched_core_inputs
:
bool
=
False
,
return_dummy_inputs
:
bool
=
False
,
)
->
Apply
|
tuple
[
Apply
,
list
[
TensorVariable
]]:
core_inputs
=
[]
core_dummy_inputs
=
[]
for
i
,
(
inp
,
sig
)
in
enumerate
(
zip
(
inputs
,
self
.
inputs_sig
,
strict
=
True
)):
if
inp
.
type
.
ndim
<
len
(
sig
):
raise
ValueError
(
f
"Input {i} {inp} has insufficient core dimensions for signature {self.signature}"
)
# ndim_supp = 0 case
if
not
sig
:
core_shape
=
()
inp_ndim
=
inp
.
type
.
ndim
batch_ndim
=
inp_ndim
-
len
(
sig
)
core_shape
=
inp
.
type
.
shape
[
batch_ndim
:]
if
propagate_unbatched_core_inputs
and
all
(
inp
.
type
.
broadcastable
[:
batch_ndim
]
):
core_inputs
.
append
(
_squeeze_left
(
inp
,
batch_ndim
))
else
:
core_shape
=
inp
.
type
.
shape
[
-
len
(
sig
)
:]
core_input_types
.
append
(
tensor
(
dtype
=
inp
.
type
.
dtype
,
shape
=
core_shape
))
dummy_inp
=
tensor
(
dtype
=
inp
.
type
.
dtype
,
shape
=
core_shape
)
core_inputs
.
append
(
dummy_inp
)
core_dummy_inputs
.
append
(
dummy_inp
)
core_node
=
self
.
core_op
.
make_node
(
*
core_input
_type
s
)
core_node
=
self
.
core_op
.
make_node
(
*
core_inputs
)
if
len
(
core_node
.
outputs
)
!=
len
(
self
.
outputs_sig
):
raise
ValueError
(
...
...
@@ -230,6 +270,9 @@ class Blockwise(COp):
f
"Output {i} of {self.core_op} has wrong number of core dimensions for signature {self.signature}: {core_out.type.ndim}"
)
if
return_dummy_inputs
:
return
core_node
,
core_dummy_inputs
return
core_node
def
make_node
(
self
,
*
inputs
):
...
...
@@ -298,11 +341,17 @@ class Blockwise(COp):
batch_shape
=
broadcast_shape
(
*
batch_shapes
,
arrays_are_shapes
=
True
)
# Try to extract the core shapes from the core_op
core_op_infer_shape
=
getattr
(
self
.
core_op
,
"infer_shape"
,
None
)
if
core_op_infer_shape
is
not
None
:
dummy_core_node
=
self
.
_create_dummy_core_node
(
node
.
inputs
)
dummy_core_inputs
=
tuple
(
explicit_graph_inputs
(
dummy_core_node
.
inputs
))
def
extract_core_shape_from_infer_shape
():
# Try to extract the core shapes from the core_op
core_op_infer_shape
=
getattr
(
self
.
core_op
,
"infer_shape"
,
None
)
if
core_op_infer_shape
is
None
:
return
[[
None
]
*
out
.
ndim
for
out
in
node
.
outputs
]
dummy_core_node
,
dummy_core_inputs
=
self
.
_create_dummy_core_node
(
node
.
inputs
,
return_dummy_inputs
=
True
,
propagate_unbatched_core_inputs
=
True
,
)
dummy_fgraph
=
FunctionGraph
(
outputs
=
dummy_core_node
.
outputs
,
clone
=
False
)
core_input_shapes
=
[
input_shape
[
batch_ndims
:]
for
input_shape
in
input_shapes
...
...
@@ -311,6 +360,25 @@ class Blockwise(COp):
dummy_fgraph
,
dummy_core_node
,
core_input_shapes
)
# Set to None those core_shapes that depend on dummy_core_inputs,
# meaning their value may not be constant across batch dims of the Blockwise
if
not
dummy_core_inputs
:
# All inputs are unbatched, so the core_shape can be used as is
return
core_output_shapes
else
:
set_dummy_core_inputs
=
set
(
dummy_core_inputs
)
safe_core_output_shapes
=
[
list
(
shape
)
for
shape
in
core_output_shapes
]
for
core_out_shape
in
safe_core_output_shapes
:
for
o
,
core_out_dim
in
enumerate
(
core_out_shape
):
if
set_dummy_core_inputs
&
set
(
explicit_graph_inputs
([
core_out_dim
])
):
core_out_shape
[
o
]
=
None
return
safe_core_output_shapes
safe_core_out_shape
=
None
out_shapes
=
[]
for
o
,
(
output
,
sig
)
in
enumerate
(
zip
(
node
.
outputs
,
self
.
outputs_sig
,
strict
=
True
)
...
...
@@ -321,19 +389,15 @@ class Blockwise(COp):
if
dim_name
in
core_dims
:
core_out_shape
.
append
(
core_dims
[
dim_name
])
else
:
if
core_op_infer_shape
is
not
None
:
# If the input values are needed to compute the dimension length, we can't use the infer_shape
# of the core_node as the value is not constant across batch dims of the Blockwise
core_out_dim
=
core_output_shapes
[
o
][
i
]
if
not
(
set
(
dummy_core_inputs
)
&
set
(
explicit_graph_inputs
([
core_out_dim
]))
):
core_out_shape
.
append
(
core_out_dim
)
continue
# Fallback shape requires evaluating the Blockwise Op
core_out_shape
.
append
(
Shape_i
(
batch_ndims
+
i
)(
output
))
if
safe_core_out_shape
is
None
:
# Extract the core shape from the core_op infer_shape on demand
# For many Ops we never need to do this, because all info is in their signature
safe_core_out_shape
=
extract_core_shape_from_infer_shape
()
if
(
core_out_dim
:
=
safe_core_out_shape
[
o
][
i
])
is
not
None
:
core_out_shape
.
append
(
core_out_dim
)
else
:
# Fallback shape requires evaluating the Blockwise Op
core_out_shape
.
append
(
Shape_i
(
batch_ndims
+
i
)(
output
))
out_shapes
.
append
((
*
batch_shape
,
*
core_out_shape
))
return
out_shapes
...
...
@@ -448,7 +512,10 @@ class Blockwise(COp):
)
return
core_func
(
*
inputs
)
else
:
core_node
=
self
.
_create_dummy_core_node
(
node
.
inputs
)
# type: ignore
core_node
=
self
.
_create_dummy_core_node
(
cast
(
list
[
TensorVariable
],
node
.
inputs
),
propagate_unbatched_core_inputs
=
True
,
)
gufunc
=
_vectorize_node_perform
(
core_node
,
batch_bcast_patterns
=
batch_bcast_patterns
,
...
...
pytensor/tensor/rewriting/blockwise.py
浏览文件 @
a62e785d
...
...
@@ -4,7 +4,7 @@ from pytensor.graph.destroyhandler import inplace_candidates
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
out2in
from
pytensor.tensor.basic
import
Alloc
,
ARange
,
alloc
,
shape_padleft
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.blockwise
import
Blockwise
,
_squeeze_left
from
pytensor.tensor.math
import
Dot
from
pytensor.tensor.rewriting.basic
import
(
register_canonicalize
,
...
...
@@ -90,17 +90,6 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
return
local_useless_unbatched_blockwise
.
fn
(
fgraph
,
node
)
def
_squeeze_left
(
x
,
stop_at_dim
:
int
|
None
=
None
):
"""Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
x_dims
=
x
.
type
.
broadcastable
squeeze_ndim
=
len
(
x_dims
)
if
all
(
x_dims
)
else
x_dims
.
index
(
False
)
if
stop_at_dim
is
not
None
:
squeeze_ndim
=
min
(
squeeze_ndim
,
stop_at_dim
)
if
squeeze_ndim
==
0
:
return
x
return
x
.
squeeze
(
axis
=
tuple
(
range
(
squeeze_ndim
)))
@register_specialize
(
"shape_unsafe"
)
@node_rewriter
([
Blockwise
])
def
local_blockwise_alloc
(
fgraph
,
node
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论