Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
b0a1d33c
提交
b0a1d33c
authored
4月 10, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
4月 11, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add view_map and destroy_map variables and docstrings to Op
上级
cae78759
隐藏空白字符变更
内嵌
并排
正在显示
15 个修改的文件
包含
81 行增加
和
64 行删除
+81
-64
debugmode.py
aesara/compile/debugmode.py
+14
-14
types.py
aesara/compile/function/types.py
+5
-5
profiling.py
aesara/compile/profiling.py
+6
-6
formatting.py
aesara/d3viz/formatting.py
+2
-2
destroyhandler.py
aesara/graph/destroyhandler.py
+10
-10
fg.py
aesara/graph/fg.py
+2
-2
op.py
aesara/graph/op.py
+24
-1
opt.py
aesara/graph/opt.py
+1
-1
toolbox.py
aesara/graph/toolbox.py
+1
-1
utils.py
aesara/link/utils.py
+7
-7
vm.py
aesara/link/vm.py
+3
-5
printing.py
aesara/printing.py
+2
-2
op.py
aesara/scan/op.py
+1
-5
opt.py
aesara/scan/opt.py
+1
-1
test_mode.py
tests/compile/test_mode.py
+2
-2
没有找到文件。
aesara/compile/debugmode.py
浏览文件 @
b0a1d33c
...
@@ -166,7 +166,7 @@ class BadDestroyMap(DebugModeError):
...
@@ -166,7 +166,7 @@ class BadDestroyMap(DebugModeError):
print
(
" node:"
,
self
.
node
,
file
=
sio
)
print
(
" node:"
,
self
.
node
,
file
=
sio
)
print
(
" perform:"
,
self
.
perform
,
file
=
sio
)
print
(
" perform:"
,
self
.
perform
,
file
=
sio
)
print
(
" node.inputs:"
,
[(
str
(
i
),
id
(
i
))
for
i
in
self
.
node
.
inputs
],
file
=
sio
)
print
(
" node.inputs:"
,
[(
str
(
i
),
id
(
i
))
for
i
in
self
.
node
.
inputs
],
file
=
sio
)
print
(
" destroy_map:"
,
getattr
(
self
.
node
.
op
,
"destroy_map"
,
{})
,
file
=
sio
)
print
(
" destroy_map:"
,
self
.
node
.
op
.
destroy_map
,
file
=
sio
)
print
(
" changed input idx:"
,
self
.
idx
,
file
=
sio
)
print
(
" changed input idx:"
,
self
.
idx
,
file
=
sio
)
print
(
" changed input type:"
,
self
.
node
.
inputs
[
self
.
idx
]
.
type
,
file
=
sio
)
print
(
" changed input type:"
,
self
.
node
.
inputs
[
self
.
idx
]
.
type
,
file
=
sio
)
print
(
" repr (old val):"
,
repr
(
self
.
old_val
),
file
=
sio
)
print
(
" repr (old val):"
,
repr
(
self
.
old_val
),
file
=
sio
)
...
@@ -250,8 +250,8 @@ class BadViewMap(DebugModeError):
...
@@ -250,8 +250,8 @@ class BadViewMap(DebugModeError):
print
(
" node:"
,
self
.
node
,
file
=
sio
)
print
(
" node:"
,
self
.
node
,
file
=
sio
)
print
(
" node.inputs:"
,
[(
str
(
i
),
id
(
i
))
for
i
in
self
.
node
.
inputs
],
file
=
sio
)
print
(
" node.inputs:"
,
[(
str
(
i
),
id
(
i
))
for
i
in
self
.
node
.
inputs
],
file
=
sio
)
print
(
" node.outputs:"
,
[(
str
(
i
),
id
(
i
))
for
i
in
self
.
node
.
outputs
],
file
=
sio
)
print
(
" node.outputs:"
,
[(
str
(
i
),
id
(
i
))
for
i
in
self
.
node
.
outputs
],
file
=
sio
)
print
(
" view_map:"
,
getattr
(
self
.
node
.
op
,
"view_map"
,
{})
,
file
=
sio
)
print
(
" view_map:"
,
self
.
node
.
op
.
view_map
,
file
=
sio
)
print
(
" destroy_map:"
,
getattr
(
self
.
node
.
op
,
"destroy_map"
,
{})
,
file
=
sio
)
print
(
" destroy_map:"
,
self
.
node
.
op
.
destroy_map
,
file
=
sio
)
print
(
" aliased output:"
,
self
.
output_idx
,
file
=
sio
)
print
(
" aliased output:"
,
self
.
output_idx
,
file
=
sio
)
print
(
" aliased output storage:"
,
self
.
out_storage
,
file
=
sio
)
print
(
" aliased output storage:"
,
self
.
out_storage
,
file
=
sio
)
if
self
.
in_alias_idx
:
if
self
.
in_alias_idx
:
...
@@ -554,12 +554,12 @@ def debugprint(
...
@@ -554,12 +554,12 @@ def debugprint(
r_name
=
""
r_name
=
""
if
print_destroy_map
:
if
print_destroy_map
:
destroy_map_str
=
str
(
getattr
(
r
.
owner
.
op
,
"destroy_map"
,
""
)
)
destroy_map_str
=
str
(
r
.
owner
.
op
.
destroy_map
)
else
:
else
:
destroy_map_str
=
""
destroy_map_str
=
""
if
print_view_map
:
if
print_view_map
:
view_map_str
=
str
(
getattr
(
r
.
owner
.
op
,
"view_map"
,
""
)
)
view_map_str
=
str
(
r
.
owner
.
op
.
view_map
)
else
:
else
:
view_map_str
=
""
view_map_str
=
""
if
destroy_map_str
and
destroy_map_str
!=
"{}"
:
if
destroy_map_str
and
destroy_map_str
!=
"{}"
:
...
@@ -742,13 +742,13 @@ def _check_inputs(
...
@@ -742,13 +742,13 @@ def _check_inputs(
"""
"""
destroyed_idx_list
=
[]
destroyed_idx_list
=
[]
destroy_map
=
getattr
(
node
.
op
,
"destroy_map"
,
{})
destroy_map
=
node
.
op
.
destroy_map
for
o_pos
,
i_pos_list
in
destroy_map
.
items
():
for
o_pos
,
i_pos_list
in
destroy_map
.
items
():
destroyed_idx_list
.
extend
(
i_pos_list
)
destroyed_idx_list
.
extend
(
i_pos_list
)
destroyed_res_list
=
[
node
.
inputs
[
i
]
for
i
in
destroyed_idx_list
]
destroyed_res_list
=
[
node
.
inputs
[
i
]
for
i
in
destroyed_idx_list
]
actually_inplace_outputs
=
[]
actually_inplace_outputs
=
[]
dmap
=
getattr
(
node
.
op
,
"destroy_map"
,
{})
dmap
=
node
.
op
.
destroy_map
for
oo
,
ii
in
dmap
.
items
():
for
oo
,
ii
in
dmap
.
items
():
var
=
node
.
outputs
[
oo
]
var
=
node
.
outputs
[
oo
]
out_var
=
storage_map
[
var
][
0
]
out_var
=
storage_map
[
var
][
0
]
...
@@ -769,7 +769,7 @@ def _check_inputs(
...
@@ -769,7 +769,7 @@ def _check_inputs(
f
"as destroyed was not changed for node '{node}'"
f
"as destroyed was not changed for node '{node}'"
)
)
vmap
=
getattr
(
node
.
op
,
"view_map"
,
{})
vmap
=
node
.
op
.
view_map
for
oo
,
ii
in
vmap
.
items
():
for
oo
,
ii
in
vmap
.
items
():
var
=
node
.
outputs
[
oo
]
var
=
node
.
outputs
[
oo
]
out_var
=
storage_map
[
var
][
0
]
out_var
=
storage_map
[
var
][
0
]
...
@@ -836,8 +836,8 @@ def _check_viewmap(fgraph, node, storage_map):
...
@@ -836,8 +836,8 @@ def _check_viewmap(fgraph, node, storage_map):
outstorage
=
storage_map
[
onode
][
0
]
outstorage
=
storage_map
[
onode
][
0
]
# first find out which input it aliases
# first find out which input it aliases
view_map
=
getattr
(
node
.
op
,
"view_map"
,
{})
view_map
=
node
.
op
.
view_map
destroy_map
=
getattr
(
node
.
op
,
"destroy_map"
,
{})
destroy_map
=
node
.
op
.
destroy_map
# In theory, aesara's view_map only allows for 1 output to
# In theory, aesara's view_map only allows for 1 output to
# alias 1 input. Checking for multiple aliases just in
# alias 1 input. Checking for multiple aliases just in
...
@@ -1395,8 +1395,8 @@ def _check_preallocated_output(
...
@@ -1395,8 +1395,8 @@ def _check_preallocated_output(
# Set of inputs that are marked as destroyed or viewed
# Set of inputs that are marked as destroyed or viewed
aliased_inputs
=
set
()
aliased_inputs
=
set
()
dmap
=
getattr
(
node
.
op
,
"destroy_map"
,
{})
dmap
=
node
.
op
.
destroy_map
vmap
=
getattr
(
node
.
op
,
"view_map"
,
{})
vmap
=
node
.
op
.
view_map
for
i
,
r
in
enumerate
(
node
.
inputs
):
for
i
,
r
in
enumerate
(
node
.
inputs
):
if
any
(
i
in
v
for
v
in
chain
(
dmap
.
values
(),
vmap
.
values
())):
if
any
(
i
in
v
for
v
in
chain
(
dmap
.
values
(),
vmap
.
values
())):
aliased_inputs
.
add
(
r
)
aliased_inputs
.
add
(
r
)
...
@@ -2082,8 +2082,8 @@ class _Linker(LocalLinker):
...
@@ -2082,8 +2082,8 @@ class _Linker(LocalLinker):
clobber
=
True
clobber
=
True
if
thunk_py
:
if
thunk_py
:
dmap
=
getattr
(
node
.
op
,
"destroy_map"
,
{})
dmap
=
node
.
op
.
destroy_map
vmap
=
getattr
(
node
.
op
,
"view_map"
,
{})
vmap
=
node
.
op
.
view_map
for
i
,
r
in
enumerate
(
node
.
inputs
):
for
i
,
r
in
enumerate
(
node
.
inputs
):
# if thunk_py ran, and we still got
# if thunk_py ran, and we still got
# this far, it means that the
# this far, it means that the
...
...
aesara/compile/function/types.py
浏览文件 @
b0a1d33c
...
@@ -57,8 +57,8 @@ def alias_root(v):
...
@@ -57,8 +57,8 @@ def alias_root(v):
"""
"""
if
v
.
owner
is
None
:
if
v
.
owner
is
None
:
return
v
return
v
vmap
=
getattr
(
v
.
owner
.
op
,
"view_map"
,
{})
vmap
=
v
.
owner
.
op
.
view_map
dmap
=
getattr
(
v
.
owner
.
op
,
"destroy_map"
,
{})
dmap
=
v
.
owner
.
op
.
destroy_map
outpos
=
v
.
owner
.
outputs
.
index
(
v
)
outpos
=
v
.
owner
.
outputs
.
index
(
v
)
v_views
=
vmap
.
get
(
outpos
,
[])
+
dmap
.
get
(
outpos
,
[])
v_views
=
vmap
.
get
(
outpos
,
[])
+
dmap
.
get
(
outpos
,
[])
if
len
(
v_views
)
>
1
:
if
len
(
v_views
)
>
1
:
...
@@ -83,8 +83,8 @@ def view_tree_set(fgraph, v, treeset):
...
@@ -83,8 +83,8 @@ def view_tree_set(fgraph, v, treeset):
for
cl
,
v_input_pos_to_cl
in
fgraph
.
clients
[
v
]:
for
cl
,
v_input_pos_to_cl
in
fgraph
.
clients
[
v
]:
if
cl
==
"output"
:
if
cl
==
"output"
:
continue
continue
vmap
=
getattr
(
cl
.
op
,
"view_map"
,
{})
vmap
=
cl
.
op
.
view_map
dmap
=
getattr
(
cl
.
op
,
"destroy_map"
,
{})
dmap
=
cl
.
op
.
destroy_map
for
opos
,
iposlist
in
chain
(
vmap
.
items
(),
dmap
.
items
()):
for
opos
,
iposlist
in
chain
(
vmap
.
items
(),
dmap
.
items
()):
if
v_input_pos_to_cl
in
iposlist
:
if
v_input_pos_to_cl
in
iposlist
:
if
cl
.
outputs
[
opos
]
not
in
treeset
:
if
cl
.
outputs
[
opos
]
not
in
treeset
:
...
@@ -189,7 +189,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
...
@@ -189,7 +189,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
fgraph
=
FunctionGraph
(
orig_inputs
,
orig_outputs
,
update_mapping
=
update_mapping
)
fgraph
=
FunctionGraph
(
orig_inputs
,
orig_outputs
,
update_mapping
=
update_mapping
)
for
node
in
fgraph
.
apply_nodes
:
for
node
in
fgraph
.
apply_nodes
:
if
getattr
(
node
.
op
,
"destroy_map"
,
None
)
:
if
node
.
op
.
destroy_map
:
if
not
accept_inplace
:
if
not
accept_inplace
:
raise
TypeError
(
raise
TypeError
(
"Graph must not contain inplace operations"
,
node
,
node
.
op
"Graph must not contain inplace operations"
,
node
,
node
.
op
...
...
aesara/compile/profiling.py
浏览文件 @
b0a1d33c
...
@@ -962,8 +962,8 @@ class ProfileStats:
...
@@ -962,8 +962,8 @@ class ProfileStats:
if
ignore_dmap
:
if
ignore_dmap
:
dmap
=
None
dmap
=
None
else
:
else
:
dmap
=
getattr
(
node
.
op
,
"destroy_map"
,
None
)
dmap
=
node
.
op
.
destroy_map
vmap
=
getattr
(
node
.
op
,
"view_map"
,
None
)
vmap
=
node
.
op
.
view_map
val
=
nodes_mem
[
node
]
val
=
nodes_mem
[
node
]
for
v
in
val
:
for
v
in
val
:
...
@@ -1125,8 +1125,8 @@ class ProfileStats:
...
@@ -1125,8 +1125,8 @@ class ProfileStats:
mem_freed
=
0
mem_freed
=
0
max_storage
=
max_mem_count
max_storage
=
max_mem_count
dmap
=
getattr
(
node
.
op
,
"destroy_map"
,
None
)
dmap
=
node
.
op
.
destroy_map
vmap
=
getattr
(
node
.
op
,
"view_map"
,
None
)
vmap
=
node
.
op
.
view_map
idx
=
0
idx
=
0
# Update the Python emulating dicts and add the
# Update the Python emulating dicts and add the
...
@@ -1426,9 +1426,9 @@ class ProfileStats:
...
@@ -1426,9 +1426,9 @@ class ProfileStats:
items
.
sort
(
key
=
lambda
a
:
a
[
1
],
reverse
=
True
)
items
.
sort
(
key
=
lambda
a
:
a
[
1
],
reverse
=
True
)
for
idx
,
((
fgraph
,
node
),
node_outputs_size
)
in
enumerate
(
items
[:
N
]):
for
idx
,
((
fgraph
,
node
),
node_outputs_size
)
in
enumerate
(
items
[:
N
]):
code
=
[
"c"
]
*
len
(
node
.
outputs
)
code
=
[
"c"
]
*
len
(
node
.
outputs
)
for
out
,
inp
in
getattr
(
node
.
op
,
"destroy_map"
,
{})
.
items
():
for
out
,
inp
in
node
.
op
.
destroy_map
.
items
():
code
[
out
]
=
"i"
code
[
out
]
=
"i"
for
out
,
inp
in
getattr
(
node
.
op
,
"view_map"
,
{})
.
items
():
for
out
,
inp
in
node
.
op
.
view_map
.
items
():
code
[
out
]
=
"v"
code
[
out
]
=
"v"
shapes
=
str
(
fct_shapes
[
fgraph
][
node
])
shapes
=
str
(
fct_shapes
[
fgraph
][
node
])
...
...
aesara/d3viz/formatting.py
浏览文件 @
b0a1d33c
...
@@ -186,11 +186,11 @@ class PyDotFormatter:
...
@@ -186,11 +186,11 @@ class PyDotFormatter:
graph
.
add_node
(
pd_var
)
graph
.
add_node
(
pd_var
)
edge_params
=
{}
edge_params
=
{}
if
hasattr
(
node
.
op
,
"view_map"
)
and
id
in
reduce
(
if
node
.
op
.
view_map
and
id
in
reduce
(
list
.
__add__
,
node
.
op
.
view_map
.
values
(),
[]
list
.
__add__
,
node
.
op
.
view_map
.
values
(),
[]
):
):
edge_params
[
"color"
]
=
self
.
node_colors
[
"output"
]
edge_params
[
"color"
]
=
self
.
node_colors
[
"output"
]
elif
hasattr
(
node
.
op
,
"destroy_map"
)
and
id
in
reduce
(
elif
node
.
op
.
destroy_map
and
id
in
reduce
(
list
.
__add__
,
node
.
op
.
destroy_map
.
values
(),
[]
list
.
__add__
,
node
.
op
.
destroy_map
.
values
(),
[]
):
):
edge_params
[
"color"
]
=
"red"
edge_params
[
"color"
]
=
"red"
...
...
aesara/graph/destroyhandler.py
浏览文件 @
b0a1d33c
...
@@ -413,11 +413,11 @@ class DestroyHandler(Bookkeeper): # noqa
...
@@ -413,11 +413,11 @@ class DestroyHandler(Bookkeeper): # noqa
for
(
app
,
idx
)
in
fgraph
.
clients
[
protected_var
]:
for
(
app
,
idx
)
in
fgraph
.
clients
[
protected_var
]:
if
app
==
"output"
:
if
app
==
"output"
:
continue
continue
destroy_maps
=
getattr
(
app
.
op
,
"destroy_map"
,
{})
.
values
()
destroy_maps
=
app
.
op
.
destroy_map
.
values
()
# If True means that the apply node, destroys the protected_var.
# If True means that the apply node, destroys the protected_var.
if
idx
in
[
dmap
for
sublist
in
destroy_maps
for
dmap
in
sublist
]:
if
idx
in
[
dmap
for
sublist
in
destroy_maps
for
dmap
in
sublist
]:
return
True
return
True
for
var_idx
in
getattr
(
app
.
op
,
"view_map"
,
{})
.
keys
():
for
var_idx
in
app
.
op
.
view_map
.
keys
():
if
idx
in
app
.
op
.
view_map
[
var_idx
]:
if
idx
in
app
.
op
.
view_map
[
var_idx
]:
# We need to recursivly check the destroy_map of all the
# We need to recursivly check the destroy_map of all the
# outputs that we have a view_map on.
# outputs that we have a view_map on.
...
@@ -467,7 +467,7 @@ class DestroyHandler(Bookkeeper): # noqa
...
@@ -467,7 +467,7 @@ class DestroyHandler(Bookkeeper): # noqa
- Allow sequence of view.
- Allow sequence of view.
- But don't allow to destroy view
- But don't allow to destroy view
"""
"""
dm
=
getattr
(
app
.
op
,
"destroy_map"
,
None
)
dm
=
app
.
op
.
destroy_map
if
not
dm
:
if
not
dm
:
return
return
inputs
=
set
(
inputs
=
set
(
...
@@ -486,8 +486,8 @@ class DestroyHandler(Bookkeeper): # noqa
...
@@ -486,8 +486,8 @@ class DestroyHandler(Bookkeeper): # noqa
elif
inp
.
owner
:
elif
inp
.
owner
:
app2
=
inp
.
owner
app2
=
inp
.
owner
inp_idx2
=
app2
.
outputs
.
index
(
inp
)
inp_idx2
=
app2
.
outputs
.
index
(
inp
)
v
=
getattr
(
app2
.
op
,
"view_map"
,
{})
v
=
app2
.
op
.
view_map
d
=
getattr
(
app2
.
op
,
"destroy_map"
,
{})
d
=
app2
.
op
.
destroy_map
if
v
:
if
v
:
v
=
v
.
get
(
inp_idx2
,
[])
v
=
v
.
get
(
inp_idx2
,
[])
if
len
(
v
)
>
0
:
if
len
(
v
)
>
0
:
...
@@ -517,8 +517,8 @@ class DestroyHandler(Bookkeeper): # noqa
...
@@ -517,8 +517,8 @@ class DestroyHandler(Bookkeeper): # noqa
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list
# If it's a destructive op, add it to our watch list
dmap
=
getattr
(
app
.
op
,
"destroy_map"
,
None
)
dmap
=
app
.
op
.
destroy_map
vmap
=
getattr
(
app
.
op
,
"view_map"
,
{})
vmap
=
app
.
op
.
view_map
if
dmap
:
if
dmap
:
self
.
destroyers
.
add
(
app
)
self
.
destroyers
.
add
(
app
)
if
self
.
algo
==
"fast"
:
if
self
.
algo
==
"fast"
:
...
@@ -558,7 +558,7 @@ class DestroyHandler(Bookkeeper): # noqa
...
@@ -558,7 +558,7 @@ class DestroyHandler(Bookkeeper): # noqa
for
input
in
set
(
app
.
inputs
):
for
input
in
set
(
app
.
inputs
):
del
self
.
clients
[
input
][
app
]
del
self
.
clients
[
input
][
app
]
if
getattr
(
app
.
op
,
"destroy_map"
,
OrderedDict
())
:
if
app
.
op
.
destroy_map
:
self
.
destroyers
.
remove
(
app
)
self
.
destroyers
.
remove
(
app
)
# Note: leaving empty client dictionaries in the struct.
# Note: leaving empty client dictionaries in the struct.
...
@@ -566,7 +566,7 @@ class DestroyHandler(Bookkeeper): # noqa
...
@@ -566,7 +566,7 @@ class DestroyHandler(Bookkeeper): # noqa
# deleted on_detach().
# deleted on_detach().
# UPDATE self.view_i, self.view_o
# UPDATE self.view_i, self.view_o
for
o_idx
,
i_idx_list
in
getattr
(
app
.
op
,
"view_map"
,
OrderedDict
())
.
items
():
for
o_idx
,
i_idx_list
in
app
.
op
.
view_map
.
items
():
if
len
(
i_idx_list
)
>
1
:
if
len
(
i_idx_list
)
>
1
:
# destroying this output invalidates multiple inputs
# destroying this output invalidates multiple inputs
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -605,7 +605,7 @@ class DestroyHandler(Bookkeeper): # noqa
...
@@ -605,7 +605,7 @@ class DestroyHandler(Bookkeeper): # noqa
self
.
clients
[
new_r
][
app
]
+=
1
self
.
clients
[
new_r
][
app
]
+=
1
# UPDATE self.view_i, self.view_o
# UPDATE self.view_i, self.view_o
for
o_idx
,
i_idx_list
in
getattr
(
app
.
op
,
"view_map"
,
OrderedDict
())
.
items
():
for
o_idx
,
i_idx_list
in
app
.
op
.
view_map
.
items
():
if
len
(
i_idx_list
)
>
1
:
if
len
(
i_idx_list
)
>
1
:
# destroying this output invalidates multiple inputs
# destroying this output invalidates multiple inputs
raise
NotImplementedError
()
raise
NotImplementedError
()
...
...
aesara/graph/fg.py
浏览文件 @
b0a1d33c
...
@@ -205,14 +205,14 @@ class FunctionGraph(MetaObject):
...
@@ -205,14 +205,14 @@ class FunctionGraph(MetaObject):
node : aesara.graph.basic.Apply
node : aesara.graph.basic.Apply
"""
"""
if
hasattr
(
node
.
op
,
"view_map"
)
and
not
all
(
if
node
.
op
.
view_map
and
not
all
(
isinstance
(
view
,
(
list
,
tuple
))
for
view
in
node
.
op
.
view_map
.
values
()
isinstance
(
view
,
(
list
,
tuple
))
for
view
in
node
.
op
.
view_map
.
values
()
):
):
raise
Exception
(
raise
Exception
(
f
"Op '{node.op}' have a bad view map '{node.op.view_map}',"
f
"Op '{node.op}' have a bad view map '{node.op.view_map}',"
" the values must be tuples or lists."
" the values must be tuples or lists."
)
)
if
hasattr
(
node
.
op
,
"destroy_map"
)
and
not
all
(
if
node
.
op
.
destroy_map
and
not
all
(
isinstance
(
destroy
,
(
list
,
tuple
))
isinstance
(
destroy
,
(
list
,
tuple
))
for
destroy
in
node
.
op
.
destroy_map
.
values
()
for
destroy
in
node
.
op
.
destroy_map
.
values
()
):
):
...
...
aesara/graph/op.py
浏览文件 @
b0a1d33c
...
@@ -107,7 +107,7 @@ def compute_test_value(node: Apply):
...
@@ -107,7 +107,7 @@ def compute_test_value(node: Apply):
# The original values should not be destroyed, so we copy the values of the
# The original values should not be destroyed, so we copy the values of the
# inputs in `destroy_map`
# inputs in `destroy_map`
destroyed_inputs_idx
=
set
()
destroyed_inputs_idx
=
set
()
if
getattr
(
node
.
op
,
"destroy_map"
,
None
)
:
if
node
.
op
.
destroy_map
:
for
i_pos_list
in
node
.
op
.
destroy_map
.
values
():
for
i_pos_list
in
node
.
op
.
destroy_map
.
values
():
destroyed_inputs_idx
.
update
(
i_pos_list
)
destroyed_inputs_idx
.
update
(
i_pos_list
)
for
inp_idx
in
destroyed_inputs_idx
:
for
inp_idx
in
destroyed_inputs_idx
:
...
@@ -167,6 +167,29 @@ class Op(MetaObject):
...
@@ -167,6 +167,29 @@ class Op(MetaObject):
"""
"""
view_map
:
Dict
[
int
,
List
[
int
]]
=
{}
"""
A ``dict`` that maps output indices to the input indices of which they are
a view.
Examples
========
view_map = {0: [1]} # first output is a view of second input
view_map = {1: [0]} # second output is a view of first input
"""
destroy_map
:
Dict
[
int
,
List
[
int
]]
=
{}
"""
A ``dict`` that maps output indices to the input indices upon which they
operate in-place.
Examples
========
destroy_map = {0: [1]} # first output operates in-place on second input
destroy_map = {1: [0]} # second output operates in-place on first input
"""
def
make_node
(
self
,
*
inputs
:
Variable
)
->
Apply
:
def
make_node
(
self
,
*
inputs
:
Variable
)
->
Apply
:
"""Construct an `Apply` node that represent the application of this operation to the given inputs.
"""Construct an `Apply` node that represent the application of this operation to the given inputs.
...
...
aesara/graph/opt.py
浏览文件 @
b0a1d33c
...
@@ -835,7 +835,7 @@ class MergeOptimizer(GlobalOptimizer):
...
@@ -835,7 +835,7 @@ class MergeOptimizer(GlobalOptimizer):
[
[
i
in
flatten
(
c
.
op
.
destroy_map
.
values
())
i
in
flatten
(
c
.
op
.
destroy_map
.
values
())
for
c
,
i
in
clients
for
c
,
i
in
clients
if
c
!=
"output"
and
hasattr
(
c
.
op
,
"destroy_map"
)
if
c
!=
"output"
and
c
.
op
.
destroy_map
]
]
)
)
>
1
>
1
...
...
aesara/graph/toolbox.py
浏览文件 @
b0a1d33c
...
@@ -812,7 +812,7 @@ class NoOutputFromInplace(Feature):
...
@@ -812,7 +812,7 @@ class NoOutputFromInplace(Feature):
node
=
out
.
owner
node
=
out
.
owner
op
=
node
.
op
op
=
node
.
op
out_idx
=
node
.
outputs
.
index
(
out
)
out_idx
=
node
.
outputs
.
index
(
out
)
if
hasattr
(
op
,
"destroy_map"
)
and
out_idx
in
op
.
destroy_map
:
if
op
.
destroy_map
and
out_idx
in
op
.
destroy_map
:
raise
aesara
.
graph
.
fg
.
InconsistencyError
(
raise
aesara
.
graph
.
fg
.
InconsistencyError
(
"A function graph Feature has requested that outputs of the graph "
"A function graph Feature has requested that outputs of the graph "
"be prevented from being the result of in-place "
"be prevented from being the result of in-place "
...
...
aesara/link/utils.py
浏览文件 @
b0a1d33c
...
@@ -430,8 +430,8 @@ def raise_with_op(
...
@@ -430,8 +430,8 @@ def raise_with_op(
total_size_inputs
+=
sz
total_size_inputs
+=
sz
else
:
else
:
# If it is a view, don't count it twice.
# If it is a view, don't count it twice.
if
getattr
(
k
.
owner
.
op
,
"view_map"
,
None
):
vmap
=
k
.
owner
.
op
.
view_map
vmap
=
k
.
owner
.
op
.
view_map
if
vmap
:
out_idx
=
k
.
owner
.
outputs
.
index
(
k
)
out_idx
=
k
.
owner
.
outputs
.
index
(
k
)
data
=
storage_map
[
k
][
0
]
data
=
storage_map
[
k
][
0
]
if
out_idx
in
vmap
:
if
out_idx
in
vmap
:
...
@@ -445,14 +445,14 @@ def raise_with_op(
...
@@ -445,14 +445,14 @@ def raise_with_op(
# shouldn't be in the storage_map anymore
# shouldn't be in the storage_map anymore
# except if there is a special flag used. So
# except if there is a special flag used. So
# we still must check it.
# we still must check it.
if
getattr
(
k
.
owner
.
op
,
"destroy_map"
,
None
):
dmap
=
k
.
owner
.
op
.
destroy_map
vmap
=
k
.
owner
.
op
.
destroy_map
if
dmap
:
out_idx
=
k
.
owner
.
outputs
.
index
(
k
)
out_idx
=
k
.
owner
.
outputs
.
index
(
k
)
data
=
storage_map
[
k
][
0
]
data
=
storage_map
[
k
][
0
]
if
out_idx
in
v
map
:
if
out_idx
in
d
map
:
assert
len
(
v
map
[
out_idx
])
==
1
assert
len
(
d
map
[
out_idx
])
==
1
input_data
=
storage_map
[
input_data
=
storage_map
[
k
.
owner
.
inputs
[
v
map
[
out_idx
][
0
]]
k
.
owner
.
inputs
[
d
map
[
out_idx
][
0
]]
][
0
]
][
0
]
if
k
.
type
.
may_share_memory
(
data
,
input_data
):
if
k
.
type
.
may_share_memory
(
data
,
input_data
):
total_size
-=
sz
total_size
-=
sz
...
...
aesara/link/vm.py
浏览文件 @
b0a1d33c
...
@@ -36,8 +36,8 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
...
@@ -36,8 +36,8 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
for
idx
in
range
(
len
(
order
)):
for
idx
in
range
(
len
(
order
)):
node
=
order
[
idx
]
node
=
order
[
idx
]
dmap
=
getattr
(
node
.
op
,
"destroy_map"
,
None
)
dmap
=
node
.
op
.
destroy_map
vmap
=
getattr
(
node
.
op
,
"view_map"
,
None
)
vmap
=
node
.
op
.
view_map
idx_o
=
0
idx_o
=
0
for
out
in
node
.
outputs
:
for
out
in
node
.
outputs
:
...
@@ -574,9 +574,7 @@ class Stack(VM):
...
@@ -574,9 +574,7 @@ class Stack(VM):
if
(
if
(
config
.
warn__vm_gc_bug
config
.
warn__vm_gc_bug
and
current_apply
in
apply_stack
and
current_apply
in
apply_stack
and
getattr
(
and
current_apply
.
op
.
destroy_map
current_apply
.
op
,
"destroy_map"
,
False
)
):
):
warnings
.
warn
(
warnings
.
warn
(
"There was a bug that existed in "
"There was a bug that existed in "
...
...
aesara/printing.py
浏览文件 @
b0a1d33c
...
@@ -997,11 +997,11 @@ def pydotprint(
...
@@ -997,11 +997,11 @@ def pydotprint(
param
=
{}
param
=
{}
if
label
:
if
label
:
param
[
"label"
]
=
label
param
[
"label"
]
=
label
if
hasattr
(
node
.
op
,
"view_map"
)
and
idx
in
reduce
(
if
node
.
op
.
view_map
and
idx
in
reduce
(
list
.
__add__
,
node
.
op
.
view_map
.
values
(),
[]
list
.
__add__
,
node
.
op
.
view_map
.
values
(),
[]
):
):
param
[
"color"
]
=
colorCodes
[
"Output"
]
param
[
"color"
]
=
colorCodes
[
"Output"
]
elif
hasattr
(
node
.
op
,
"destroy_map"
)
and
idx
in
reduce
(
elif
node
.
op
.
destroy_map
and
idx
in
reduce
(
list
.
__add__
,
node
.
op
.
destroy_map
.
values
(),
[]
list
.
__add__
,
node
.
op
.
destroy_map
.
values
(),
[]
):
):
param
[
"color"
]
=
"red"
param
[
"color"
]
=
"red"
...
...
aesara/scan/op.py
浏览文件 @
b0a1d33c
...
@@ -794,8 +794,6 @@ class Scan(Op):
...
@@ -794,8 +794,6 @@ class Scan(Op):
else
:
else
:
name
=
"for"
name
=
"for"
aux_txt
=
"
%
s"
aux_txt
=
"
%
s"
if
getattr
(
self
,
"destroy_map"
,
None
)
is
None
:
self
.
destroy_map
=
OrderedDict
()
if
len
(
self
.
destroy_map
.
keys
())
>
0
:
if
len
(
self
.
destroy_map
.
keys
())
>
0
:
# Check if all outputs are inplace
# Check if all outputs are inplace
if
sorted
(
self
.
destroy_map
.
keys
())
==
sorted
(
if
sorted
(
self
.
destroy_map
.
keys
())
==
sorted
(
...
@@ -1027,7 +1025,7 @@ class Scan(Op):
...
@@ -1027,7 +1025,7 @@ class Scan(Op):
cython_inps_is_tensor
=
np
.
asarray
(
self
.
inps_is_tensor
,
dtype
=
"int32"
)
cython_inps_is_tensor
=
np
.
asarray
(
self
.
inps_is_tensor
,
dtype
=
"int32"
)
cython_outs_is_tensor
=
np
.
asarray
(
self
.
outs_is_tensor
,
dtype
=
"int32"
)
cython_outs_is_tensor
=
np
.
asarray
(
self
.
outs_is_tensor
,
dtype
=
"int32"
)
if
hasattr
(
self
,
"destroy_map"
)
:
if
self
.
destroy_map
:
cython_destroy_map
=
[
cython_destroy_map
=
[
x
in
self
.
destroy_map
for
x
in
range
(
len
(
node
.
outputs
))
x
in
self
.
destroy_map
for
x
in
range
(
len
(
node
.
outputs
))
]
]
...
@@ -1321,8 +1319,6 @@ class Scan(Op):
...
@@ -1321,8 +1319,6 @@ class Scan(Op):
(
-
self
.
mintaps
[
idx
])
%
store_steps
[
idx
]
(
-
self
.
mintaps
[
idx
])
%
store_steps
[
idx
]
for
idx
in
range
(
self
.
n_outs
+
self
.
n_nit_sot
)
for
idx
in
range
(
self
.
n_outs
+
self
.
n_nit_sot
)
]
]
if
not
getattr
(
self
,
"destroy_map"
,
None
):
self
.
destroy_map
=
OrderedDict
()
# 2.1 Create storage space for outputs
# 2.1 Create storage space for outputs
for
idx
in
range
(
self
.
n_outs
):
for
idx
in
range
(
self
.
n_outs
):
if
idx
in
self
.
destroy_map
:
if
idx
in
self
.
destroy_map
:
...
...
aesara/scan/opt.py
浏览文件 @
b0a1d33c
...
@@ -1119,7 +1119,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
...
@@ -1119,7 +1119,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
# Get the indices of this client's inputs on which it
# Get the indices of this client's inputs on which it
# operates inplace
# operates inplace
if
hasattr
(
client
.
op
,
"destroy_map"
)
:
if
client
.
op
.
destroy_map
:
# This flattens the content of destroy_map.values()
# This flattens the content of destroy_map.values()
# which is a list of lists
# which is a list of lists
inplace_inp_indices
=
sum
(
client
.
op
.
destroy_map
.
values
(),
[])
inplace_inp_indices
=
sum
(
client
.
op
.
destroy_map
.
values
(),
[])
...
...
tests/compile/test_mode.py
浏览文件 @
b0a1d33c
...
@@ -20,7 +20,7 @@ def test_no_output_from_implace():
...
@@ -20,7 +20,7 @@ def test_no_output_from_implace():
# using a mode that does not include the optimization
# using a mode that does not include the optimization
fct_no_opt
=
aesara
.
function
([
x
,
y
],
b
,
mode
=
"FAST_RUN"
)
fct_no_opt
=
aesara
.
function
([
x
,
y
],
b
,
mode
=
"FAST_RUN"
)
op
=
fct_no_opt
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
op
=
fct_no_opt
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
assert
hasattr
(
op
,
"destroy_map"
)
and
0
in
op
.
destroy_map
assert
op
.
destroy_map
and
0
in
op
.
destroy_map
# Ensure that the elemwise op that produces the output is not inplace when
# Ensure that the elemwise op that produces the output is not inplace when
# using a mode that includes the optimization
# using a mode that includes the optimization
...
@@ -29,7 +29,7 @@ def test_no_output_from_implace():
...
@@ -29,7 +29,7 @@ def test_no_output_from_implace():
fct_opt
=
aesara
.
function
([
x
,
y
],
b
,
mode
=
mode_opt
)
fct_opt
=
aesara
.
function
([
x
,
y
],
b
,
mode
=
mode_opt
)
op
=
fct_opt
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
op
=
fct_opt
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
assert
not
hasattr
(
op
,
"destroy_map"
)
or
0
not
in
op
.
destroy_map
assert
not
op
.
destroy_map
or
0
not
in
op
.
destroy_map
def
test_including
():
def
test_including
():
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论