Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
321c0108
提交
321c0108
authored
1月 29, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
2月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove unused aesara.scan.utils.map_variables
上级
97f9ad48
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
1 行增加
和
354 行删除
+1
-354
utils.py
aesara/scan/utils.py
+0
-213
test_utils.py
tests/scan/test_utils.py
+1
-141
没有找到文件。
aesara/scan/utils.py
浏览文件 @
321c0108
...
@@ -20,9 +20,7 @@ from aesara.graph.basic import (
...
@@ -20,9 +20,7 @@ from aesara.graph.basic import (
equal_computations
,
equal_computations
,
graph_inputs
,
graph_inputs
,
)
)
from
aesara.graph.fg
import
FunctionGraph
from
aesara.graph.op
import
get_test_value
from
aesara.graph.op
import
get_test_value
from
aesara.graph.opt
import
TopoOptimizer
,
local_optimizer
from
aesara.graph.utils
import
TestValueError
from
aesara.graph.utils
import
TestValueError
from
aesara.tensor.basic
import
AllocEmpty
,
get_scalar_constant_value
from
aesara.tensor.basic
import
AllocEmpty
,
get_scalar_constant_value
from
aesara.tensor.subtensor
import
set_subtensor
from
aesara.tensor.subtensor
import
set_subtensor
...
@@ -229,217 +227,6 @@ def traverse(out, x, x_copy, d, visited=None):
...
@@ -229,217 +227,6 @@ def traverse(out, x, x_copy, d, visited=None):
return
d
return
d
def
map_variables
(
replacer
,
graphs
,
additional_inputs
=
None
):
"""Construct new graphs based on 'graphs' with some variables replaced
according to 'replacer'.
:param replacer: function that takes a variable and returns its
replacement.
:param graphs: an iterable of graphs in which to replace variables
:param additional_inputs: an iterable of graph inputs not used in any
of 'graphs' but possibly used in the graphs returned by `replacer`
:return: the new graphs, in the same order as 'graphs'
Example:
.. code-block:: python
tag = "replaceme"
a = aesara.tensor.type.scalar("a")
b = aesara.tensor.type.scalar("b")
c = aesara.tensor.type.scalar("c")
ab = a + b
ab.tag.replacement = a * b
u = ab + c
v, = map_variables(lambda graph:
return getattr(graph.tag, "replacement", graph),
[u])
# v is now equal to a * b + c
"""
if
additional_inputs
is
None
:
additional_inputs
=
[]
# wrap replacer to avoid replacing things we just put there.
graphs_seen
=
set
()
def
wrapped_replacer
(
graph
):
if
graph
in
graphs_seen
:
return
graph
else
:
new_graph
=
replacer
(
graph
)
graphs_seen
.
add
(
new_graph
)
return
new_graph
graphs
=
list
(
graphs
)
inputs_
=
list
(
set
(
list
(
graph_inputs
(
graphs
))
+
list
(
additional_inputs
)))
# perform any desired replacement of input variables. these
# aren't replaced by the local optimizer approach because they are
# not outputs of any Apply node.
new_inputs
=
[
wrapped_replacer
(
i
)
for
i
in
inputs_
]
replacements
=
[
(
input_
,
new_input
)
for
input_
,
new_input
in
zip
(
inputs_
,
new_inputs
)
if
new_input
is
not
input_
]
graphs
=
clone_replace
(
graphs
,
share_inputs
=
True
,
replace
=
replacements
)
inputs_
=
list
(
set
(
list
(
graph_inputs
(
graphs
))
+
list
(
additional_inputs
)))
fg
=
FunctionGraph
(
inputs_
,
graphs
,
clone
=
False
)
nodes_seen
=
set
()
@local_optimizer
(
None
)
def
local_transform
(
fgraph
,
node
):
if
node
in
nodes_seen
:
return
False
# importing Scan into module scope would be circular
from
aesara.compile.builders
import
OpFromGraph
from
aesara.scan.op
import
Scan
if
isinstance
(
node
.
op
,
(
Scan
,
OpFromGraph
)):
# recurse on the inner graph
(
new_inner_inputs
,
new_outer_inputs
,
new_inner_outputs
,
)
=
_map_variables_inner
(
wrapped_replacer
,
inner_inputs
=
node
.
op
.
inputs
,
outer_inputs
=
node
.
inputs
,
inner_outputs
=
node
.
op
.
outputs
,
containing_op
=
node
.
op
,
)
# reinstantiate the op
if
isinstance
(
node
.
op
,
Scan
):
new_op
=
Scan
(
new_inner_inputs
,
new_inner_outputs
,
node
.
op
.
info
,
node
.
op
.
mode
,
# FIXME: infer this someday?
typeConstructor
=
None
,
)
elif
isinstance
(
node
.
op
,
OpFromGraph
):
new_op
=
OpFromGraph
(
new_inner_inputs
,
new_inner_outputs
,
**
node
.
op
.
kwargs
)
# make a new node to replace the old one
new_node
=
new_op
.
make_node
(
*
new_outer_inputs
)
nodes_seen
.
add
(
new_node
)
return
new_node
.
outputs
else
:
nodes_seen
.
add
(
node
)
replacements
=
[
wrapped_replacer
(
o
)
for
o
in
node
.
outputs
]
# Add inputs to replacement graphs as inputs to this `fgraph`
for
i
in
graph_inputs
(
replacements
):
fgraph
.
add_input
(
i
)
return
replacements
topo_transform
=
TopoOptimizer
(
local_transform
,
"out_to_in"
)
topo_transform
.
optimize
(
fg
)
new_graphs
=
fg
.
outputs
fg
.
disown
()
return
new_graphs
def
_map_variables_inner
(
replacer
,
inner_inputs
,
outer_inputs
,
inner_outputs
,
containing_op
):
# the replacements returned by the replacer may involve variables
# that are already owned by the outer fgraph (`fg` in the caller)
# and so cannot be added to the inner fgraph (`fg` in the
# recursive call). wrap the replacer to catch these before they
# are added.
# additionally, some of these may be fgraph inputs or shared
# variables, which we cannot directly use inside the inner graph.
# we need to create inner inputs to access them through.
outer_to_inner
=
dict
(
zip
(
outer_inputs
,
inner_inputs
))
extra_inner_inputs
=
[]
extra_outer_inputs
=
[]
from
itertools
import
chain
from
aesara.scan
import
utils
def
inner_replacer
(
graph
):
new_graph
=
replacer
(
graph
)
other_inputs
=
[]
constants
=
[]
for
input_
in
graph_inputs
([
new_graph
]):
if
isinstance
(
input_
,
Variable
):
if
isinstance
(
input_
,
Constant
):
constants
.
append
(
input_
)
else
:
other_inputs
.
append
(
input_
)
# foreign inputs are fgraph inputs and shared variables that we need
# to access through inner inputs
foreign_inputs
=
list
(
set
(
other_inputs
)
-
set
(
outer_to_inner
.
values
()))
# skip further processing if there is nothing to do
if
not
constants
and
not
foreign_inputs
:
return
new_graph
replacements
=
[]
# constants just need to be replaced by copies that the inner
# `fg` can take ownership of
for
input_
in
constants
:
new_input
=
input_
.
clone
()
new_input
.
name
=
f
"{new_input.name}_copied"
replacements
.
append
((
input_
,
new_input
))
for
outer_input
in
foreign_inputs
:
if
getattr
(
outer_input
,
"update"
,
False
):
# when aesara.scan() constructs a scan node, it detects
# shared variables with updates and returns these updates
# to the user. we need to do the same thing for every new
# use of such a variable that is introduced. it's hard to
# do that at this point.
# shared variables with updates inside the inner graph of
# OpFromGraph are not supported at all, so we don't support
# introducing those either.
raise
NotImplementedError
(
f
"Replacement introduces shared variable {outer_input} "
"which has an update associated with it into "
f
"the inner graph of {containing_op}. This is not currently "
"supported."
)
# if this foreign input is not already available
# as an inner input, connect it through a new
# inner input
if
outer_input
not
in
outer_to_inner
.
keys
():
inner_input
=
utils
.
safe_new
(
outer_input
,
tag
=
"_copy"
)
outer_to_inner
[
outer_input
]
=
inner_input
extra_inner_inputs
.
append
(
inner_input
)
extra_outer_inputs
.
append
(
outer_input
)
replacements
.
extend
(
outer_to_inner
.
items
())
(
new_graph
,)
=
clone_replace
(
[
new_graph
],
share_inputs
=
True
,
replace
=
replacements
)
return
new_graph
new_inner_outputs
=
map_variables
(
inner_replacer
,
inner_outputs
)
new_inner_inputs
=
list
(
chain
(
inner_inputs
,
extra_inner_inputs
))
new_outer_inputs
=
list
(
chain
(
outer_inputs
,
extra_outer_inputs
))
return
new_inner_inputs
,
new_outer_inputs
,
new_inner_outputs
def
get_updates_and_outputs
(
ls
):
def
get_updates_and_outputs
(
ls
):
"""
"""
This function tries to recognize the updates OrderedDict, the
This function tries to recognize the updates OrderedDict, the
...
...
tests/scan/test_utils.py
浏览文件 @
321c0108
import
itertools
from
copy
import
copy
from
copy
import
copy
import
numpy
as
np
import
numpy
as
np
...
@@ -6,8 +5,7 @@ import pytest
...
@@ -6,8 +5,7 @@ import pytest
import
aesara
import
aesara
from
aesara
import
tensor
as
at
from
aesara
import
tensor
as
at
from
aesara.scan.utils
import
ScanArgs
,
map_variables
from
aesara.scan.utils
import
ScanArgs
from
aesara.tensor.type
import
scalar
,
vector
@pytest.fixture
(
scope
=
"module"
,
autouse
=
True
)
@pytest.fixture
(
scope
=
"module"
,
autouse
=
True
)
...
@@ -16,144 +14,6 @@ def set_aesara_flags():
...
@@ -16,144 +14,6 @@ def set_aesara_flags():
yield
yield
class
TestMapVariables
:
@staticmethod
def
replacer
(
graph
):
return
getattr
(
graph
.
tag
,
"replacement"
,
graph
)
def
test_leaf
(
self
):
a
=
scalar
(
"a"
)
b
=
scalar
(
"b"
)
c
=
scalar
(
"c"
)
b
.
tag
.
replacement
=
c
u
=
a
+
b
(
v
,)
=
map_variables
(
self
.
replacer
,
[
u
])
assert
u
.
owner
.
inputs
==
[
a
,
b
]
assert
v
.
owner
.
inputs
==
[
a
,
c
]
def
test_leaf_inside_scan
(
self
):
x
=
vector
(
"x"
)
y
=
scalar
(
"y"
)
z
=
scalar
(
"z"
)
y
.
tag
.
replacement
=
z
s
,
_
=
aesara
.
scan
(
lambda
x
:
x
*
y
,
sequences
=
x
)
(
s2
,)
=
map_variables
(
self
.
replacer
,
[
s
])
f
=
aesara
.
function
([
x
,
y
,
z
],
[
s
,
s2
])
rval
=
f
(
x
=
np
.
array
([
1
,
2
,
3
],
dtype
=
np
.
float32
),
y
=
1
,
z
=
2
)
assert
np
.
array_equal
(
rval
,
[[
1
,
2
,
3
],
[
2
,
4
,
6
]])
def
test_scan
(
self
):
x
=
vector
(
"x"
)
# we will insert a subgraph involving these variables into the inner
# graph of scan. since they were not previously in the inner graph,
# they are like non_sequences to scan(). scan() infers these and
# imports them into the inner graph properly, and map_variables()
# should do this as well.
outer
=
scalar
(
"outer"
)
shared
=
aesara
.
shared
(
np
.
array
(
1.0
,
dtype
=
aesara
.
config
.
floatX
),
name
=
"shared"
)
constant
=
at
.
constant
(
1
,
name
=
"constant"
)
# z will equal 1 so multiplying by it doesn't change any values
z
=
outer
*
(
shared
+
constant
)
def
step
(
x
,
a
):
r
=
a
+
x
r
.
tag
.
replacement
=
z
*
(
a
-
x
)
return
r
s
,
_
=
aesara
.
scan
(
step
,
sequences
=
x
,
outputs_info
=
[
np
.
array
(
0.0
)])
# ensure z is owned by the outer graph so map_variables() will need to
# jump through additional hoops to placate FunctionGraph.
t
=
z
*
s
(
s2
,)
=
map_variables
(
self
.
replacer
,
[
t
])
t2
=
z
*
s2
f
=
aesara
.
function
([
x
,
outer
],
[
t
,
t2
])
rval
=
f
(
x
=
np
.
array
([
1
,
2
,
3
],
dtype
=
np
.
float32
),
outer
=
0.5
)
assert
np
.
array_equal
(
rval
,
[[
1
,
3
,
6
],
[
-
1
,
-
3
,
-
6
]])
def
test_scan_with_shared_update
(
self
):
x
=
vector
(
"x"
)
# counts how many times its value is used
counter
=
aesara
.
shared
(
0
,
name
=
"shared"
)
counter
.
update
=
counter
+
1
def
step
(
x
,
a
):
r
=
a
+
x
# introducing a shared variable with an update into the
# inner graph is unsupported and the code must crash rather
# than silently produce the wrong result.
r
.
tag
.
replacement
=
counter
*
(
a
-
x
)
return
r
s
,
_
=
aesara
.
scan
(
step
,
sequences
=
x
,
outputs_info
=
[
np
.
array
(
0.0
)])
with
pytest
.
raises
(
NotImplementedError
):
map_variables
(
self
.
replacer
,
[
s
])
def
test_scan_with_shared_update2
(
self
):
x
=
vector
(
"x"
)
# counts how many times its value is used
counter
=
aesara
.
shared
(
0
,
name
=
"shared"
)
counter
.
update
=
counter
+
1
def
step
(
x
,
a
):
r
=
a
+
x
# introducing a shared variable with an update into the
# inner graph is unsupported and the code must crash rather
# than silently produce the wrong result.
r
.
tag
.
replacement
=
counter
*
(
a
-
x
)
# the shared variable was already present, but the
# replacement changes the number of times it is used,
# which would have to change the updates, which is
# unsupported.
return
r
+
counter
s
,
_
=
aesara
.
scan
(
step
,
sequences
=
x
,
outputs_info
=
[
np
.
array
(
0.0
)])
with
pytest
.
raises
(
NotImplementedError
):
map_variables
(
self
.
replacer
,
[
s
])
def
test_opfromgraph
(
self
):
# as with the scan tests above, insert foreign inputs into the
# inner graph.
outer
=
scalar
(
"outer"
)
shared
=
aesara
.
shared
(
np
.
array
(
1.0
,
dtype
=
aesara
.
config
.
floatX
),
name
=
"shared"
)
constant
=
at
.
constant
(
1.0
,
name
=
"constant"
)
z
=
outer
*
(
shared
+
constant
)
# construct the inner graph
a
=
scalar
()
b
=
scalar
()
r
=
a
+
b
r
.
tag
.
replacement
=
z
*
(
a
-
b
)
# construct the outer graph
c
=
scalar
()
d
=
scalar
()
u
=
aesara
.
compile
.
builders
.
OpFromGraph
([
a
,
b
],
[
r
])(
c
,
d
)
t
=
z
*
u
(
v
,)
=
map_variables
(
self
.
replacer
,
[
t
])
t2
=
z
*
v
f
=
aesara
.
function
([
c
,
d
,
outer
],
[
t
,
t2
])
for
m
,
n
in
itertools
.
combinations
(
range
(
10
),
2
):
assert
f
(
m
,
n
,
outer
=
0.5
)
==
[
m
+
n
,
m
-
n
]
# test that the unsupported case of replacement with a shared
# variable with updates crashes
shared
.
update
=
shared
+
1
with
pytest
.
raises
(
NotImplementedError
):
map_variables
(
self
.
replacer
,
[
t
])
def
create_test_hmm
():
def
create_test_hmm
():
rng_state
=
np
.
random
.
default_rng
(
23422
)
rng_state
=
np
.
random
.
default_rng
(
23422
)
rng_tt
=
aesara
.
shared
(
rng_state
,
name
=
"rng"
,
borrow
=
True
)
rng_tt
=
aesara
.
shared
(
rng_state
,
name
=
"rng"
,
borrow
=
True
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论