Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
5a2fb70b
提交
5a2fb70b
authored
12月 03, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
12月 08, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add KanrenRelationSub optimizer
上级
459c570d
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
261 行增加
和
0 行删除
+261
-0
kanren.py
aesara/graph/kanren.py
+93
-0
setup.py
setup.py
+1
-0
test_kanren.py
tests/graph/test_kanren.py
+167
-0
没有找到文件。
aesara/graph/kanren.py
0 → 100644
浏览文件 @
5a2fb70b
from
typing
import
Callable
,
Iterator
,
List
,
Union
from
etuples.core
import
ExpressionTuple
from
kanren
import
run
from
unification
import
var
from
unification.variable
import
Var
from
aesara.graph.basic
import
Apply
,
Variable
from
aesara.graph.opt
import
LocalOptimizer
from
aesara.graph.unify
import
eval_if_etuple
class
KanrenRelationSub
(
LocalOptimizer
):
r"""A local optimizer that uses `kanren` to match and replace terms.
See `kanren <https://github.com/pythological/kanren>`__ for more information
miniKanren and the API for constructing `kanren` goals.
Example
-------
..code-block:: python
from kanren import eq, conso, var
import aesara.tensor as at
from aesara.graph.kanren import KanrenRelationSub
def relation(in_lv, out_lv):
# A `kanren` goal that changes `at.log` terms to `at.exp`
cdr_lv = var()
return eq(conso(at.log, cdr_lv, in_lv),
conso(at.exp, cdr_lv, out_lv))
kanren_sub_opt = KanrenRelationSub(relation)
"""
reentrant
=
True
def
__init__
(
self
,
kanren_relation
:
Callable
[[
Variable
,
Var
],
Callable
],
results_filter
:
Callable
[
[
Iterator
],
List
[
Union
[
ExpressionTuple
,
Variable
]]
]
=
lambda
x
:
next
(
x
,
None
),
node_filter
:
Callable
[[
Apply
],
bool
]
=
lambda
x
:
True
,
):
r"""Create a `KanrenRelationSub`.
Parameters
----------
kanren_relation
A function that takes an input graph and an output logic variable and
returns a `kanren` goal.
results_filter
A function that takes the direct output of `kanren.run(None, ...)`
and returns a single result. The default implementation returns
the first result.
node_filter
A function taking a single node and returns ``True`` when the node
should be processed.
"""
self
.
kanren_relation
=
kanren_relation
self
.
results_filter
=
results_filter
self
.
node_filter
=
node_filter
super
()
.
__init__
()
def
transform
(
self
,
fgraph
,
node
):
if
self
.
node_filter
(
node
)
is
False
:
return
False
try
:
input_expr
=
node
.
default_output
()
except
ValueError
:
input_expr
=
node
.
outputs
q
=
var
()
kanren_results
=
run
(
None
,
q
,
self
.
kanren_relation
(
input_expr
,
q
))
chosen_res
=
self
.
results_filter
(
kanren_results
)
if
chosen_res
:
if
isinstance
(
chosen_res
,
list
):
new_outputs
=
[
eval_if_etuple
(
v
)
for
v
in
chosen_res
]
else
:
new_outputs
=
[
eval_if_etuple
(
chosen_res
)]
return
new_outputs
else
:
return
False
setup.py
浏览文件 @
5a2fb70b
...
...
@@ -51,6 +51,7 @@ install_requires = [
"filelock"
,
"etuples"
,
"logical-unification"
,
"miniKanren"
,
"cons"
,
]
...
...
tests/graph/test_kanren.py
0 → 100644
浏览文件 @
5a2fb70b
from
copy
import
copy
import
numpy
as
np
import
pytest
from
etuples
import
etuple
from
kanren
import
eq
,
fact
,
run
from
kanren.assoccomm
import
associative
,
commutative
,
eq_assoccomm
from
kanren.core
import
lall
from
unification
import
var
,
vars
import
aesara.tensor
as
at
from
aesara.graph.basic
import
Apply
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.kanren
import
KanrenRelationSub
from
aesara.graph.op
import
Op
from
aesara.graph.opt
import
EquilibriumOptimizer
from
aesara.graph.opt_utils
import
optimize_graph
from
aesara.graph.unify
import
eval_if_etuple
from
aesara.tensor.math
import
Dot
,
_dot
from
tests.graph.utils
import
MyType
,
MyVariable
@pytest.fixture
(
autouse
=
True
)
def
clear_assoccomm
():
old_commutative_index
=
copy
(
commutative
.
index
)
old_commutative_facts
=
copy
(
commutative
.
facts
)
old_associative_index
=
copy
(
associative
.
index
)
old_associative_facts
=
copy
(
associative
.
facts
)
try
:
yield
finally
:
commutative
.
index
=
old_commutative_index
commutative
.
facts
=
old_commutative_facts
associative
.
index
=
old_associative_index
associative
.
facts
=
old_associative_facts
def
test_kanren_basic
():
A_at
=
at
.
matrix
(
"A"
)
x_at
=
at
.
vector
(
"x"
)
y_at
=
at
.
dot
(
A_at
,
x_at
)
q
=
var
()
res
=
list
(
run
(
None
,
q
,
eq
(
y_at
,
etuple
(
_dot
,
q
,
x_at
))))
assert
res
==
[
A_at
]
def
test_KanrenRelationSub_filters
():
x_at
=
at
.
vector
(
"x"
)
y_at
=
at
.
vector
(
"y"
)
z_at
=
at
.
vector
(
"z"
)
A_at
=
at
.
matrix
(
"A"
)
fact
(
commutative
,
_dot
)
fact
(
commutative
,
at
.
add
)
fact
(
associative
,
at
.
add
)
Z_at
=
A_at
.
dot
((
x_at
+
y_at
)
+
z_at
)
fgraph
=
FunctionGraph
(
outputs
=
[
Z_at
],
clone
=
False
)
def
distributes
(
in_lv
,
out_lv
):
A_lv
,
x_lv
,
y_lv
,
z_lv
=
vars
(
4
)
return
lall
(
# lhs == A * (x + y + z)
eq_assoccomm
(
etuple
(
_dot
,
A_lv
,
etuple
(
at
.
add
,
x_lv
,
etuple
(
at
.
add
,
y_lv
,
z_lv
))),
in_lv
,
),
# This relation does nothing but provide us with a means of
# generating associative-commutative matches in the `kanren`
# output.
eq
((
A_lv
,
x_lv
,
y_lv
,
z_lv
),
out_lv
),
)
def
results_filter
(
results
):
_results
=
[
eval_if_etuple
(
v
)
for
v
in
results
]
# Make sure that at least a couple permutations are present
assert
(
A_at
,
x_at
,
y_at
,
z_at
)
in
_results
assert
(
A_at
,
y_at
,
x_at
,
z_at
)
in
_results
assert
(
A_at
,
z_at
,
x_at
,
y_at
)
in
_results
return
None
_
=
KanrenRelationSub
(
distributes
,
results_filter
=
results_filter
)
.
transform
(
fgraph
,
fgraph
.
outputs
[
0
]
.
owner
)
res
=
KanrenRelationSub
(
distributes
,
node_filter
=
lambda
x
:
False
)
.
transform
(
fgraph
,
fgraph
.
outputs
[
0
]
.
owner
)
assert
res
is
False
def
test_KanrenRelationSub_multiout
():
class
MyMultiOutOp
(
Op
):
def
make_node
(
self
,
*
inputs
):
outputs
=
[
MyType
()(),
MyType
()()]
return
Apply
(
self
,
list
(
inputs
),
outputs
)
def
perform
(
self
,
node
,
inputs
,
outputs
):
outputs
[
0
]
=
np
.
array
(
inputs
[
0
])
outputs
[
1
]
=
np
.
array
(
inputs
[
0
])
x
=
MyVariable
(
"x"
)
y
=
MyVariable
(
"y"
)
multi_op
=
MyMultiOutOp
()
o1
,
o2
=
multi_op
(
x
,
y
)
fgraph
=
FunctionGraph
([
x
,
y
],
[
o1
],
clone
=
False
)
def
relation
(
in_lv
,
out_lv
):
return
eq
(
in_lv
,
out_lv
)
res
=
KanrenRelationSub
(
relation
)
.
transform
(
fgraph
,
fgraph
.
outputs
[
0
]
.
owner
)
assert
res
==
[
o1
,
o2
]
def
test_KanrenRelationSub_dot
():
"""Make sure we can run miniKanren "optimizations" over a graph until a fixed-point/normal-form is reached."""
x_at
=
at
.
vector
(
"x"
)
c_at
=
at
.
vector
(
"c"
)
d_at
=
at
.
vector
(
"d"
)
A_at
=
at
.
matrix
(
"A"
)
B_at
=
at
.
matrix
(
"B"
)
Z_at
=
A_at
.
dot
(
x_at
+
B_at
.
dot
(
c_at
+
d_at
))
fgraph
=
FunctionGraph
(
outputs
=
[
Z_at
],
clone
=
False
)
assert
isinstance
(
fgraph
.
outputs
[
0
]
.
owner
.
op
,
Dot
)
def
distributes
(
in_lv
,
out_lv
):
return
lall
(
# lhs == A * (x + b)
eq
(
etuple
(
_dot
,
var
(
"A"
),
etuple
(
at
.
add
,
var
(
"x"
),
var
(
"b"
))),
in_lv
,
),
# rhs == A * x + A * b
eq
(
etuple
(
at
.
add
,
etuple
(
_dot
,
var
(
"A"
),
var
(
"x"
)),
etuple
(
_dot
,
var
(
"A"
),
var
(
"b"
)),
),
out_lv
,
),
)
distribute_opt
=
EquilibriumOptimizer
(
[
KanrenRelationSub
(
distributes
)],
max_use_ratio
=
10
)
fgraph_opt
=
optimize_graph
(
fgraph
,
custom_opt
=
distribute_opt
)
(
expr_opt
,)
=
fgraph_opt
.
outputs
assert
expr_opt
.
owner
.
op
==
at
.
add
assert
isinstance
(
expr_opt
.
owner
.
inputs
[
0
]
.
owner
.
op
,
Dot
)
assert
fgraph_opt
.
inputs
[
0
]
is
A_at
assert
expr_opt
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
0
]
.
name
==
"A"
assert
expr_opt
.
owner
.
inputs
[
1
]
.
owner
.
op
==
at
.
add
assert
isinstance
(
expr_opt
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
.
owner
.
op
,
Dot
)
assert
isinstance
(
expr_opt
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
1
]
.
owner
.
op
,
Dot
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论