Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
ed6ca162
提交
ed6ca162
authored
1月 06, 2024
作者:
jessegrabowski
提交者:
Ricardo Vieira
10月 11, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Inplace Blockwise and core versions of Cholesky and Solve Ops.
Co-authored-by:
Ricardo Vieira
<
28983449+ricardov94@users.noreply.github.com
>
上级
b8dbd4ca
全部展开
显示空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
215 行增加
和
7 行删除
+215
-7
op.py
pytensor/graph/op.py
+6
-0
blockwise.py
pytensor/tensor/blockwise.py
+10
-0
blockwise.py
pytensor/tensor/rewriting/blockwise.py
+80
-2
slinalg.py
pytensor/tensor/slinalg.py
+0
-0
test_blockwise.py
tests/tensor/test_blockwise.py
+114
-3
test_slinalg.py
tests/tensor/test_slinalg.py
+5
-2
没有找到文件。
pytensor/graph/op.py
浏览文件 @
ed6ca162
...
@@ -583,6 +583,12 @@ class Op(MetaObject):
...
@@ -583,6 +583,12 @@ class Op(MetaObject):
)
)
return
self
.
make_py_thunk
(
node
,
storage_map
,
compute_map
,
no_recycling
)
return
self
.
make_py_thunk
(
node
,
storage_map
,
compute_map
,
no_recycling
)
def
inplace_on_inputs
(
self
,
allowed_inplace_inputs
:
list
[
int
])
->
"Op"
:
"""Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`."""
# TODO: Document this in the Create your own Op docs
# By default, do nothing
return
self
def
__str__
(
self
):
def
__str__
(
self
):
return
getattr
(
type
(
self
),
"__name__"
,
super
()
.
__str__
())
return
getattr
(
type
(
self
),
"__name__"
,
super
()
.
__str__
())
...
...
pytensor/tensor/blockwise.py
浏览文件 @
ed6ca162
...
@@ -45,6 +45,7 @@ class Blockwise(Op):
...
@@ -45,6 +45,7 @@ class Blockwise(Op):
signature
:
str
|
None
=
None
,
signature
:
str
|
None
=
None
,
name
:
str
|
None
=
None
,
name
:
str
|
None
=
None
,
gufunc_spec
:
tuple
[
str
,
int
,
int
]
|
None
=
None
,
gufunc_spec
:
tuple
[
str
,
int
,
int
]
|
None
=
None
,
destroy_map
=
None
,
**
kwargs
,
**
kwargs
,
):
):
"""
"""
...
@@ -79,6 +80,15 @@ class Blockwise(Op):
...
@@ -79,6 +80,15 @@ class Blockwise(Op):
self
.
inputs_sig
,
self
.
outputs_sig
=
_parse_gufunc_signature
(
signature
)
self
.
inputs_sig
,
self
.
outputs_sig
=
_parse_gufunc_signature
(
signature
)
self
.
gufunc_spec
=
gufunc_spec
self
.
gufunc_spec
=
gufunc_spec
self
.
_gufunc
=
None
self
.
_gufunc
=
None
if
destroy_map
is
not
None
:
self
.
destroy_map
=
destroy_map
if
self
.
destroy_map
!=
core_op
.
destroy_map
:
# Note: Should be fine for destroy_map of Blockwise to be more extensive than that of core_op
# But we are not using that anywhere yet, so this check is fine for now
raise
ValueError
(
f
"Blockwise destroy_map {self.destroy_map} must be the same as that of the core_op {core_op} {core_op.destroy_map}"
)
super
()
.
__init__
(
**
kwargs
)
super
()
.
__init__
(
**
kwargs
)
def
__getstate__
(
self
):
def
__getstate__
(
self
):
...
...
pytensor/tensor/rewriting/blockwise.py
浏览文件 @
ed6ca162
import
itertools
from
pytensor.compile
import
Supervisor
from
pytensor.compile.mode
import
optdb
from
pytensor.compile.mode
import
optdb
from
pytensor.graph
import
Constant
,
node_rewriter
from
pytensor.graph
import
Constant
,
node_rewriter
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
out2in
from
pytensor.graph.rewriting.basic
import
copy_stack_trace
,
in2out
,
out2in
from
pytensor.tensor.basic
import
Alloc
,
ARange
,
alloc
,
shape_padleft
from
pytensor.tensor.basic
import
Alloc
,
ARange
,
alloc
,
shape_padleft
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.math
import
Dot
from
pytensor.tensor.math
import
Dot
...
@@ -50,13 +53,14 @@ def local_useless_unbatched_blockwise(fgraph, node):
...
@@ -50,13 +53,14 @@ def local_useless_unbatched_blockwise(fgraph, node):
# We register this rewrite late, so that other rewrites need only target Blockwise Ops
# We register this rewrite late, so that other rewrites need only target Blockwise Ops
# We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops
optdb
.
register
(
optdb
.
register
(
"local_useless_unbatched_blockwise"
,
"local_useless_unbatched_blockwise"
,
out2in
(
local_useless_unbatched_blockwise
,
ignore_newtrees
=
True
),
out2in
(
local_useless_unbatched_blockwise
,
ignore_newtrees
=
True
),
"fast_run"
,
"fast_run"
,
"fast_compile"
,
"fast_compile"
,
"blockwise"
,
"blockwise"
,
position
=
49
,
position
=
60
,
)
)
...
@@ -225,3 +229,77 @@ def local_blockwise_reshape(fgraph, node):
...
@@ -225,3 +229,77 @@ def local_blockwise_reshape(fgraph, node):
new_out
=
x
.
reshape
([
*
tuple
(
batched_shape
),
*
tuple
(
core_reshape
)])
new_out
=
x
.
reshape
([
*
tuple
(
batched_shape
),
*
tuple
(
core_reshape
)])
copy_stack_trace
(
node
.
outputs
[
0
],
new_out
)
copy_stack_trace
(
node
.
outputs
[
0
],
new_out
)
return
[
new_out
]
return
[
new_out
]
@node_rewriter
(
tracks
=
[
Blockwise
],
inplace
=
True
)
def
blockwise_inplace
(
fgraph
,
node
):
blockwise_op
=
node
.
op
if
blockwise_op
.
destroy_map
:
# Op already has inplace
return
# Find out valid inputs for inplacing
batch_ndim
=
blockwise_op
.
batch_ndim
(
node
)
out_batch_bcast
=
node
.
outputs
[
0
]
.
type
.
broadcastable
[:
batch_ndim
]
protected_inputs
=
[
f
.
protected
for
f
in
fgraph
.
_features
if
isinstance
(
f
,
Supervisor
)
]
protected_inputs
=
list
(
itertools
.
chain
.
from_iterable
(
protected_inputs
))
protected_inputs
.
extend
(
fgraph
.
outputs
)
allowed_inplace_inputs
=
[
idx
for
idx
,
inp
in
enumerate
(
node
.
inputs
)
if
(
# Constants would need to be recreated every time if inplaced
not
isinstance
(
inp
,
Constant
)
# We can only inplace on inputs that are not being broadcasted
# As those are reused across iterations of Blockwise
and
node
.
inputs
[
idx
]
.
type
.
broadcastable
[:
batch_ndim
]
==
out_batch_bcast
# Inputs that are marked as protected or destroyed can't be inplaced
and
not
fgraph
.
has_destroyers
([
inp
])
and
inp
not
in
protected_inputs
)
]
if
not
allowed_inplace_inputs
:
return
None
inplace_core_op
=
blockwise_op
.
core_op
.
inplace_on_inputs
(
allowed_inplace_inputs
=
allowed_inplace_inputs
)
if
not
inplace_core_op
.
destroy_map
:
return
None
# Check Op is not trying to inplace on non-candidate inputs
for
destroyed_inputs
in
inplace_core_op
.
destroy_map
.
values
():
for
destroyed_input
in
destroyed_inputs
:
if
destroyed_input
not
in
allowed_inplace_inputs
:
raise
ValueError
(
f
"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
)
# Recreate core_op with inplace
inplace_blockwise_op
=
Blockwise
(
core_op
=
inplace_core_op
,
signature
=
blockwise_op
.
signature
,
name
=
blockwise_op
.
name
,
gufunc_spec
=
blockwise_op
.
gufunc_spec
,
destroy_map
=
inplace_core_op
.
destroy_map
,
)
out
=
inplace_blockwise_op
.
make_node
(
*
node
.
inputs
)
.
outputs
copy_stack_trace
(
node
.
outputs
,
out
)
return
out
optdb
.
register
(
"blockwise_inplace"
,
in2out
(
blockwise_inplace
),
"fast_run"
,
"inplace"
,
position
=
50.1
,
)
pytensor/tensor/slinalg.py
浏览文件 @
ed6ca162
差异被折叠。
点击展开。
tests/tensor/test_blockwise.py
浏览文件 @
ed6ca162
...
@@ -3,10 +3,11 @@ from itertools import product
...
@@ -3,10 +3,11 @@ from itertools import product
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
scipy.linalg
import
pytensor
import
pytensor
from
pytensor
import
config
,
function
from
pytensor
import
In
,
config
,
function
from
pytensor.compile
import
get_mode
from
pytensor.compile
import
get_
default_mode
,
get_
mode
from
pytensor.gradient
import
grad
from
pytensor.gradient
import
grad
from
pytensor.graph
import
Apply
,
Op
from
pytensor.graph
import
Apply
,
Op
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.graph.replace
import
vectorize_node
...
@@ -15,7 +16,15 @@ from pytensor.tensor import diagonal, log, tensor
...
@@ -15,7 +16,15 @@ from pytensor.tensor import diagonal, log, tensor
from
pytensor.tensor.blockwise
import
Blockwise
,
vectorize_node_fallback
from
pytensor.tensor.blockwise
import
Blockwise
,
vectorize_node_fallback
from
pytensor.tensor.nlinalg
import
MatrixInverse
from
pytensor.tensor.nlinalg
import
MatrixInverse
from
pytensor.tensor.rewriting.blas
import
specialize_matmul_to_batched_dot
from
pytensor.tensor.rewriting.blas
import
specialize_matmul_to_batched_dot
from
pytensor.tensor.slinalg
import
Cholesky
,
Solve
,
cholesky
,
solve_triangular
from
pytensor.tensor.slinalg
import
(
Cholesky
,
Solve
,
SolveBase
,
cho_solve
,
cholesky
,
solve
,
solve_triangular
,
)
from
pytensor.tensor.utils
import
_parse_gufunc_signature
from
pytensor.tensor.utils
import
_parse_gufunc_signature
...
@@ -398,3 +407,105 @@ def test_cop_with_params():
...
@@ -398,3 +407,105 @@ def test_cop_with_params():
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
fn
(
np
.
zeros
((
5
,
3
,
2
))
-
1
)
fn
(
np
.
zeros
((
5
,
3
,
2
))
-
1
)
@pytest.mark.skipif
(
config
.
mode
==
"FAST_COMPILE"
,
reason
=
"inplace rewrites disabled when mode is FAST_COMPILE"
,
)
class
TestInplace
:
@pytest.mark.parametrize
(
"is_batched"
,
(
False
,
True
))
def
test_cholesky
(
self
,
is_batched
):
X
=
tensor
(
"X"
,
shape
=
(
5
,
None
,
None
)
if
is_batched
else
(
None
,
None
))
L
=
cholesky
(
X
,
lower
=
True
)
f
=
function
([
In
(
X
,
mutable
=
True
)],
L
)
assert
not
L
.
owner
.
op
.
core_op
.
destroy_map
if
is_batched
:
[
cholesky_op
]
=
[
node
.
op
.
core_op
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Blockwise
)
and
isinstance
(
node
.
op
.
core_op
,
Cholesky
)
]
else
:
[
cholesky_op
]
=
[
node
.
op
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Cholesky
)
]
assert
cholesky_op
.
destroy_map
==
{
0
:
[
0
]}
rng
=
np
.
random
.
default_rng
(
441
+
is_batched
)
X_val
=
rng
.
normal
(
size
=
(
10
,
10
))
.
astype
(
config
.
floatX
)
X_val_in
=
X_val
@
X_val
.
T
if
is_batched
:
X_val_in
=
np
.
broadcast_to
(
X_val_in
,
(
5
,
*
X_val_in
.
shape
))
.
copy
()
X_val_in_copy
=
X_val_in
.
copy
()
f
(
X_val_in
)
np
.
testing
.
assert_allclose
(
X_val_in
,
np
.
linalg
.
cholesky
(
X_val_in_copy
),
atol
=
1e-5
if
config
.
floatX
==
"float32"
else
0
,
)
@pytest.mark.parametrize
(
"batched_A"
,
(
False
,
True
))
@pytest.mark.parametrize
(
"batched_b"
,
(
False
,
True
))
@pytest.mark.parametrize
(
"solve_fn"
,
(
solve
,
solve_triangular
,
cho_solve
))
def
test_solve
(
self
,
solve_fn
,
batched_A
,
batched_b
):
A
=
tensor
(
"A"
,
shape
=
(
5
,
3
,
3
)
if
batched_A
else
(
3
,
3
))
b
=
tensor
(
"b"
,
shape
=
(
5
,
3
)
if
batched_b
else
(
3
,))
if
solve_fn
==
cho_solve
:
# Special signature for cho_solve
x
=
solve_fn
((
A
,
True
),
b
,
b_ndim
=
1
)
else
:
x
=
solve_fn
(
A
,
b
,
b_ndim
=
1
)
mode
=
get_default_mode
()
.
excluding
(
"batched_vector_b_solve_to_matrix_b_solve"
)
fn
=
function
([
In
(
A
,
mutable
=
True
),
In
(
b
,
mutable
=
True
)],
x
,
mode
=
mode
)
op
=
fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
if
batched_A
or
batched_b
:
assert
isinstance
(
op
,
Blockwise
)
and
isinstance
(
op
.
core_op
,
SolveBase
)
if
batched_A
and
not
batched_b
:
if
solve_fn
==
solve
:
assert
op
.
destroy_map
==
{
0
:
[
0
]}
else
:
# SolveTriangular does not destroy A
assert
op
.
destroy_map
==
{}
else
:
assert
op
.
destroy_map
==
{
0
:
[
1
]}
else
:
assert
isinstance
(
op
,
SolveBase
)
assert
op
.
destroy_map
==
{
0
:
[
1
]}
# We test with an F_CONTIGUOUS (core) A as only that will be destroyed by scipy
rng
=
np
.
random
.
default_rng
(
487
+
batched_A
+
2
*
batched_b
+
sum
(
map
(
ord
,
solve_fn
.
__name__
))
)
A_val
=
np
.
swapaxes
(
rng
.
normal
(
size
=
A
.
type
.
shape
)
.
astype
(
A
.
type
.
dtype
),
-
1
,
-
2
)
b_val
=
np
.
random
.
normal
(
size
=
b
.
type
.
shape
)
.
astype
(
b
.
type
.
dtype
)
A_val_copy
=
A_val
.
copy
()
b_val_copy
=
b_val
.
copy
()
out
=
fn
(
A_val
,
b_val
)
if
solve_fn
==
cho_solve
:
def
core_scipy_fn
(
A
,
b
):
return
scipy
.
linalg
.
cho_solve
((
A
,
True
),
b
)
else
:
core_scipy_fn
=
getattr
(
scipy
.
linalg
,
solve_fn
.
__name__
)
expected_out
=
np
.
vectorize
(
core_scipy_fn
,
signature
=
"(m,m),(m)->(m)"
)(
A_val_copy
,
b_val_copy
)
np
.
testing
.
assert_allclose
(
out
,
expected_out
,
atol
=
1e-6
if
config
.
floatX
==
"float32"
else
0
)
# Confirm input was destroyed
assert
(
A_val
==
A_val_copy
)
.
all
()
==
(
op
.
destroy_map
.
get
(
0
,
None
)
!=
[
0
])
assert
(
b_val
==
b_val_copy
)
.
all
()
==
(
op
.
destroy_map
.
get
(
0
,
None
)
!=
[
1
])
tests/tensor/test_slinalg.py
浏览文件 @
ed6ca162
...
@@ -197,7 +197,10 @@ class TestSolveBase(utt.InferShapeTester):
...
@@ -197,7 +197,10 @@ class TestSolveBase(utt.InferShapeTester):
A
=
matrix
()
A
=
matrix
()
b
=
matrix
()
b
=
matrix
()
y
=
SolveBase
(
b_ndim
=
2
)(
A
,
b
)
y
=
SolveBase
(
b_ndim
=
2
)(
A
,
b
)
assert
y
.
__repr__
()
==
"SolveBase{lower=False, check_finite=True, b_ndim=2}.0"
assert
(
y
.
__repr__
()
==
"SolveBase{lower=False, check_finite=True, b_ndim=2, overwrite_a=False, overwrite_b=False}.0"
)
class
TestSolve
(
utt
.
InferShapeTester
):
class
TestSolve
(
utt
.
InferShapeTester
):
...
@@ -361,7 +364,7 @@ class TestCholeskySolve(utt.InferShapeTester):
...
@@ -361,7 +364,7 @@ class TestCholeskySolve(utt.InferShapeTester):
def
test_repr
(
self
):
def
test_repr
(
self
):
assert
(
assert
(
repr
(
CholeskySolve
(
lower
=
True
,
b_ndim
=
1
))
repr
(
CholeskySolve
(
lower
=
True
,
b_ndim
=
1
))
==
"CholeskySolve(lower=True,check_finite=True,b_ndim=1)"
==
"CholeskySolve(lower=True,check_finite=True,b_ndim=1
,overwrite_b=False
)"
)
)
def
test_infer_shape
(
self
):
def
test_infer_shape
(
self
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论