Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
b0a1d33c
提交
b0a1d33c
authored
4月 10, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
4月 11, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add view_map and destroy_map variables and docstrings to Op
上级
cae78759
隐藏空白字符变更
内嵌
并排
正在显示
15 个修改的文件
包含
81 行增加
和
64 行删除
+81
-64
debugmode.py
aesara/compile/debugmode.py
+14
-14
types.py
aesara/compile/function/types.py
+5
-5
profiling.py
aesara/compile/profiling.py
+6
-6
formatting.py
aesara/d3viz/formatting.py
+2
-2
destroyhandler.py
aesara/graph/destroyhandler.py
+10
-10
fg.py
aesara/graph/fg.py
+2
-2
op.py
aesara/graph/op.py
+24
-1
opt.py
aesara/graph/opt.py
+1
-1
toolbox.py
aesara/graph/toolbox.py
+1
-1
utils.py
aesara/link/utils.py
+7
-7
vm.py
aesara/link/vm.py
+3
-5
printing.py
aesara/printing.py
+2
-2
op.py
aesara/scan/op.py
+1
-5
opt.py
aesara/scan/opt.py
+1
-1
test_mode.py
tests/compile/test_mode.py
+2
-2
没有找到文件。
aesara/compile/debugmode.py
浏览文件 @
b0a1d33c
...
...
@@ -166,7 +166,7 @@ class BadDestroyMap(DebugModeError):
print
(
" node:"
,
self
.
node
,
file
=
sio
)
print
(
" perform:"
,
self
.
perform
,
file
=
sio
)
print
(
" node.inputs:"
,
[(
str
(
i
),
id
(
i
))
for
i
in
self
.
node
.
inputs
],
file
=
sio
)
print
(
" destroy_map:"
,
getattr
(
self
.
node
.
op
,
"destroy_map"
,
{})
,
file
=
sio
)
print
(
" destroy_map:"
,
self
.
node
.
op
.
destroy_map
,
file
=
sio
)
print
(
" changed input idx:"
,
self
.
idx
,
file
=
sio
)
print
(
" changed input type:"
,
self
.
node
.
inputs
[
self
.
idx
]
.
type
,
file
=
sio
)
print
(
" repr (old val):"
,
repr
(
self
.
old_val
),
file
=
sio
)
...
...
@@ -250,8 +250,8 @@ class BadViewMap(DebugModeError):
print
(
" node:"
,
self
.
node
,
file
=
sio
)
print
(
" node.inputs:"
,
[(
str
(
i
),
id
(
i
))
for
i
in
self
.
node
.
inputs
],
file
=
sio
)
print
(
" node.outputs:"
,
[(
str
(
i
),
id
(
i
))
for
i
in
self
.
node
.
outputs
],
file
=
sio
)
print
(
" view_map:"
,
getattr
(
self
.
node
.
op
,
"view_map"
,
{})
,
file
=
sio
)
print
(
" destroy_map:"
,
getattr
(
self
.
node
.
op
,
"destroy_map"
,
{})
,
file
=
sio
)
print
(
" view_map:"
,
self
.
node
.
op
.
view_map
,
file
=
sio
)
print
(
" destroy_map:"
,
self
.
node
.
op
.
destroy_map
,
file
=
sio
)
print
(
" aliased output:"
,
self
.
output_idx
,
file
=
sio
)
print
(
" aliased output storage:"
,
self
.
out_storage
,
file
=
sio
)
if
self
.
in_alias_idx
:
...
...
@@ -554,12 +554,12 @@ def debugprint(
r_name
=
""
if
print_destroy_map
:
destroy_map_str
=
str
(
getattr
(
r
.
owner
.
op
,
"destroy_map"
,
""
)
)
destroy_map_str
=
str
(
r
.
owner
.
op
.
destroy_map
)
else
:
destroy_map_str
=
""
if
print_view_map
:
view_map_str
=
str
(
getattr
(
r
.
owner
.
op
,
"view_map"
,
""
)
)
view_map_str
=
str
(
r
.
owner
.
op
.
view_map
)
else
:
view_map_str
=
""
if
destroy_map_str
and
destroy_map_str
!=
"{}"
:
...
...
@@ -742,13 +742,13 @@ def _check_inputs(
"""
destroyed_idx_list
=
[]
destroy_map
=
getattr
(
node
.
op
,
"destroy_map"
,
{})
destroy_map
=
node
.
op
.
destroy_map
for
o_pos
,
i_pos_list
in
destroy_map
.
items
():
destroyed_idx_list
.
extend
(
i_pos_list
)
destroyed_res_list
=
[
node
.
inputs
[
i
]
for
i
in
destroyed_idx_list
]
actually_inplace_outputs
=
[]
dmap
=
getattr
(
node
.
op
,
"destroy_map"
,
{})
dmap
=
node
.
op
.
destroy_map
for
oo
,
ii
in
dmap
.
items
():
var
=
node
.
outputs
[
oo
]
out_var
=
storage_map
[
var
][
0
]
...
...
@@ -769,7 +769,7 @@ def _check_inputs(
f
"as destroyed was not changed for node '{node}'"
)
vmap
=
getattr
(
node
.
op
,
"view_map"
,
{})
vmap
=
node
.
op
.
view_map
for
oo
,
ii
in
vmap
.
items
():
var
=
node
.
outputs
[
oo
]
out_var
=
storage_map
[
var
][
0
]
...
...
@@ -836,8 +836,8 @@ def _check_viewmap(fgraph, node, storage_map):
outstorage
=
storage_map
[
onode
][
0
]
# first find out which input it aliases
view_map
=
getattr
(
node
.
op
,
"view_map"
,
{})
destroy_map
=
getattr
(
node
.
op
,
"destroy_map"
,
{})
view_map
=
node
.
op
.
view_map
destroy_map
=
node
.
op
.
destroy_map
# In theory, aesara's view_map only allows for 1 output to
# alias 1 input. Checking for multiple aliases just in
...
...
@@ -1395,8 +1395,8 @@ def _check_preallocated_output(
# Set of inputs that are marked as destroyed or viewed
aliased_inputs
=
set
()
dmap
=
getattr
(
node
.
op
,
"destroy_map"
,
{})
vmap
=
getattr
(
node
.
op
,
"view_map"
,
{})
dmap
=
node
.
op
.
destroy_map
vmap
=
node
.
op
.
view_map
for
i
,
r
in
enumerate
(
node
.
inputs
):
if
any
(
i
in
v
for
v
in
chain
(
dmap
.
values
(),
vmap
.
values
())):
aliased_inputs
.
add
(
r
)
...
...
@@ -2082,8 +2082,8 @@ class _Linker(LocalLinker):
clobber
=
True
if
thunk_py
:
dmap
=
getattr
(
node
.
op
,
"destroy_map"
,
{})
vmap
=
getattr
(
node
.
op
,
"view_map"
,
{})
dmap
=
node
.
op
.
destroy_map
vmap
=
node
.
op
.
view_map
for
i
,
r
in
enumerate
(
node
.
inputs
):
# if thunk_py ran, and we still got
# this far, it means that the
...
...
aesara/compile/function/types.py
浏览文件 @
b0a1d33c
...
...
@@ -57,8 +57,8 @@ def alias_root(v):
"""
if
v
.
owner
is
None
:
return
v
vmap
=
getattr
(
v
.
owner
.
op
,
"view_map"
,
{})
dmap
=
getattr
(
v
.
owner
.
op
,
"destroy_map"
,
{})
vmap
=
v
.
owner
.
op
.
view_map
dmap
=
v
.
owner
.
op
.
destroy_map
outpos
=
v
.
owner
.
outputs
.
index
(
v
)
v_views
=
vmap
.
get
(
outpos
,
[])
+
dmap
.
get
(
outpos
,
[])
if
len
(
v_views
)
>
1
:
...
...
@@ -83,8 +83,8 @@ def view_tree_set(fgraph, v, treeset):
for
cl
,
v_input_pos_to_cl
in
fgraph
.
clients
[
v
]:
if
cl
==
"output"
:
continue
vmap
=
getattr
(
cl
.
op
,
"view_map"
,
{})
dmap
=
getattr
(
cl
.
op
,
"destroy_map"
,
{})
vmap
=
cl
.
op
.
view_map
dmap
=
cl
.
op
.
destroy_map
for
opos
,
iposlist
in
chain
(
vmap
.
items
(),
dmap
.
items
()):
if
v_input_pos_to_cl
in
iposlist
:
if
cl
.
outputs
[
opos
]
not
in
treeset
:
...
...
@@ -189,7 +189,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
fgraph
=
FunctionGraph
(
orig_inputs
,
orig_outputs
,
update_mapping
=
update_mapping
)
for
node
in
fgraph
.
apply_nodes
:
if
getattr
(
node
.
op
,
"destroy_map"
,
None
)
:
if
node
.
op
.
destroy_map
:
if
not
accept_inplace
:
raise
TypeError
(
"Graph must not contain inplace operations"
,
node
,
node
.
op
...
...
aesara/compile/profiling.py
浏览文件 @
b0a1d33c
...
...
@@ -962,8 +962,8 @@ class ProfileStats:
if
ignore_dmap
:
dmap
=
None
else
:
dmap
=
getattr
(
node
.
op
,
"destroy_map"
,
None
)
vmap
=
getattr
(
node
.
op
,
"view_map"
,
None
)
dmap
=
node
.
op
.
destroy_map
vmap
=
node
.
op
.
view_map
val
=
nodes_mem
[
node
]
for
v
in
val
:
...
...
@@ -1125,8 +1125,8 @@ class ProfileStats:
mem_freed
=
0
max_storage
=
max_mem_count
dmap
=
getattr
(
node
.
op
,
"destroy_map"
,
None
)
vmap
=
getattr
(
node
.
op
,
"view_map"
,
None
)
dmap
=
node
.
op
.
destroy_map
vmap
=
node
.
op
.
view_map
idx
=
0
# Update the Python emulating dicts and add the
...
...
@@ -1426,9 +1426,9 @@ class ProfileStats:
items
.
sort
(
key
=
lambda
a
:
a
[
1
],
reverse
=
True
)
for
idx
,
((
fgraph
,
node
),
node_outputs_size
)
in
enumerate
(
items
[:
N
]):
code
=
[
"c"
]
*
len
(
node
.
outputs
)
for
out
,
inp
in
getattr
(
node
.
op
,
"destroy_map"
,
{})
.
items
():
for
out
,
inp
in
node
.
op
.
destroy_map
.
items
():
code
[
out
]
=
"i"
for
out
,
inp
in
getattr
(
node
.
op
,
"view_map"
,
{})
.
items
():
for
out
,
inp
in
node
.
op
.
view_map
.
items
():
code
[
out
]
=
"v"
shapes
=
str
(
fct_shapes
[
fgraph
][
node
])
...
...
aesara/d3viz/formatting.py
浏览文件 @
b0a1d33c
...
...
@@ -186,11 +186,11 @@ class PyDotFormatter:
graph
.
add_node
(
pd_var
)
edge_params
=
{}
if
hasattr
(
node
.
op
,
"view_map"
)
and
id
in
reduce
(
if
node
.
op
.
view_map
and
id
in
reduce
(
list
.
__add__
,
node
.
op
.
view_map
.
values
(),
[]
):
edge_params
[
"color"
]
=
self
.
node_colors
[
"output"
]
elif
hasattr
(
node
.
op
,
"destroy_map"
)
and
id
in
reduce
(
elif
node
.
op
.
destroy_map
and
id
in
reduce
(
list
.
__add__
,
node
.
op
.
destroy_map
.
values
(),
[]
):
edge_params
[
"color"
]
=
"red"
...
...
aesara/graph/destroyhandler.py
浏览文件 @
b0a1d33c
...
...
@@ -413,11 +413,11 @@ class DestroyHandler(Bookkeeper): # noqa
for
(
app
,
idx
)
in
fgraph
.
clients
[
protected_var
]:
if
app
==
"output"
:
continue
destroy_maps
=
getattr
(
app
.
op
,
"destroy_map"
,
{})
.
values
()
destroy_maps
=
app
.
op
.
destroy_map
.
values
()
# If True means that the apply node, destroys the protected_var.
if
idx
in
[
dmap
for
sublist
in
destroy_maps
for
dmap
in
sublist
]:
return
True
for
var_idx
in
getattr
(
app
.
op
,
"view_map"
,
{})
.
keys
():
for
var_idx
in
app
.
op
.
view_map
.
keys
():
if
idx
in
app
.
op
.
view_map
[
var_idx
]:
# We need to recursivly check the destroy_map of all the
# outputs that we have a view_map on.
...
...
@@ -467,7 +467,7 @@ class DestroyHandler(Bookkeeper): # noqa
- Allow sequence of view.
- But don't allow to destroy view
"""
dm
=
getattr
(
app
.
op
,
"destroy_map"
,
None
)
dm
=
app
.
op
.
destroy_map
if
not
dm
:
return
inputs
=
set
(
...
...
@@ -486,8 +486,8 @@ class DestroyHandler(Bookkeeper): # noqa
elif
inp
.
owner
:
app2
=
inp
.
owner
inp_idx2
=
app2
.
outputs
.
index
(
inp
)
v
=
getattr
(
app2
.
op
,
"view_map"
,
{})
d
=
getattr
(
app2
.
op
,
"destroy_map"
,
{})
v
=
app2
.
op
.
view_map
d
=
app2
.
op
.
destroy_map
if
v
:
v
=
v
.
get
(
inp_idx2
,
[])
if
len
(
v
)
>
0
:
...
...
@@ -517,8 +517,8 @@ class DestroyHandler(Bookkeeper): # noqa
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list
dmap
=
getattr
(
app
.
op
,
"destroy_map"
,
None
)
vmap
=
getattr
(
app
.
op
,
"view_map"
,
{})
dmap
=
app
.
op
.
destroy_map
vmap
=
app
.
op
.
view_map
if
dmap
:
self
.
destroyers
.
add
(
app
)
if
self
.
algo
==
"fast"
:
...
...
@@ -558,7 +558,7 @@ class DestroyHandler(Bookkeeper): # noqa
for
input
in
set
(
app
.
inputs
):
del
self
.
clients
[
input
][
app
]
if
getattr
(
app
.
op
,
"destroy_map"
,
OrderedDict
())
:
if
app
.
op
.
destroy_map
:
self
.
destroyers
.
remove
(
app
)
# Note: leaving empty client dictionaries in the struct.
...
...
@@ -566,7 +566,7 @@ class DestroyHandler(Bookkeeper): # noqa
# deleted on_detach().
# UPDATE self.view_i, self.view_o
for
o_idx
,
i_idx_list
in
getattr
(
app
.
op
,
"view_map"
,
OrderedDict
())
.
items
():
for
o_idx
,
i_idx_list
in
app
.
op
.
view_map
.
items
():
if
len
(
i_idx_list
)
>
1
:
# destroying this output invalidates multiple inputs
raise
NotImplementedError
()
...
...
@@ -605,7 +605,7 @@ class DestroyHandler(Bookkeeper): # noqa
self
.
clients
[
new_r
][
app
]
+=
1
# UPDATE self.view_i, self.view_o
for
o_idx
,
i_idx_list
in
getattr
(
app
.
op
,
"view_map"
,
OrderedDict
())
.
items
():
for
o_idx
,
i_idx_list
in
app
.
op
.
view_map
.
items
():
if
len
(
i_idx_list
)
>
1
:
# destroying this output invalidates multiple inputs
raise
NotImplementedError
()
...
...
aesara/graph/fg.py
浏览文件 @
b0a1d33c
...
...
@@ -205,14 +205,14 @@ class FunctionGraph(MetaObject):
node : aesara.graph.basic.Apply
"""
if
hasattr
(
node
.
op
,
"view_map"
)
and
not
all
(
if
node
.
op
.
view_map
and
not
all
(
isinstance
(
view
,
(
list
,
tuple
))
for
view
in
node
.
op
.
view_map
.
values
()
):
raise
Exception
(
f
"Op '{node.op}' have a bad view map '{node.op.view_map}',"
" the values must be tuples or lists."
)
if
hasattr
(
node
.
op
,
"destroy_map"
)
and
not
all
(
if
node
.
op
.
destroy_map
and
not
all
(
isinstance
(
destroy
,
(
list
,
tuple
))
for
destroy
in
node
.
op
.
destroy_map
.
values
()
):
...
...
aesara/graph/op.py
浏览文件 @
b0a1d33c
...
...
@@ -107,7 +107,7 @@ def compute_test_value(node: Apply):
# The original values should not be destroyed, so we copy the values of the
# inputs in `destroy_map`
destroyed_inputs_idx
=
set
()
if
getattr
(
node
.
op
,
"destroy_map"
,
None
)
:
if
node
.
op
.
destroy_map
:
for
i_pos_list
in
node
.
op
.
destroy_map
.
values
():
destroyed_inputs_idx
.
update
(
i_pos_list
)
for
inp_idx
in
destroyed_inputs_idx
:
...
...
@@ -167,6 +167,29 @@ class Op(MetaObject):
"""
view_map
:
Dict
[
int
,
List
[
int
]]
=
{}
"""
A ``dict`` that maps output indices to the input indices of which they are
a view.
Examples
========
view_map = {0: [1]} # first output is a view of second input
view_map = {1: [0]} # second output is a view of first input
"""
destroy_map
:
Dict
[
int
,
List
[
int
]]
=
{}
"""
A ``dict`` that maps output indices to the input indices upon which they
operate in-place.
Examples
========
destroy_map = {0: [1]} # first output operates in-place on second input
destroy_map = {1: [0]} # second output operates in-place on first input
"""
def
make_node
(
self
,
*
inputs
:
Variable
)
->
Apply
:
"""Construct an `Apply` node that represent the application of this operation to the given inputs.
...
...
aesara/graph/opt.py
浏览文件 @
b0a1d33c
...
...
@@ -835,7 +835,7 @@ class MergeOptimizer(GlobalOptimizer):
[
i
in
flatten
(
c
.
op
.
destroy_map
.
values
())
for
c
,
i
in
clients
if
c
!=
"output"
and
hasattr
(
c
.
op
,
"destroy_map"
)
if
c
!=
"output"
and
c
.
op
.
destroy_map
]
)
>
1
...
...
aesara/graph/toolbox.py
浏览文件 @
b0a1d33c
...
...
@@ -812,7 +812,7 @@ class NoOutputFromInplace(Feature):
node
=
out
.
owner
op
=
node
.
op
out_idx
=
node
.
outputs
.
index
(
out
)
if
hasattr
(
op
,
"destroy_map"
)
and
out_idx
in
op
.
destroy_map
:
if
op
.
destroy_map
and
out_idx
in
op
.
destroy_map
:
raise
aesara
.
graph
.
fg
.
InconsistencyError
(
"A function graph Feature has requested that outputs of the graph "
"be prevented from being the result of in-place "
...
...
aesara/link/utils.py
浏览文件 @
b0a1d33c
...
...
@@ -430,8 +430,8 @@ def raise_with_op(
total_size_inputs
+=
sz
else
:
# If it is a view, don't count it twice.
if
getattr
(
k
.
owner
.
op
,
"view_map"
,
None
):
vmap
=
k
.
owner
.
op
.
view_map
vmap
=
k
.
owner
.
op
.
view_map
if
vmap
:
out_idx
=
k
.
owner
.
outputs
.
index
(
k
)
data
=
storage_map
[
k
][
0
]
if
out_idx
in
vmap
:
...
...
@@ -445,14 +445,14 @@ def raise_with_op(
# shouldn't be in the storage_map anymore
# except if there is a special flag used. So
# we still must check it.
if
getattr
(
k
.
owner
.
op
,
"destroy_map"
,
None
):
vmap
=
k
.
owner
.
op
.
destroy_map
dmap
=
k
.
owner
.
op
.
destroy_map
if
dmap
:
out_idx
=
k
.
owner
.
outputs
.
index
(
k
)
data
=
storage_map
[
k
][
0
]
if
out_idx
in
v
map
:
assert
len
(
v
map
[
out_idx
])
==
1
if
out_idx
in
d
map
:
assert
len
(
d
map
[
out_idx
])
==
1
input_data
=
storage_map
[
k
.
owner
.
inputs
[
v
map
[
out_idx
][
0
]]
k
.
owner
.
inputs
[
d
map
[
out_idx
][
0
]]
][
0
]
if
k
.
type
.
may_share_memory
(
data
,
input_data
):
total_size
-=
sz
...
...
aesara/link/vm.py
浏览文件 @
b0a1d33c
...
...
@@ -36,8 +36,8 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
for
idx
in
range
(
len
(
order
)):
node
=
order
[
idx
]
dmap
=
getattr
(
node
.
op
,
"destroy_map"
,
None
)
vmap
=
getattr
(
node
.
op
,
"view_map"
,
None
)
dmap
=
node
.
op
.
destroy_map
vmap
=
node
.
op
.
view_map
idx_o
=
0
for
out
in
node
.
outputs
:
...
...
@@ -574,9 +574,7 @@ class Stack(VM):
if
(
config
.
warn__vm_gc_bug
and
current_apply
in
apply_stack
and
getattr
(
current_apply
.
op
,
"destroy_map"
,
False
)
and
current_apply
.
op
.
destroy_map
):
warnings
.
warn
(
"There was a bug that existed in "
...
...
aesara/printing.py
浏览文件 @
b0a1d33c
...
...
@@ -997,11 +997,11 @@ def pydotprint(
param
=
{}
if
label
:
param
[
"label"
]
=
label
if
hasattr
(
node
.
op
,
"view_map"
)
and
idx
in
reduce
(
if
node
.
op
.
view_map
and
idx
in
reduce
(
list
.
__add__
,
node
.
op
.
view_map
.
values
(),
[]
):
param
[
"color"
]
=
colorCodes
[
"Output"
]
elif
hasattr
(
node
.
op
,
"destroy_map"
)
and
idx
in
reduce
(
elif
node
.
op
.
destroy_map
and
idx
in
reduce
(
list
.
__add__
,
node
.
op
.
destroy_map
.
values
(),
[]
):
param
[
"color"
]
=
"red"
...
...
aesara/scan/op.py
浏览文件 @
b0a1d33c
...
...
@@ -794,8 +794,6 @@ class Scan(Op):
else
:
name
=
"for"
aux_txt
=
"
%
s"
if
getattr
(
self
,
"destroy_map"
,
None
)
is
None
:
self
.
destroy_map
=
OrderedDict
()
if
len
(
self
.
destroy_map
.
keys
())
>
0
:
# Check if all outputs are inplace
if
sorted
(
self
.
destroy_map
.
keys
())
==
sorted
(
...
...
@@ -1027,7 +1025,7 @@ class Scan(Op):
cython_inps_is_tensor
=
np
.
asarray
(
self
.
inps_is_tensor
,
dtype
=
"int32"
)
cython_outs_is_tensor
=
np
.
asarray
(
self
.
outs_is_tensor
,
dtype
=
"int32"
)
if
hasattr
(
self
,
"destroy_map"
)
:
if
self
.
destroy_map
:
cython_destroy_map
=
[
x
in
self
.
destroy_map
for
x
in
range
(
len
(
node
.
outputs
))
]
...
...
@@ -1321,8 +1319,6 @@ class Scan(Op):
(
-
self
.
mintaps
[
idx
])
%
store_steps
[
idx
]
for
idx
in
range
(
self
.
n_outs
+
self
.
n_nit_sot
)
]
if
not
getattr
(
self
,
"destroy_map"
,
None
):
self
.
destroy_map
=
OrderedDict
()
# 2.1 Create storage space for outputs
for
idx
in
range
(
self
.
n_outs
):
if
idx
in
self
.
destroy_map
:
...
...
aesara/scan/opt.py
浏览文件 @
b0a1d33c
...
...
@@ -1119,7 +1119,7 @@ class ScanInplaceOptimizer(GlobalOptimizer):
# Get the indices of this client's inputs on which it
# operates inplace
if
hasattr
(
client
.
op
,
"destroy_map"
)
:
if
client
.
op
.
destroy_map
:
# This flattens the content of destroy_map.values()
# which is a list of lists
inplace_inp_indices
=
sum
(
client
.
op
.
destroy_map
.
values
(),
[])
...
...
tests/compile/test_mode.py
浏览文件 @
b0a1d33c
...
...
@@ -20,7 +20,7 @@ def test_no_output_from_implace():
# using a mode that does not include the optimization
fct_no_opt
=
aesara
.
function
([
x
,
y
],
b
,
mode
=
"FAST_RUN"
)
op
=
fct_no_opt
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
assert
hasattr
(
op
,
"destroy_map"
)
and
0
in
op
.
destroy_map
assert
op
.
destroy_map
and
0
in
op
.
destroy_map
# Ensure that the elemwise op that produces the output is not inplace when
# using a mode that includes the optimization
...
...
@@ -29,7 +29,7 @@ def test_no_output_from_implace():
fct_opt
=
aesara
.
function
([
x
,
y
],
b
,
mode
=
mode_opt
)
op
=
fct_opt
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
assert
not
hasattr
(
op
,
"destroy_map"
)
or
0
not
in
op
.
destroy_map
assert
not
op
.
destroy_map
or
0
not
in
op
.
destroy_map
def
test_including
():
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论