Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2e2c871e
提交
2e2c871e
authored
12月 05, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
12月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add rewrite for Blockwise with Alloc inputs
Also prevent Alloc from constant_folding when it's used by Elemwise and Blockwise to avoid creating useless large arrays
上级
fe06ee32
显示空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
220 行增加
和
8 行删除
+220
-8
basic.py
pytensor/graph/basic.py
+4
-0
basic.py
pytensor/tensor/basic.py
+11
-4
blockwise.py
pytensor/tensor/rewriting/blockwise.py
+121
-2
test_blockwise.py
tests/tensor/rewriting/test_blockwise.py
+84
-2
没有找到文件。
pytensor/graph/basic.py
浏览文件 @
2e2c871e
...
...
@@ -1777,6 +1777,7 @@ def equal_computations(
ys
:
list
[
Union
[
np
.
ndarray
,
Variable
]],
in_xs
:
Optional
[
list
[
Variable
]]
=
None
,
in_ys
:
Optional
[
list
[
Variable
]]
=
None
,
strict_dtype
=
True
,
)
->
bool
:
"""Checks if PyTensor graphs represent the same computations.
...
...
@@ -1908,6 +1909,9 @@ def equal_computations(
if
dx
!=
dy
:
if
isinstance
(
dx
,
Constant
)
and
isinstance
(
dy
,
Constant
):
if
not
dx
.
equals
(
dy
):
if
strict_dtype
:
return
False
elif
not
np
.
array_equal
(
dx
.
data
,
dy
.
data
):
return
False
else
:
return
False
...
...
pytensor/tensor/basic.py
浏览文件 @
2e2c871e
...
...
@@ -42,6 +42,7 @@ from pytensor.tensor import (
as_tensor_variable
,
get_vector_length
,
)
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
,
scalar_elemwise
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.shape
import
(
...
...
@@ -1658,16 +1659,22 @@ class Alloc(COp):
if
not
clients
:
return
False
for
client
in
clients
:
if
client
[
0
]
==
"output"
:
for
client
,
idx
in
clients
:
if
client
==
"output"
:
# If the output is a constant, it will have to be deepcopied
# each time the function is called. So we do not fold.
return
False
# Allow alloc to be lifted out of Elemwise before constant folding it
elif
isinstance
(
client
.
op
,
Elemwise
):
return
None
# Same for Blockwise, unless it has no batch_dims
elif
isinstance
(
client
.
op
,
Blockwise
)
and
client
.
op
.
batch_ndim
(
client
):
return
None
elif
(
# The following ops work inplace of their input id 0.
client
[
1
]
==
0
idx
==
0
and
isinstance
(
client
[
0
]
.
op
,
client
.
op
,
(
# Ops that will work inplace on the Alloc. So if they
# get constant_folded, they would copy the
...
...
pytensor/tensor/rewriting/blockwise.py
浏览文件 @
2e2c871e
from
typing
import
Optional
from
pytensor.compile.mode
import
optdb
from
pytensor.graph
import
node_rewriter
from
pytensor.graph
import
Constant
,
node_rewriter
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
out2in
from
pytensor.tensor.basic
import
Alloc
,
ARange
,
shape_padleft
from
pytensor.tensor.basic
import
Alloc
,
ARange
,
alloc
,
shape_padleft
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.math
import
Dot
from
pytensor.tensor.rewriting.basic
import
(
...
...
@@ -80,3 +82,120 @@ def local_eager_useless_unbatched_blockwise(fgraph, node):
),
):
return
local_useless_unbatched_blockwise
.
fn
(
fgraph
,
node
)
def
_squeeze_left
(
x
,
stop_at_dim
:
Optional
[
int
]
=
None
):
"""Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached."""
x_dims
=
x
.
type
.
broadcastable
squeeze_ndim
=
len
(
x_dims
)
if
all
(
x_dims
)
else
x_dims
.
index
(
False
)
if
stop_at_dim
is
not
None
:
squeeze_ndim
=
min
(
squeeze_ndim
,
stop_at_dim
)
if
squeeze_ndim
==
0
:
return
x
return
x
.
squeeze
(
axis
=
tuple
(
range
(
squeeze_ndim
)))
@register_specialize
(
"shape_unsafe"
)
@node_rewriter
([
Blockwise
])
def
local_blockwise_alloc
(
fgraph
,
node
):
"""Push Allocs from the inputs to the output of Blockwise Ops.
BOp = Blockwise(Op, signature="(x),(x)->(x)")
BOp(vector, alloc(vector, 10, 5)) -> alloc(BOp)(vector, vector), 10, 5)
BOp(vector, alloc(scalar, 10, 5)) -> alloc(BOp)(vector, alloc(scalar, 5), 10, 5)
BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector)
"""
if
not
any
(
isinstance
(
inp
.
owner
.
op
,
Alloc
)
for
inp
in
node
.
inputs
if
inp
.
owner
):
return
None
op
:
Blockwise
=
node
.
op
# type: ignore
batch_ndim
=
op
.
batch_ndim
(
node
)
if
not
batch_ndim
:
return
None
new_inputs
=
[]
batch_shapes
=
[]
can_push_any_alloc
=
False
for
inp
,
inp_sig
in
zip
(
node
.
inputs
,
op
.
inputs_sig
):
if
inp
.
owner
and
isinstance
(
inp
.
owner
.
op
,
Alloc
):
# Push batch dims from Alloc
value
,
*
shape
=
inp
.
owner
.
inputs
# Check what to do with the value of the Alloc
squeezed_value
=
_squeeze_left
(
value
,
batch_ndim
)
missing_ndim
=
len
(
shape
)
-
value
.
type
.
ndim
if
(
((
1
,)
*
missing_ndim
+
value
.
type
.
broadcastable
)[
batch_ndim
:]
)
!=
inp
.
type
.
broadcastable
[
batch_ndim
:]:
# We still need an Alloc for the core dims
core_shape
=
shape
[
batch_ndim
:]
# And the batch dims of the squeezed value
squeezed_value_batch_ndim
=
squeezed_value
.
type
.
ndim
-
len
(
core_shape
)
batch_shape
=
[
1
if
broadcastable
else
dim
for
broadcastable
,
dim
in
zip
(
squeezed_value
.
type
.
broadcastable
[:
squeezed_value_batch_ndim
],
tuple
(
squeezed_value
.
shape
)[:
squeezed_value_batch_ndim
],
)
]
squeezed_value
=
alloc
(
squeezed_value
,
*
batch_shape
,
*
core_shape
)
if
squeezed_value
.
type
.
broadcastable
==
inp
.
type
.
broadcastable
:
# We can't change anything about this Alloc input
new_inputs
.
append
(
inp
)
continue
# We can push batch dims of this Alloc input
batch_shapes
.
append
(
tuple
(
1
if
broadcastable
else
dim
for
broadcastable
,
dim
in
zip
(
inp
.
type
.
broadcastable
,
shape
[:
batch_ndim
]
)
)
)
new_inputs
.
append
(
squeezed_value
)
can_push_any_alloc
=
True
else
:
# Nothing to do with this input other than removing dummy batch dims
new_inputs
.
append
(
_squeeze_left
(
inp
,
batch_ndim
))
if
not
can_push_any_alloc
:
return
None
new_outs
=
node
.
op
.
make_node
(
*
new_inputs
)
.
outputs
new_out_type
=
new_outs
[
0
]
.
type
old_out_type
=
node
.
outputs
[
0
]
.
type
if
new_out_type
.
broadcastable
!=
old_out_type
.
broadcastable
:
# An Alloc is still needed to broadcast the new output to the original shape
# We pick the most parsimonious batch dim from the pushed Alloc
missing_ndim
=
old_out_type
.
ndim
-
new_out_type
.
ndim
batch_shape
=
([
1
]
*
missing_ndim
+
list
(
new_outs
[
0
]
.
shape
))[:
batch_ndim
]
for
i
,
batch_dims
in
enumerate
(
zip
(
*
batch_shapes
)):
# Transpose shape tuples
for
batch_dim
in
batch_dims
:
if
batch_dim
==
1
:
continue
if
isinstance
(
batch_dim
,
Constant
):
# Give preference to Constants
batch_shape
[
i
]
=
batch_dim
break
elif
old_out_type
.
broadcastable
[
i
]:
# Only use non Constant shapes if absolutely necessary
# Otherwise, we use the shape of the non-alloc output
batch_shape
[
i
]
=
batch_dim
copy_stack_trace
(
node
.
outputs
,
new_outs
)
new_outs
=
[
alloc
(
new_out
,
*
batch_shape
,
*
tuple
(
new_out
.
shape
)[
batch_ndim
-
missing_ndim
:],
)
for
new_out
in
new_outs
]
assert
new_outs
[
0
]
.
type
.
broadcastable
==
old_out_type
.
broadcastable
copy_stack_trace
(
node
.
outputs
,
new_outs
)
return
new_outs
tests/tensor/rewriting/test_blockwise.py
浏览文件 @
2e2c871e
from
functools
import
partial
from
pytensor
import
function
from
pytensor.graph
import
FunctionGraph
from
pytensor.graph
import
FunctionGraph
,
rewrite_graph
from
pytensor.graph.basic
import
equal_computations
from
pytensor.scalar
import
log
as
scalar_log
from
pytensor.tensor
import
matrix
,
tensor3
from
pytensor.tensor
import
add
,
alloc
,
matrix
,
tensor
,
tensor3
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.nlinalg
import
MatrixPinv
...
...
@@ -36,3 +39,82 @@ def test_useless_unbatched_blockwise():
fn
=
function
([
x
],
out
,
mode
=
"FAST_COMPILE"
)
assert
isinstance
(
fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
,
Blockwise
)
assert
isinstance
(
fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
.
core_op
,
MatrixPinv
)
def
test_blockwise_alloc
():
rewrite
=
partial
(
rewrite_graph
,
include
=
(
"ShapeOpt"
,
"specialize"
),
exclude
=
(
"local_useless_unbatched_blockwise"
,),
)
vector_add
=
Blockwise
(
core_op
=
add
,
signature
=
"(x),(x)->(x)"
)
# Depending on the rewrites the Alloc shape may be upcast to int64 or not
# We do not care about that for the purposes of this test
equal
=
partial
(
equal_computations
,
strict_dtype
=
False
)
# Case where Alloc is not necessary
x
=
tensor
(
"x"
,
shape
=
(
7
,
5
))
y
=
tensor
(
"y"
,
shape
=
(
5
,))
out
=
vector_add
(
x
,
alloc
(
y
,
7
,
5
))
expected_out
=
vector_add
(
x
,
y
)
assert
equal
([
rewrite
(
out
)],
[
expected_out
])
# Cases where Alloc can be fully pushed
x
=
tensor
(
"x"
,
shape
=
(
5
,))
y
=
tensor
(
"y"
,
shape
=
(
5
,))
out
=
vector_add
(
x
,
alloc
(
y
,
7
,
5
))
expected_out
=
alloc
(
vector_add
(
x
,
y
),
7
,
5
)
assert
equal
([
rewrite
(
out
)],
[
expected_out
])
x
=
tensor
(
"x"
,
shape
=
(
1
,
5
))
y
=
tensor
(
"y"
,
shape
=
(
5
,))
out
=
vector_add
(
x
,
alloc
(
y
,
7
,
5
))
expected_out
=
alloc
(
vector_add
(
x
.
squeeze
(
0
),
y
),
7
,
5
)
assert
equal
([
rewrite
(
out
)],
[
expected_out
])
x
=
tensor
(
"x"
,
shape
=
(
7
,
5
))
y
=
tensor
(
"y"
,
shape
=
(
7
,
5
))
out
=
vector_add
(
x
,
alloc
(
y
,
3
,
7
,
5
))
expected_out
=
alloc
(
vector_add
(
x
,
y
),
3
,
7
,
5
)
assert
equal
([
rewrite
(
out
)],
[
expected_out
])
x
=
tensor
(
"x"
,
shape
=
(
5
,))
y
=
tensor
(
"y"
,
shape
=
(
7
,
1
,
5
))
out
=
vector_add
(
x
,
alloc
(
y
,
7
,
2
,
5
))
expected_out
=
alloc
(
vector_add
(
x
,
y
),
7
,
2
,
5
)
assert
equal
([
rewrite
(
out
)],
[
expected_out
])
# Case where Alloc can be partially pushed
x
=
tensor
(
"x"
,
shape
=
(
5
,))
y
=
tensor
(
"y"
,
shape
=
())
out
=
vector_add
(
x
,
alloc
(
y
,
7
,
5
))
expected_out
=
alloc
(
vector_add
(
x
,
alloc
(
y
,
5
)),
7
,
5
)
assert
equal
([
rewrite
(
out
)],
[
expected_out
])
x
=
tensor
(
"x"
,
shape
=
(
5
,))
y
=
tensor
(
"y"
,
shape
=
(
7
,
1
,
1
))
out
=
vector_add
(
x
,
alloc
(
y
,
7
,
2
,
5
))
expected_out
=
alloc
(
vector_add
(
x
,
alloc
(
y
,
7
,
1
,
5
)),
7
,
2
,
5
)
assert
equal
([
rewrite
(
out
)],
[
expected_out
],
strict_dtype
=
False
)
# Cases involving multiple Allocs being pushed
x
=
tensor
(
"x"
,
shape
=
())
y
=
tensor
(
"y"
,
shape
=
())
out
=
vector_add
(
alloc
(
x
,
3
,
1
,
5
),
alloc
(
y
,
7
,
5
))
expected_out
=
alloc
(
vector_add
(
alloc
(
x
,
5
),
alloc
(
y
,
5
)),
3
,
7
,
5
)
assert
equal
([
rewrite
(
out
)],
[
expected_out
])
x
=
tensor
(
"x"
,
shape
=
(
5
,))
y
=
tensor
(
"y"
,
shape
=
())
out
=
vector_add
(
alloc
(
x
,
3
,
1
,
5
),
alloc
(
y
,
7
,
5
))
expected_out
=
alloc
(
vector_add
(
x
,
alloc
(
y
,
5
)),
3
,
7
,
5
)
assert
equal
([
rewrite
(
out
)],
[
expected_out
])
# Case where Alloc cannot be pushed
x
=
tensor
(
"x"
,
shape
=
(
5
,))
y
=
tensor
(
"y"
,
shape
=
(
1
,))
out
=
vector_add
(
x
,
alloc
(
y
,
5
))
expected_out
=
out
assert
equal
([
rewrite
(
out
)],
[
expected_out
])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论