Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
08538340
提交
08538340
authored
6月 25, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
6月 25, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove Hints and use instance checks in aesara.sandbox.linalg.ops
上级
2b12a455
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
30 行增加
和
239 行删除
+30
-239
__init__.py
aesara/sandbox/linalg/__init__.py
+1
-3
ops.py
aesara/sandbox/linalg/ops.py
+25
-224
test_linalg.py
tests/sandbox/linalg/test_linalg.py
+4
-12
没有找到文件。
aesara/sandbox/linalg/__init__.py
浏览文件 @
08538340
from
aesara.sandbox.linalg.ops
import
psd
,
spectral_radius_bound
from
aesara.tensor.nlinalg
import
det
,
eig
,
eigh
,
matrix_inverse
,
trace
from
aesara.tensor.slinalg
import
cholesky
,
eigvalsh
,
solve
from
aesara.sandbox.linalg.ops
import
spectral_radius_bound
aesara/sandbox/linalg/ops.py
浏览文件 @
08538340
import
logging
import
aesara.tensor
from
aesara.graph.basic
import
Apply
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
GlobalOptimizer
,
local_optimizer
from
aesara.graph.opt
import
local_optimizer
from
aesara.tensor
import
basic
as
aet
from
aesara.tensor.basic_opt
import
(
register_canonicalize
,
...
...
@@ -12,208 +9,16 @@ from aesara.tensor.basic_opt import (
)
from
aesara.tensor.blas
import
Dot22
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.exceptions
import
NotScalarConstantError
from
aesara.tensor.math
import
Dot
,
Prod
,
dot
,
log
from
aesara.tensor.math
import
pow
as
aet_pow
from
aesara.tensor.math
import
prod
from
aesara.tensor.nlinalg
import
MatrixInverse
,
det
,
matrix_i
nverse
,
trace
from
aesara.tensor.nlinalg
import
Det
,
MatrixI
nverse
,
trace
from
aesara.tensor.slinalg
import
Cholesky
,
Solve
,
cholesky
,
solve
logger
=
logging
.
getLogger
(
__name__
)
class
Hint
(
Op
):
"""
Provide arbitrary information to the optimizer.
These ops are removed from the graph during canonicalization
in order to not interfere with other optimizations.
The idea is that prior to canonicalization, one or more Features of the
fgraph should register the information contained in any Hint node, and
transfer that information out of the graph.
"""
__props__
=
(
"hints"
,)
def
__init__
(
self
,
**
kwargs
):
self
.
hints
=
tuple
(
kwargs
.
items
())
self
.
view_map
=
{
0
:
[
0
]}
def
make_node
(
self
,
x
):
return
Apply
(
self
,
[
x
],
[
x
.
type
()])
def
perform
(
self
,
node
,
inputs
,
outstor
):
outstor
[
0
][
0
]
=
inputs
[
0
]
def
grad
(
self
,
inputs
,
g_out
):
return
g_out
def
hints
(
variable
):
if
variable
.
owner
and
isinstance
(
variable
.
owner
.
op
,
Hint
):
return
dict
(
variable
.
owner
.
op
.
hints
)
else
:
return
{}
@register_canonicalize
@local_optimizer
([
Hint
])
def
remove_hint_nodes
(
fgraph
,
node
):
if
isinstance
(
node
,
Hint
):
# transfer hints from graph to Feature
try
:
for
k
,
v
in
node
.
op
.
hints
:
fgraph
.
hints_feature
.
add_hint
(
node
.
inputs
[
0
],
k
,
v
)
except
AttributeError
:
pass
return
node
.
inputs
class
HintsFeature
:
"""
FunctionGraph Feature to track matrix properties.
This is a similar feature to variable 'tags'. In fact, tags are one way
to provide hints.
This class exists because tags were not documented well, and the
semantics of how tag information should be moved around during
optimizations was never clearly spelled out.
Hints are assumptions about mathematical properties of variables.
If one variable is substituted for another by an optimization,
then it means that the assumptions should be transferred to the
new variable.
Hints are attached to 'positions in a graph' rather than to variables
in particular, although Hints are originally attached to a particular
positition in a graph *via* a variable in that original graph.
Examples of hints are:
- shape information
- matrix properties (e.g. symmetry, psd, banded, diagonal)
Hint information is propagated through the graph similarly to graph
optimizations, except that adding a hint does not change the graph.
Adding a hint is not something that debugmode will check.
#TODO: should a Hint be an object that can actually evaluate its
# truthfulness?
# Should the PSD property be an object that can check the
# PSD-ness of a variable?
"""
def
add_hint
(
self
,
r
,
k
,
v
):
logger
.
debug
(
f
"adding hint; {r}, {k}, {v}"
)
self
.
hints
[
r
][
k
]
=
v
def
ensure_init_r
(
self
,
r
):
if
r
not
in
self
.
hints
:
self
.
hints
[
r
]
=
{}
#
#
# Feature interface
#
#
def
on_attach
(
self
,
fgraph
):
assert
not
hasattr
(
fgraph
,
"hints_feature"
)
fgraph
.
hints_feature
=
self
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
self
.
hints
=
{}
for
node
in
fgraph
.
toposort
():
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
def
on_import
(
self
,
fgraph
,
node
,
reason
):
if
node
.
outputs
[
0
]
in
self
.
hints
:
# this is a revert, not really an import
for
r
in
node
.
outputs
+
node
.
inputs
:
assert
r
in
self
.
hints
return
for
i
,
r
in
enumerate
(
node
.
inputs
+
node
.
outputs
):
# make sure we have shapes for the inputs
self
.
ensure_init_r
(
r
)
def
update_second_from_first
(
self
,
r0
,
r1
):
old_hints
=
self
.
hints
[
r0
]
new_hints
=
self
.
hints
[
r1
]
for
k
,
v
in
old_hints
.
items
():
if
k
in
new_hints
and
new_hints
[
k
]
is
not
v
:
raise
NotImplementedError
()
if
k
not
in
new_hints
:
new_hints
[
k
]
=
v
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
):
# TODO:
# This tells us that r and new_r must have the same shape
# if we didn't know that the shapes are related, now we do.
self
.
ensure_init_r
(
new_r
)
self
.
update_second_from_first
(
r
,
new_r
)
self
.
update_second_from_first
(
new_r
,
r
)
# change_input happens in two cases:
# 1) we are trying to get rid of r, or
# 2) we are putting things back after a failed transaction.
class
HintsOptimizer
(
GlobalOptimizer
):
"""
Optimizer that serves to add HintsFeature as an fgraph feature.
"""
def
__init__
(
self
):
super
()
.
__init__
()
def
add_requirements
(
self
,
fgraph
):
fgraph
.
attach_feature
(
HintsFeature
())
def
apply
(
self
,
fgraph
):
pass
# -1 should make it run right before the first merge
aesara
.
compile
.
mode
.
optdb
.
register
(
"HintsOpt"
,
HintsOptimizer
(),
-
1
,
"fast_run"
,
"fast_compile"
)
def
psd
(
v
):
r"""
Apply a hint that the variable `v` is positive semi-definite, i.e.
it is a symmetric matrix and :math:`x^T A x \ge 0` for any vector x.
"""
return
Hint
(
psd
=
True
,
symmetric
=
True
)(
v
)
def
is_psd
(
v
):
return
hints
(
v
)
.
get
(
"psd"
,
False
)
def
is_symmetric
(
v
):
return
hints
(
v
)
.
get
(
"symmetric"
,
False
)
def
is_positive
(
v
):
if
hints
(
v
)
.
get
(
"positive"
,
False
):
return
True
# TODO: how to handle this - a registry?
# infer_hints on Ops?
logger
.
debug
(
f
"is_positive: {v}"
)
if
v
.
owner
and
v
.
owner
.
op
==
aet_pow
:
try
:
exponent
=
aet
.
get_scalar_constant_value
(
v
.
owner
.
inputs
[
1
])
except
NotScalarConstantError
:
return
False
if
0
==
exponent
%
2
:
return
True
return
False
@register_canonicalize
@local_optimizer
([
DimShuffle
])
def
transinv_to_invtrans
(
fgraph
,
node
):
...
...
@@ -229,15 +34,19 @@ def transinv_to_invtrans(fgraph, node):
@register_stabilize
@local_optimizer
([
Dot
,
Dot22
])
def
inv_as_solve
(
fgraph
,
node
):
"""
This utilizes a boolean `symmetric` tag on the matrices.
"""
if
isinstance
(
node
.
op
,
(
Dot
,
Dot22
)):
l
,
r
=
node
.
inputs
if
l
.
owner
and
l
.
owner
.
op
==
matrix_inverse
:
if
l
.
owner
and
isinstance
(
l
.
owner
.
op
,
MatrixInverse
)
:
return
[
solve
(
l
.
owner
.
inputs
[
0
],
r
)]
if
r
.
owner
and
r
.
owner
.
op
==
matrix_inverse
:
if
is_symmetric
(
r
.
owner
.
inputs
[
0
]):
return
[
solve
(
r
.
owner
.
inputs
[
0
],
l
.
T
)
.
T
]
if
r
.
owner
and
isinstance
(
r
.
owner
.
op
,
MatrixInverse
):
x
=
r
.
owner
.
inputs
[
0
]
if
getattr
(
x
.
tag
,
"symmetric"
,
None
)
is
True
:
return
[
solve
(
x
,
l
.
T
)
.
T
]
else
:
return
[
solve
(
r
.
owner
.
inputs
[
0
]
.
T
,
l
.
T
)
.
T
]
return
[
solve
(
x
.
T
,
l
.
T
)
.
T
]
@register_stabilize
...
...
@@ -277,18 +86,20 @@ def tag_solve_triangular(fgraph, node):
def
no_transpose_symmetric
(
fgraph
,
node
):
if
isinstance
(
node
.
op
,
DimShuffle
):
x
=
node
.
inputs
[
0
]
if
x
.
type
.
ndim
==
2
and
is_symmetric
(
x
):
# print 'UNDOING TRANSPOSE', is_symmetric(x), x.ndim
if
x
.
type
.
ndim
==
2
and
getattr
(
x
.
tag
,
"symmetric"
,
None
)
is
True
:
if
node
.
op
.
new_order
==
[
1
,
0
]:
return
[
x
]
@register_stabilize
@local_optimizer
(
None
)
# XXX: solve is defined later and can't be used here
@local_optimizer
(
[
Solve
])
def
psd_solve_with_chol
(
fgraph
,
node
):
"""
This utilizes a boolean `psd` tag on matrices.
"""
if
isinstance
(
node
.
op
,
Solve
):
A
,
b
=
node
.
inputs
# result is solution Ax=b
if
is_psd
(
A
)
:
if
getattr
(
A
.
tag
,
"psd"
,
None
)
is
True
:
L
=
cholesky
(
A
)
# N.B. this can be further reduced to a yet-unwritten cho_solve Op
# __if__ no other Op makes use of the the L matrix during the
...
...
@@ -300,14 +111,14 @@ def psd_solve_with_chol(fgraph, node):
@register_stabilize
@register_specialize
@local_optimizer
(
None
)
# XXX: det is defined later and can't be used here
@local_optimizer
(
[
Det
])
def
local_det_chol
(
fgraph
,
node
):
"""
If we have det(X) and there is already an L=cholesky(X)
floating around, then we can use prod(diag(L)) to get the determinant.
"""
if
node
.
op
==
det
:
if
isinstance
(
node
.
op
,
Det
)
:
(
x
,)
=
node
.
inputs
for
(
cl
,
xpos
)
in
fgraph
.
clients
[
x
]:
if
isinstance
(
cl
.
op
,
Cholesky
):
...
...
@@ -320,6 +131,9 @@ def local_det_chol(fgraph, node):
@register_specialize
@local_optimizer
([
log
])
def
local_log_prod_sqr
(
fgraph
,
node
):
"""
This utilizes a boolean `positive` tag on matrices.
"""
if
node
.
op
==
log
:
(
x
,)
=
node
.
inputs
if
x
.
owner
and
isinstance
(
x
.
owner
.
op
,
Prod
):
...
...
@@ -328,29 +142,16 @@ def local_log_prod_sqr(fgraph, node):
p
=
x
.
owner
.
inputs
[
0
]
# p is the matrix we're reducing with prod
if
is_positive
(
p
)
:
if
getattr
(
p
.
tag
,
"positive"
,
None
)
is
True
:
return
[
log
(
p
)
.
sum
(
axis
=
x
.
owner
.
op
.
axis
)]
# TODO: have a reduction like prod and sum that simply
# returns the sign of the prod multiplication.
@register_canonicalize
@register_stabilize
@register_specialize
@local_optimizer
([
log
])
def
local_log_pow
(
fgraph
,
node
):
if
node
.
op
==
log
:
(
x
,)
=
node
.
inputs
if
x
.
owner
and
x
.
owner
.
op
==
aet_pow
:
base
,
exponent
=
x
.
owner
.
inputs
# TODO: reason to be careful with dtypes?
return
[
exponent
*
log
(
base
)]
# returns the sign of the prod multiplication.
def
spectral_radius_bound
(
X
,
log2_exponent
):
"""
Returns upper bound on the largest eigenvalue of square symmetri
x
matrix X.
Returns upper bound on the largest eigenvalue of square symmetri
c
matrix X.
log2_exponent must be a positive-valued integer. The larger it is, the
slower and tighter the bound. Values up to 5 should usually suffice. The
...
...
tests/sandbox/linalg/test_linalg.py
浏览文件 @
08538340
...
...
@@ -5,19 +5,11 @@ import aesara
from
aesara
import
function
from
aesara
import
tensor
as
aet
from
aesara.configdefaults
import
config
# The one in comment are not tested...
from
aesara.sandbox.linalg.ops
import
Cholesky
# PSD_hint,; op class
from
aesara.sandbox.linalg.ops
import
(
Solve
,
inv_as_solve
,
matrix_inverse
,
solve
,
spectral_radius_bound
,
)
from
aesara.sandbox.linalg.ops
import
inv_as_solve
,
spectral_radius_bound
from
aesara.tensor.elemwise
import
DimShuffle
from
aesara.tensor.math
import
_allclose
from
aesara.tensor.nlinalg
import
MatrixInverse
from
aesara.tensor.nlinalg
import
MatrixInverse
,
matrix_inverse
from
aesara.tensor.slinalg
import
Cholesky
,
Solve
,
solve
from
aesara.tensor.type
import
dmatrix
,
matrix
,
vector
from
tests
import
unittest_tools
as
utt
from
tests.test_rop
import
break_op
...
...
@@ -120,7 +112,7 @@ def test_spectral_radius_bound():
def
test_transinv_to_invtrans
():
X
=
matrix
(
"X"
)
Y
=
aesara
.
tensor
.
nlinalg
.
matrix_inverse
(
X
)
Y
=
matrix_inverse
(
X
)
Z
=
Y
.
transpose
()
f
=
aesara
.
function
([
X
],
Z
)
if
config
.
mode
!=
"FAST_COMPILE"
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论