Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d5a054d1
提交
d5a054d1
authored
1月 20, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 09, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Generalize lift of Subtensor over Elemwise
Split off Subtensor of Unbroadcast into its own rewrite
上级
f1db1bd6
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
148 行增加
和
169 行删除
+148
-169
subtensor_lift.py
pytensor/tensor/rewriting/subtensor_lift.py
+63
-57
test_subtensor_lift.py
tests/tensor/rewriting/test_subtensor_lift.py
+85
-112
没有找到文件。
pytensor/tensor/rewriting/subtensor_lift.py
浏览文件 @
d5a054d1
...
...
@@ -108,73 +108,79 @@ def local_subtensor_of_dot(fgraph, node):
return
[
r
]
# fast_compile to allow opt subtensor(cast{float32}(make_vector)
)
@register_
canonicalize
(
"fast_compil
e"
)
@register_canonicalize
(
"shape_unsafe"
)
@register_
specialize
(
"shape_unsaf
e"
)
@node_rewriter
([
Subtensor
])
def
local_subtensor_lift
(
fgraph
,
node
):
def
local_subtensor_of_elemwise
(
fgraph
,
node
):
"""Lift a Subtensor through an Elemwise and its implicit broadcasting behavior.
exp(x)[:, 0] -> exp(x[:, 0])
add(x, y)[0] -> add(x[0], y[0])
add(x[None], y)[2] -> add(x, y[2])
"""
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
elem
,
*
idx
=
node
.
inputs
Handles the following unary ops:
elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all
if
not
(
elem
.
owner
and
isinstance
(
elem
.
owner
.
op
,
Elemwise
)):
return
None
"""
if
isinstance
(
node
.
op
,
Subtensor
):
u
=
node
.
inputs
[
0
]
if
u
.
owner
is
None
or
len
(
fgraph
.
clients
[
u
])
>
1
:
return
False
if
len
(
fgraph
.
clients
[
elem
])
>
1
:
# Elemwise output is used beyond the Subtensor.
# Get out to avoid repeated computations
return
None
if
isinstance
(
u
.
owner
.
op
,
Elemwise
)
and
len
(
u
.
owner
.
inputs
)
==
1
:
idx
=
node
.
inputs
[
1
:]
x_idx
=
node
.
op
(
u
.
owner
.
inputs
[
0
],
*
idx
)
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
,
x_idx
)
ret
=
u
.
owner
.
op
(
x_idx
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
idx_tuple
=
indices_from_subtensor
(
idx
,
node
.
op
.
idx_list
)
elem_inputs
=
elem
.
owner
.
inputs
elem_bcast
=
elem
.
type
.
broadcastable
if
all
(
inp
.
type
.
broadcastable
==
elem_bcast
for
inp
in
elem_inputs
):
# No need to worry about implicit broadcasting.
indexed_inputs
=
[
inp
[
idx_tuple
]
for
inp
in
elem_inputs
]
if
isinstance
(
u
.
owner
.
op
,
Elemwise
):
new_inputs
=
[]
if
all
(
sum
(
i
.
type
.
broadcastable
)
==
0
for
i
in
u
.
owner
.
inputs
):
# There is no broadcastable in the inputs
idx
=
node
.
inputs
[
1
:]
new_inputs
=
[
node
.
op
(
i
,
*
idx
)
for
i
in
u
.
owner
.
inputs
]
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
ret
=
u
.
owner
.
op
(
*
new_inputs
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
elif
all
(
sum
(
i
.
type
.
broadcastable
)
in
[
i
.
ndim
,
0
]
for
i
in
u
.
owner
.
inputs
):
# There is no broadcastable in the inputs or it is scalar
idx
=
node
.
inputs
[
1
:]
new_inputs
=
[]
for
i
in
u
.
owner
.
inputs
:
if
sum
(
i
.
type
.
broadcastable
)
==
0
:
new_inputs
.
append
(
node
.
op
(
i
,
*
idx
))
else
:
# If the subtensor remove some dims, we must
# lower the number of dimensions of this scalar.
if
node
.
outputs
[
0
]
.
ndim
==
i
.
ndim
:
new_inputs
.
append
(
i
)
else
:
new_inputs
.
append
(
i
.
dimshuffle
([
"x"
]
*
node
.
outputs
[
0
]
.
ndim
)
# The original indices may not make sense on some of the broadcasted dimensions
new_idxs
=
[
list
(
idx_tuple
)
for
_
in
elem_inputs
]
for
dim
,
(
dim_idx
,
dim_bcast_out
,
*
dim_bcast_inputs
)
in
enumerate
(
zip
(
idx_tuple
,
elem_bcast
,
*
(
inp
.
type
.
broadcastable
for
inp
in
elem_inputs
),
# Indices can be shorter than input ndims
strict
=
False
,
)
):
if
is_full_slice
(
dim_idx
):
# Full slice can be safely applied to all inputs
continue
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
if
all
(
dim_bcast_inp
==
elem_bcast
for
dim_bcast_inp
in
dim_bcast_inputs
):
# This dim is not broadcasted for any of the inputs, original index can be applied to all inputs
continue
ret
=
u
.
owner
.
op
(
*
new_inputs
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
# Some dims are broadcasted, so we need to adapt their indices
# Slice indexing keeps the dimension, so we use a full slice for broadcasted inputs
# Integer indexing drops the dimension, so we index by zero for the broadcsated inputs
safe_bcast_dim_idx
=
slice
(
None
)
if
isinstance
(
dim_idx
,
slice
)
else
0
for
inp_idx
,
dim_bcast_inp
in
zip
(
new_idxs
,
dim_bcast_inputs
,
strict
=
True
):
if
dim_bcast_inp
:
inp_idx
[
dim
]
=
safe_bcast_dim_idx
indexed_inputs
=
[
inp
[
tuple
(
new_idx
)]
for
inp
,
new_idx
in
zip
(
elem_inputs
,
new_idxs
,
strict
=
True
)
]
[
old_out
]
=
node
.
outputs
# Copy stack trace to new inputs
[
copy_stack_trace
(
old_out
,
new_inp
)
for
new_inp
in
indexed_inputs
]
# Define elemwise operation on indexed inputs
new_out
=
elem
.
owner
.
op
(
*
indexed_inputs
)
# Copy stack trace to new output
copy_stack_trace
([
old_out
,
*
node
.
inputs
],
new_out
)
return
[
new_out
]
@register_canonicalize
(
"shape_unsafe"
)
...
...
tests/tensor/rewriting/test_subtensor_lift.py
浏览文件 @
d5a054d1
import
numpy
as
np
import
pytest
import
unittest_tools
as
utt
from
pytensor
import
(
Mode
,
...
...
@@ -25,13 +24,11 @@ from pytensor.printing import debugprint
from
pytensor.tensor
import
(
add
,
exp
,
inplace
,
iscalar
,
iscalars
,
lscalar
,
lscalars
,
matrix
,
scalar
,
shape
,
slicetype
,
specify_shape
,
...
...
@@ -43,6 +40,7 @@ from pytensor.tensor.basic import MakeVector, expand_dims, make_vector
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.rewriting.subtensor_lift
import
(
local_subtensor_make_vector
,
local_subtensor_of_elemwise
,
local_subtensor_shape_constant
,
)
from
pytensor.tensor.shape
import
SpecifyShape
,
_shape
...
...
@@ -58,22 +56,8 @@ mode_opt = get_mode(mode_opt)
NO_OPTIMIZATION_MODE
=
Mode
(
linker
=
"py"
,
optimizer
=
None
)
class
TestLocalSubtensorLift
:
def
test_basic
(
self
):
# basic test that the Op works
x
=
matrix
(
"x"
)
f
=
function
([
x
],
exp
(
x
)[
0
],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
"all"
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
# first subtensor
assert
prog
[
1
]
.
op
==
exp
assert
len
(
prog
)
==
2
f
([[
0
,
1
],
[
2
,
3
]])
# let debugmode test something
def
test_basic_1
(
self
):
class
TestLocalSubtensorOfElemwise
:
def
test_unary_multiple_clients
(
self
):
# as test0, but we reuse the output of the elemwise
# So we should not lift the subtensor
x
=
matrix
(
"x"
)
...
...
@@ -87,85 +71,16 @@ class TestLocalSubtensorLift:
assert
isinstance
(
prog
[
1
]
.
op
,
Subtensor
)
# first subtensor
assert
isinstance
(
prog
[
2
]
.
op
,
DeepCopyOp
)
assert
len
(
prog
)
==
3
f
([[
0
,
1
],
[
2
,
3
]])
# let debugmode test something
def
test_basic_2
(
self
):
# basic test that the optimization work with scalar broadcasted
x
=
matrix
(
"x"
)
y
=
scalar
(
"y"
)
z
=
matrix
(
"z"
)
f
=
function
([
x
,
y
,
z
],
exp
(
x
+
y
+
z
)[
0
],
mode
=
mode_opt
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
1
]
.
op
,
DimShuffle
)
assert
isinstance
(
prog
[
2
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
3
]
.
op
.
scalar_op
,
ps
.
Composite
)
# Composite{add,add}
assert
len
(
prog
)
==
4
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
[
Subtensor
])
# let debugmode test something
f
([[
0
,
1
],
[
2
,
3
]],
4
,
[[
4
,
5
],
[
6
,
7
]])
def
test_basic_3
(
self
):
# as 1, but take a slice
x
=
matrix
(
"x"
)
y
=
scalar
(
"y"
)
z
=
matrix
(
"z"
)
f
=
function
([
x
,
y
,
z
],
exp
(
x
+
y
+
z
)[
0
:
2
],
mode
=
mode_opt
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
1
]
.
op
,
DimShuffle
)
assert
isinstance
(
prog
[
2
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
3
]
.
op
.
scalar_op
,
ps
.
Composite
)
# Composite{add,add}
assert
len
(
prog
)
==
4
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
[
Subtensor
])
# let debugmode test something
f
([[
0
,
1
],
[
2
,
3
]],
4
,
[[
4
,
5
],
[
6
,
7
]])
def
test_basic_4
(
self
):
# basic test that the optimization does work with broadcasting
# for unary elemwise.
y
=
vector
(
"y"
)
f
=
function
([
y
],
exp
(
y
.
dimshuffle
(
0
,
"x"
))[
0
],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
"all"
)
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
1
]
.
op
,
DimShuffle
)
assert
prog
[
2
]
.
op
==
exp
assert
len
(
prog
)
==
3
f
([
4
,
5
])
# let debugmode test something
@utt.assertFailure_fast
def
test_basic_5
(
self
):
# basic test that the optimization doesn't work with broadcasting
# ... It *could* be extended to,
# ... but right now it doesn't, so it shouldn't try.
x
=
matrix
(
"x"
)
y
=
vector
(
"y"
)
f
=
function
([
x
,
y
],
exp
(
x
+
y
)[
0
],
mode
=
mode_opt
)
# Opt doesn't apply, so no need for check_stack_trace
# assert check_stack_trace(f, ops_to_check='all')
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
DimShuffle
)
assert
prog
[
1
]
.
op
==
add
assert
isinstance
(
prog
[
2
]
.
op
,
Subtensor
)
# first subtensor
assert
prog
[
3
]
.
op
==
inplace
.
exp_inplace
assert
len
(
prog
)
==
4
f
([[
0
,
1
],
[
2
,
3
]],
[
4
,
5
])
# let debugmode test something
x_test
=
[[
0
,
1
],
[
2
,
3
]]
res1
,
res2
=
f
(
x_test
)
np
.
testing
.
assert_allclose
(
res1
,
np
.
exp
(
x_test
)[
0
],
)
np
.
testing
.
assert_allclose
(
res2
,
np
.
exp
(
x_test
))
def
test_
basic_6
(
self
):
def
test_
multinary_multiple_clients
(
self
):
# test that we don't lift when we reuse the output of the
# elemwise for other computation.
x
=
matrix
(
"x"
)
...
...
@@ -181,26 +96,84 @@ class TestLocalSubtensorLift:
# first subtensor
assert
isinstance
(
prog
[
2
]
.
op
,
Subtensor
)
assert
len
(
prog
)
==
3
f
([[
0
,
1
],
[
2
,
3
]],
[
4
,
5
])
# let debugmode test something
def
test_basic_7
(
self
):
# basic test that the optimization works with a scalar as input,
# and a scalar as output (no broadcasting of the scalar needed).
# The optimization used to fail and display an ERROR message.
x_test
=
np
.
array
([[
0
,
1
],
[
2
,
3
]])
.
astype
(
x
.
dtype
)
y_test
=
np
.
array
([
4
,
5
])
.
astype
(
y
.
dtype
)
res1
,
res2
=
f
(
x_test
,
y_test
)
np
.
testing
.
assert_allclose
(
res1
,
np
.
exp
(
x_test
+
y_test
)[
0
],
)
np
.
testing
.
assert_allclose
(
res2
,
np
.
exp
(
x_test
+
y_test
)
+
x_test
,
)
@pytest.mark.parametrize
(
"original_fn, expected_fn"
,
[
# Unary integer indexing
(
lambda
x
,
y
:
exp
(
x
)[
0
],
lambda
x
,
y
:
exp
(
x
[
0
])),
# Unary integer with expand_dims
(
lambda
x
,
y
:
exp
(
x
[:,
None
])[
0
],
lambda
x
,
y
:
exp
(
x
[
0
][
None
])),
# Integer indexing on non-broadcastable dimension
(
lambda
x
,
y
:
add
(
x
,
y
)[
0
],
lambda
x
,
y
:
add
(
x
[
0
],
y
[
0
])),
# Slice indexing on non-broadcastable dimension
(
lambda
x
,
y
:
add
(
x
,
y
)[
1
:],
lambda
x
,
y
:
add
(
x
[
1
:],
y
[
1
:])),
# Integer indexing on broacastable dimension
(
lambda
x
,
y
:
add
(
x
[
None
],
y
[
None
])[
0
],
lambda
x
,
y
:
add
(
x
,
y
)),
(
lambda
x
,
y
:
add
(
x
[
None
],
y
[
None
])[
0
,
1
],
lambda
x
,
y
:
add
(
x
[
1
],
y
[
1
])),
(
lambda
x
,
y
:
add
(
x
[
None
,
:],
y
[:,
None
])[
2
],
lambda
x
,
y
:
add
(
x
,
y
[
2
][
None
]),
),
(
lambda
x
,
y
:
add
(
x
[:,
None
],
y
[
None
,
:])[:,
2
],
lambda
x
,
y
:
add
(
x
,
y
[
2
][
None
]),
),
# Slice indexing on broadcastable dimension
(
lambda
x
,
y
:
add
(
x
[
None
],
y
[
None
])[
1
:],
lambda
x
,
y
:
add
(
x
[
None
][
1
:],
y
[
None
][
1
:]),
),
(
lambda
x
,
y
:
add
(
x
[
None
,
:],
y
[:,
None
])[
1
:],
lambda
x
,
y
:
add
(
x
[
None
,
:],
y
[
1
:][:,
None
]),
),
],
)
def
test_local_subtensor_of_elemwise
(
self
,
original_fn
,
expected_fn
):
rng
=
np
.
random
.
default_rng
(
257
)
x
=
pt
.
matrix
(
"x"
,
shape
=
(
5
,
3
))
y
=
pt
.
matrix
(
"y"
,
shape
=
(
5
,
3
))
x_test
=
rng
.
normal
(
size
=
x
.
type
.
shape
)
.
astype
(
x
.
dtype
)
y_test
=
rng
.
normal
(
size
=
y
.
type
.
shape
)
.
astype
(
y
.
dtype
)
out
=
original_fn
(
x
,
y
)
expected_opt_out
=
expected_fn
(
x
,
y
)
opt_out
=
rewrite_graph
(
out
)
assert
equal_computations
([
opt_out
],
[
expected_opt_out
]),
debugprint
(
[
expected_opt_out
,
opt_out
],
print_type
=
True
)
eval_kwargs
=
dict
(
mode
=
NO_OPTIMIZATION_MODE
,
on_unused_input
=
"ignore"
)
np
.
testing
.
assert_allclose
(
opt_out
.
eval
({
x
:
x_test
,
y
:
y_test
},
**
eval_kwargs
),
out
.
eval
({
x
:
x_test
,
y
:
y_test
},
**
eval_kwargs
),
)
x
=
vector
(
"x"
)
y
=
scalar
(
"y"
)
f
=
function
([
x
,
y
],
exp
(
x
+
y
)[
0
],
mode
=
mode_opt
)
def
test_local_subtensor_of_elemwise_multiple_clients
(
self
):
x
=
pt
.
matrix
(
"x"
,
shape
=
(
5
,
3
))
y
=
pt
.
matrix
(
"y"
,
shape
=
(
5
,
3
))
out1
=
add
(
x
,
y
)
out2
=
out1
[
0
]
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f
,
ops_to_check
=
Subtensor
)
# Rewrite should fail when another node uses out1 directly (in this case it's an extra output)
fgraph
=
FunctionGraph
([
x
,
y
],
[
out1
,
out2
],
clone
=
False
)
assert
local_subtensor_of_elemwise
.
transform
(
fgraph
,
out2
.
owner
)
is
None
prog
=
f
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
# Composite{add,exp}
assert
isinstance
(
prog
[
1
]
.
op
.
scalar_op
,
ps
.
Composite
)
assert
len
(
prog
)
==
2
f
([
1
,
2
,
3
],
4
)
# let debugmode test something
# Otherwise it should work
fgraph
.
remove_output
(
0
)
assert
local_subtensor_of_elemwise
.
transform
(
fgraph
,
out2
.
owner
)
is
not
None
@pytest.mark.parametrize
(
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论