Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
11c4882a
提交
11c4882a
authored
10月 15, 2013
作者:
Pascal Lamblin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1556 from nouiz/mixed2
Faster compilation and sparse stuff.
上级
6ec66344
a6ab5b22
显示空白字符变更
内嵌
并排
正在显示
12 个修改的文件
包含
166 行增加
和
74 行删除
+166
-74
index.txt
doc/library/sparse/index.txt
+1
-0
compiledir.py
theano/gof/compiledir.py
+4
-0
destroyhandler.py
theano/gof/destroyhandler.py
+2
-2
fg.py
theano/gof/fg.py
+1
-8
opt.py
theano/gof/opt.py
+67
-20
ops.py
theano/sandbox/linalg/ops.py
+1
-1
basic.py
theano/sparse/basic.py
+8
-0
test_basic.py
theano/sparse/tests/test_basic.py
+22
-25
blas_c.py
theano/tensor/blas_c.py
+8
-4
elemwise.py
theano/tensor/elemwise.py
+1
-1
opt.py
theano/tensor/opt.py
+31
-7
test_opt.py
theano/tensor/tests/test_opt.py
+20
-6
没有找到文件。
doc/library/sparse/index.txt
浏览文件 @
11c4882a
...
@@ -123,6 +123,7 @@ List of Implemented Operations
...
@@ -123,6 +123,7 @@ List of Implemented Operations
Both grad are implemented. Structured by default.
Both grad are implemented. Structured by default.
- :class:`SparseFromDense <theano.sparse.basic.SparseFromDense>` and ``csr_from_dense``, ``csc_from_dense``.
- :class:`SparseFromDense <theano.sparse.basic.SparseFromDense>` and ``csr_from_dense``, ``csc_from_dense``.
The grad implemented is structured.
The grad implemented is structured.
- Theano SparseVariable object have a method ``toarray()`` that is the same as ``dense_from_sparse``.
- Construction of Sparses and their Properties
- Construction of Sparses and their Properties
- :class:`CSM <theano.sparse.basic.CSM>` and ``CSC``, ``CSR`` to construct a matrix.
- :class:`CSM <theano.sparse.basic.CSM>` and ``CSC``, ``CSR`` to construct a matrix.
...
...
theano/gof/compiledir.py
浏览文件 @
11c4882a
...
@@ -208,6 +208,8 @@ def cleanup():
...
@@ -208,6 +208,8 @@ def cleanup():
have_c_compiler
=
False
have_c_compiler
=
False
for
obj
in
flatten
(
key
):
for
obj
in
flatten
(
key
):
if
isinstance
(
obj
,
numpy
.
ndarray
):
if
isinstance
(
obj
,
numpy
.
ndarray
):
#Reuse have_npy_abi_version to
#force the removing of key
have_npy_abi_version
=
False
have_npy_abi_version
=
False
break
break
elif
isinstance
(
obj
,
basestring
):
elif
isinstance
(
obj
,
basestring
):
...
@@ -219,6 +221,8 @@ def cleanup():
...
@@ -219,6 +221,8 @@ def cleanup():
hasattr
(
obj
,
'c_code_cache_version'
)):
hasattr
(
obj
,
'c_code_cache_version'
)):
v
=
obj
.
c_code_cache_version
()
v
=
obj
.
c_code_cache_version
()
if
v
not
in
[(),
None
]
and
v
not
in
key
[
0
]:
if
v
not
in
[(),
None
]
and
v
not
in
key
[
0
]:
#Reuse have_npy_abi_version to
#force the removing of key
have_npy_abi_version
=
False
have_npy_abi_version
=
False
break
break
...
...
theano/gof/destroyhandler.py
浏览文件 @
11c4882a
...
@@ -442,7 +442,7 @@ if 0:
...
@@ -442,7 +442,7 @@ if 0:
self
.
stale_droot
=
True
self
.
stale_droot
=
True
def
on_change_input
(
self
,
fgraph
,
app
,
i
,
old_r
,
new_r
):
def
on_change_input
(
self
,
fgraph
,
app
,
i
,
old_r
,
new_r
,
reason
):
"""app.inputs[i] changed from old_r to new_r """
"""app.inputs[i] changed from old_r to new_r """
if
app
==
'output'
:
if
app
==
'output'
:
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
...
@@ -827,7 +827,7 @@ class DestroyHandler(toolbox.Bookkeeper):
...
@@ -827,7 +827,7 @@ class DestroyHandler(toolbox.Bookkeeper):
self
.
stale_droot
=
True
self
.
stale_droot
=
True
def
on_change_input
(
self
,
fgraph
,
app
,
i
,
old_r
,
new_r
):
def
on_change_input
(
self
,
fgraph
,
app
,
i
,
old_r
,
new_r
,
reason
):
"""app.inputs[i] changed from old_r to new_r """
"""app.inputs[i] changed from old_r to new_r """
if
app
==
'output'
:
if
app
==
'output'
:
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
...
...
theano/gof/fg.py
浏览文件 @
11c4882a
...
@@ -376,7 +376,7 @@ class FunctionGraph(utils.object2):
...
@@ -376,7 +376,7 @@ class FunctionGraph(utils.object2):
current value of node.inputs[i] which we want to replace.
current value of node.inputs[i] which we want to replace.
For each feature that has a 'on_change_input' method, calls:
For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(function_graph, node, i, old_r, new_r,
[reason]
)
feature.on_change_input(function_graph, node, i, old_r, new_r,
reason
)
"""
"""
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if
node
==
'output'
:
if
node
==
'output'
:
...
@@ -512,14 +512,7 @@ class FunctionGraph(utils.object2):
...
@@ -512,14 +512,7 @@ class FunctionGraph(utils.object2):
# not existing
# not existing
continue
continue
#####HORRIBLE OPTIONAL ARGUMENT HACK
try
:
fn
(
self
,
*
args
,
**
kwargs
)
fn
(
self
,
*
args
,
**
kwargs
)
except
TypeError
,
e
:
if
str
(
e
)
==
"on_change_input() got an unexpected keyword argument 'reason'"
and
len
(
kwargs
)
==
1
:
fn
(
self
,
*
args
)
else
:
raise
def
collect_callbacks
(
self
,
name
,
*
args
):
def
collect_callbacks
(
self
,
name
,
*
args
):
"""WRITEME
"""WRITEME
...
...
theano/gof/opt.py
浏览文件 @
11c4882a
...
@@ -423,7 +423,7 @@ class MergeFeature(object):
...
@@ -423,7 +423,7 @@ class MergeFeature(object):
for
node
in
fgraph
.
toposort
():
for
node
in
fgraph
.
toposort
():
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
self
.
on_import
(
fgraph
,
node
,
"on_attach"
)
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
):
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
):
# If inputs to node change, it is not guaranteed that it is distinct
# If inputs to node change, it is not guaranteed that it is distinct
# from the other nodes in nodes_seen
# from the other nodes in nodes_seen
if
node
in
self
.
nodes_seen
:
if
node
in
self
.
nodes_seen
:
...
@@ -555,6 +555,9 @@ class MergeOptimizer(Optimizer):
...
@@ -555,6 +555,9 @@ class MergeOptimizer(Optimizer):
# clear blacklist
# clear blacklist
fgraph
.
merge_feature
.
blacklist
=
[]
fgraph
.
merge_feature
.
blacklist
=
[]
def
__str__
(
self
):
return
self
.
__class__
.
__name__
merge_optimizer
=
MergeOptimizer
()
merge_optimizer
=
MergeOptimizer
()
...
@@ -1171,7 +1174,7 @@ class NavigatorOptimizer(Optimizer):
...
@@ -1171,7 +1174,7 @@ class NavigatorOptimizer(Optimizer):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
def
on_prune
(
self
,
fgraph
,
node
,
reason
):
pruner
(
node
)
pruner
(
node
)
if
chin
is
not
None
:
if
chin
is
not
None
:
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
):
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
):
chin
(
node
,
i
,
r
,
new_r
)
chin
(
node
,
i
,
r
,
new_r
)
u
=
Updater
()
u
=
Updater
()
...
@@ -1302,6 +1305,10 @@ class TopoOptimizer(NavigatorOptimizer):
...
@@ -1302,6 +1305,10 @@ class TopoOptimizer(NavigatorOptimizer):
raise
raise
self
.
detach_updater
(
fgraph
,
u
)
self
.
detach_updater
(
fgraph
,
u
)
def
__str__
(
self
):
return
getattr
(
self
,
'__name__'
,
'<TopoOptimizer instance>'
)
class
OpKeyOptimizer
(
NavigatorOptimizer
):
class
OpKeyOptimizer
(
NavigatorOptimizer
):
"""WRITEME"""
"""WRITEME"""
...
@@ -1360,7 +1367,7 @@ class ChangeTracker:
...
@@ -1360,7 +1367,7 @@ class ChangeTracker:
def
on_import
(
self
,
fgraph
,
node
,
reason
):
def
on_import
(
self
,
fgraph
,
node
,
reason
):
self
.
changed
=
True
self
.
changed
=
True
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
):
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
):
self
.
changed
=
True
self
.
changed
=
True
def
reset
(
self
):
def
reset
(
self
):
...
@@ -1415,23 +1422,29 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1415,23 +1422,29 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def
apply
(
self
,
fgraph
,
start_from
=
None
):
def
apply
(
self
,
fgraph
,
start_from
=
None
):
if
start_from
is
None
:
if
start_from
is
None
:
start_from
=
fgraph
.
outputs
start_from
=
fgraph
.
outputs
else
:
for
node
in
start_from
:
assert
node
in
fgraph
.
outputs
changed
=
True
changed
=
True
max_use_abort
=
False
max_use_abort
=
False
opt_name
=
None
opt_name
=
None
process_count
=
{}
global_
process_count
=
{}
max_nb_nodes
=
len
(
fgraph
.
apply_nodes
)
max_nb_nodes
=
len
(
fgraph
.
apply_nodes
)
max_use
=
max_nb_nodes
*
self
.
max_use_ratio
max_use
=
max_nb_nodes
*
self
.
max_use_ratio
loop_timing
=
[]
loop_timing
=
[]
loop_process_count
=
[]
global_opt_timing
=
[]
global_opt_timing
=
[]
time_opts
=
{}
time_opts
=
{}
io_toposort_timing
=
[]
io_toposort_timing
=
[]
nb_nodes
=
[]
nb_nodes
=
[]
for
opt
in
self
.
global_optimizers
+
self
.
local_optimizers
:
for
opt
in
self
.
global_optimizers
+
self
.
local_optimizers
:
process_count
.
setdefault
(
opt
,
0
)
global_
process_count
.
setdefault
(
opt
,
0
)
time_opts
.
setdefault
(
opt
,
0
)
time_opts
.
setdefault
(
opt
,
0
)
while
changed
and
not
max_use_abort
:
while
changed
and
not
max_use_abort
:
process_count
=
{}
t0
=
time
.
time
()
t0
=
time
.
time
()
changed
=
False
changed
=
False
...
@@ -1442,9 +1455,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1442,9 +1455,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
gopt
.
apply
(
fgraph
)
gopt
.
apply
(
fgraph
)
time_opts
[
gopt
]
+=
time
.
time
()
-
t_opt
time_opts
[
gopt
]
+=
time
.
time
()
-
t_opt
if
fgraph
.
change_tracker
.
changed
:
if
fgraph
.
change_tracker
.
changed
:
process_count
.
setdefault
(
gopt
,
0
)
process_count
[
gopt
]
+=
1
process_count
[
gopt
]
+=
1
global_process_count
[
gopt
]
+=
1
changed
=
True
changed
=
True
if
process_count
[
gopt
]
>
max_use
:
if
global_
process_count
[
gopt
]
>
max_use
:
max_use_abort
=
True
max_use_abort
=
True
opt_name
=
(
getattr
(
gopt
,
"name"
,
None
)
opt_name
=
(
getattr
(
gopt
,
"name"
,
None
)
or
getattr
(
gopt
,
"__name__"
,
""
))
or
getattr
(
gopt
,
"__name__"
,
""
))
...
@@ -1452,9 +1467,6 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1452,9 +1467,6 @@ class EquilibriumOptimizer(NavigatorOptimizer):
global_opt_timing
.
append
(
float
(
time
.
time
()
-
t0
))
global_opt_timing
.
append
(
float
(
time
.
time
()
-
t0
))
#apply local optimizer
#apply local optimizer
for
node
in
start_from
:
assert
node
in
fgraph
.
outputs
topo_t0
=
time
.
time
()
topo_t0
=
time
.
time
()
q
=
deque
(
graph
.
io_toposort
(
fgraph
.
inputs
,
start_from
))
q
=
deque
(
graph
.
io_toposort
(
fgraph
.
inputs
,
start_from
))
io_toposort_timing
.
append
(
time
.
time
()
-
topo_t0
)
io_toposort_timing
.
append
(
time
.
time
()
-
topo_t0
)
...
@@ -1485,9 +1497,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1485,9 +1497,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
lopt_change
=
self
.
process_node
(
fgraph
,
node
,
lopt
)
lopt_change
=
self
.
process_node
(
fgraph
,
node
,
lopt
)
time_opts
[
lopt
]
+=
time
.
time
()
-
t_opt
time_opts
[
lopt
]
+=
time
.
time
()
-
t_opt
if
lopt_change
:
if
lopt_change
:
process_count
.
setdefault
(
lopt
,
0
)
process_count
[
lopt
]
+=
1
process_count
[
lopt
]
+=
1
global_process_count
[
lopt
]
+=
1
changed
=
True
changed
=
True
if
process_count
[
lopt
]
>
max_use
:
if
global_
process_count
[
lopt
]
>
max_use
:
max_use_abort
=
True
max_use_abort
=
True
opt_name
=
(
getattr
(
lopt
,
"name"
,
None
)
opt_name
=
(
getattr
(
lopt
,
"name"
,
None
)
or
getattr
(
lopt
,
"__name__"
,
""
))
or
getattr
(
lopt
,
"__name__"
,
""
))
...
@@ -1497,6 +1511,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1497,6 +1511,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
finally
:
finally
:
self
.
detach_updater
(
fgraph
,
u
)
self
.
detach_updater
(
fgraph
,
u
)
loop_process_count
.
append
(
process_count
)
loop_timing
.
append
(
float
(
time
.
time
()
-
t0
))
loop_timing
.
append
(
float
(
time
.
time
()
-
t0
))
if
max_use_abort
:
if
max_use_abort
:
...
@@ -1505,7 +1520,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1505,7 +1520,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
+
"
%
f with the theano flag 'optdb.max_use_ratio'."
%
+
"
%
f with the theano flag 'optdb.max_use_ratio'."
%
config
.
optdb
.
max_use_ratio
)
config
.
optdb
.
max_use_ratio
)
return
(
self
,
loop_timing
,
process_count
,
max_nb_nodes
,
return
(
self
,
loop_timing
,
loop_
process_count
,
max_nb_nodes
,
global_opt_timing
,
nb_nodes
,
time_opts
,
io_toposort_timing
)
global_opt_timing
,
nb_nodes
,
time_opts
,
io_toposort_timing
)
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
def
print_summary
(
self
,
stream
=
sys
.
stdout
,
level
=
0
,
depth
=-
1
):
...
@@ -1519,8 +1534,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1519,8 +1534,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
@staticmethod
@staticmethod
def
print_profile
(
stream
,
prof
,
level
=
0
):
def
print_profile
(
stream
,
prof
,
level
=
0
):
(
opt
,
loop_timing
,
process_count
,
max_nb_nodes
,
(
opt
,
loop_timing
,
loop_
process_count
,
max_nb_nodes
,
global_opt_timing
,
nb_nodes
,
time_opts
,
io_toposort_timing
)
=
prof
global_opt_timing
,
nb_nodes
,
time_opts
,
io_toposort_timing
)
=
prof
blanc
=
(
' '
*
level
)
blanc
=
(
' '
*
level
)
print
>>
stream
,
blanc
,
"EquilibriumOptimizer"
,
print
>>
stream
,
blanc
,
"EquilibriumOptimizer"
,
print
>>
stream
,
blanc
,
getattr
(
opt
,
"name"
,
print
>>
stream
,
blanc
,
getattr
(
opt
,
"name"
,
...
@@ -1529,30 +1545,57 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1529,30 +1545,57 @@ class EquilibriumOptimizer(NavigatorOptimizer):
sum
(
loop_timing
),
len
(
loop_timing
),
max_nb_nodes
)
sum
(
loop_timing
),
len
(
loop_timing
),
max_nb_nodes
)
print
>>
stream
,
blanc
,
" time io_toposort
%.3
fs"
%
sum
(
print
>>
stream
,
blanc
,
" time io_toposort
%.3
fs"
%
sum
(
io_toposort_timing
)
io_toposort_timing
)
s
=
sum
([
time_opts
[
o
]
for
o
in
opt
.
local_optimizers
])
print
>>
stream
,
blanc
,
" time in local optimizers
%.3
fs"
%
s
s
=
sum
([
time_opts
[
o
]
for
o
in
opt
.
global_optimizers
])
print
>>
stream
,
blanc
,
" time in global optimizers
%.3
fs"
%
s
for
i
in
range
(
len
(
loop_timing
)):
for
i
in
range
(
len
(
loop_timing
)):
print
>>
stream
,
blanc
,
(
'
%
d -
%.3
fs (
%.3
fs in global opts, '
lopt
=
""
'
%.3
fs io_toposort) -
%
d nodes'
%
(
if
loop_process_count
[
i
]:
d
=
list
(
reversed
(
sorted
(
loop_process_count
[
i
]
.
iteritems
(),
key
=
lambda
a
:
a
[
1
])))
lopt
=
" "
.
join
([
str
((
str
(
k
),
v
))
for
k
,
v
in
d
[:
5
]])
if
len
(
d
)
>
5
:
lopt
+=
" ..."
print
>>
stream
,
blanc
,
(
'
%2
d -
%.3
fs
%
d (
%.3
fs in global opts, '
'
%.3
fs io_toposort) -
%
d nodes -
%
s'
%
(
i
,
loop_timing
[
i
],
i
,
loop_timing
[
i
],
sum
(
loop_process_count
[
i
]
.
values
()),
global_opt_timing
[
i
],
global_opt_timing
[
i
],
io_toposort_timing
[
i
],
nb_nodes
[
i
]))
io_toposort_timing
[
i
],
nb_nodes
[
i
],
lopt
))
count_opt
=
[]
count_opt
=
[]
not_used
=
0
not_used_time
=
0
process_count
=
{}
for
o
in
opt
.
global_optimizers
+
opt
.
local_optimizers
:
process_count
.
setdefault
(
o
,
0
)
for
count
in
loop_process_count
:
for
o
,
v
in
count
.
iteritems
():
process_count
[
o
]
+=
v
for
opt
,
count
in
process_count
.
iteritems
():
for
opt
,
count
in
process_count
.
iteritems
():
if
count
>
0
:
if
count
>
0
:
count_opt
.
append
((
time_opts
[
opt
],
count
,
opt
))
count_opt
.
append
((
time_opts
[
opt
],
count
,
opt
))
else
:
not_used
+=
1
not_used_time
+=
time_opts
[
opt
]
if
count_opt
:
if
count_opt
:
print
>>
stream
,
blanc
,
\
print
>>
stream
,
blanc
,
\
'
times applied - optimizer (only those applied)
:'
'
times - times applied - name
:'
count_opt
.
sort
()
count_opt
.
sort
()
for
(
t
,
count
,
opt
)
in
count_opt
[::
-
1
]:
for
(
t
,
count
,
opt
)
in
count_opt
[::
-
1
]:
print
>>
stream
,
blanc
,
'
%.3
fs -
%
d -
%
s'
%
(
print
>>
stream
,
blanc
,
'
%.3
fs -
%
d -
%
s'
%
(
t
,
count
,
opt
)
t
,
count
,
opt
)
print
>>
stream
,
blanc
,
'
%.3
fs - in
%
d optimization that where not used'
%
(
not_used_time
,
not_used
)
print
>>
stream
print
>>
stream
@staticmethod
@staticmethod
def
merge_profile
(
prof1
,
prof2
):
def
merge_profile
(
prof1
,
prof2
):
#(opt, loop_timing, process_count, max_nb_nodes,
#(opt, loop_timing,
loop_
process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
local_optimizers
=
set
(
prof1
[
0
]
.
local_optimizers
)
.
union
(
local_optimizers
=
set
(
prof1
[
0
]
.
local_optimizers
)
.
union
(
...
@@ -1574,12 +1617,16 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1574,12 +1617,16 @@ class EquilibriumOptimizer(NavigatorOptimizer):
loop_timing
=
merge_list
(
prof1
[
1
],
prof2
[
1
])
loop_timing
=
merge_list
(
prof1
[
1
],
prof2
[
1
])
process_count
=
prof1
[
2
]
.
copy
()
loop_process_count
=
prof1
[
2
]
.
copy
()
for
process
,
count
in
prof2
[
2
]
.
iteritems
():
for
i
in
range
(
len
(
loop_process_count
)):
process_count
=
loop_process_count
[
i
]
for
process
,
count
in
prof2
[
2
][
i
]
.
iteritems
():
if
process
in
process_count
:
if
process
in
process_count
:
process_count
[
process
]
+=
count
process_count
[
process
]
+=
count
else
:
else
:
process_count
[
process
]
=
count
process_count
[
process
]
=
count
for
i
in
range
(
len
(
loop_process_count
),
len
(
prof2
[
2
])):
loop_process_count
.
append
(
prof2
[
2
]
.
copy
())
max_nb_nodes
=
max
(
prof1
[
3
],
prof2
[
3
])
max_nb_nodes
=
max
(
prof1
[
3
],
prof2
[
3
])
...
@@ -1601,7 +1648,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
...
@@ -1601,7 +1648,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
assert
len
(
loop_timing
)
==
max
(
len
(
prof1
[
1
]),
len
(
prof2
[
1
]))
assert
len
(
loop_timing
)
==
max
(
len
(
prof1
[
1
]),
len
(
prof2
[
1
]))
return
(
new_opt
,
return
(
new_opt
,
loop_timing
,
loop_timing
,
process_count
,
loop_
process_count
,
max_nb_nodes
,
max_nb_nodes
,
global_opt_timing
,
global_opt_timing
,
nb_nodes
,
nb_nodes
,
...
...
theano/sandbox/linalg/ops.py
浏览文件 @
11c4882a
...
@@ -159,7 +159,7 @@ class HintsFeature(object):
...
@@ -159,7 +159,7 @@ class HintsFeature(object):
if
k
not
in
new_hints
:
if
k
not
in
new_hints
:
new_hints
[
k
]
=
v
new_hints
[
k
]
=
v
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
):
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
):
# TODO:
# TODO:
# This tells us that r and new_r must have the same shape
# This tells us that r and new_r must have the same shape
# if we didn't know that the shapes are related, now we do.
# if we didn't know that the shapes are related, now we do.
...
...
theano/sparse/basic.py
浏览文件 @
11c4882a
...
@@ -300,6 +300,8 @@ class _sparse_py_operators:
...
@@ -300,6 +300,8 @@ class _sparse_py_operators:
# def _as_TensorVariable(self):
# def _as_TensorVariable(self):
# return dense_from_sparse(self)
# return dense_from_sparse(self)
def
toarray
(
self
):
return
dense_from_sparse
(
self
)
shape
=
property
(
lambda
self
:
tensor
.
shape
(
dense_from_sparse
(
self
)))
shape
=
property
(
lambda
self
:
tensor
.
shape
(
dense_from_sparse
(
self
)))
# don't worry!
# don't worry!
# the plan is that the ShapeFeature in tensor.opt will do shape propagation
# the plan is that the ShapeFeature in tensor.opt will do shape propagation
...
@@ -1843,6 +1845,8 @@ class AddSD(gof.op.Op):
...
@@ -1843,6 +1845,8 @@ class AddSD(gof.op.Op):
def
infer_shape
(
self
,
node
,
shapes
):
def
infer_shape
(
self
,
node
,
shapes
):
return
[
shapes
[
3
]]
return
[
shapes
[
3
]]
def
c_code_cache_version
(
self
):
return
(
1
,)
add_s_d
=
AddSD
()
add_s_d
=
AddSD
()
...
@@ -1918,6 +1922,10 @@ def add(x, y):
...
@@ -1918,6 +1922,10 @@ def add(x, y):
x
=
as_sparse_variable
(
x
)
x
=
as_sparse_variable
(
x
)
if
hasattr
(
y
,
'getnnz'
):
if
hasattr
(
y
,
'getnnz'
):
y
=
as_sparse_variable
(
y
)
y
=
as_sparse_variable
(
y
)
if
not
isinstance
(
x
,
theano
.
Variable
):
x
=
theano
.
tensor
.
as_tensor_variable
(
x
)
if
not
isinstance
(
y
,
theano
.
Variable
):
y
=
theano
.
tensor
.
as_tensor_variable
(
y
)
x_is_sparse_variable
=
_is_sparse_variable
(
x
)
x_is_sparse_variable
=
_is_sparse_variable
(
x
)
y_is_sparse_variable
=
_is_sparse_variable
(
y
)
y_is_sparse_variable
=
_is_sparse_variable
(
y
)
...
...
theano/sparse/tests/test_basic.py
浏览文件 @
11c4882a
...
@@ -567,67 +567,57 @@ class T_AddMul(unittest.TestCase):
...
@@ -567,67 +567,57 @@ class T_AddMul(unittest.TestCase):
def
_testSD
(
self
,
op
,
array1
=
numpy
.
array
([[
1.
,
0
],
[
3
,
0
],
[
0
,
6
]]),
def
_testSD
(
self
,
op
,
array1
=
numpy
.
array
([[
1.
,
0
],
[
3
,
0
],
[
0
,
6
]]),
array2
=
numpy
.
asarray
([[
0
,
2.
],
[
0
,
4
],
[
5
,
0
]])):
array2
=
numpy
.
asarray
([[
0
,
2.
],
[
0
,
4
],
[
5
,
0
]])):
for
mtype
in
_mtypes
:
for
mtype
in
_mtypes
:
a
=
numpy
.
array
(
array1
)
for
a
in
[
numpy
.
array
(
array1
),
tensor
.
as_tensor_variable
(
array1
)]:
aR
=
tensor
.
as_tensor_variable
(
a
)
self
.
assertFalse
(
aR
.
data
is
a
)
# constants are copied
self
.
assertTrue
(
_is_dense
(
a
))
self
.
assertTrue
(
_is_dense_variable
(
aR
))
b
=
mtype
(
array2
)
b
=
mtype
(
array2
)
bR
=
as_sparse_variable
(
b
)
bR
=
as_sparse_variable
(
b
)
self
.
assertFalse
(
bR
.
data
is
b
)
# constants are copied
self
.
assertFalse
(
bR
.
data
is
b
)
# constants are copied
self
.
assertTrue
(
_is_sparse
(
b
))
self
.
assertTrue
(
_is_sparse
(
b
))
self
.
assertTrue
(
_is_sparse_variable
(
bR
))
self
.
assertTrue
(
_is_sparse_variable
(
bR
))
apb
=
op
(
aR
,
bR
)
apb
=
op
(
a
,
bR
)
self
.
assertTrue
(
apb
.
type
.
dtype
==
aR
.
type
.
dtype
,
apb
.
type
.
dtype
)
self
.
assertTrue
(
apb
.
type
.
dtype
==
a
.
dtype
,
apb
.
type
.
dtype
)
self
.
assertTrue
(
apb
.
type
.
dtype
==
bR
.
type
.
dtype
,
apb
.
type
.
dtype
)
self
.
assertTrue
(
apb
.
type
.
dtype
==
bR
.
type
.
dtype
,
apb
.
type
.
dtype
)
val
=
eval_outputs
([
apb
])
val
=
eval_outputs
([
apb
])
self
.
assertTrue
(
val
.
shape
==
(
3
,
2
))
self
.
assertTrue
(
val
.
shape
==
(
3
,
2
))
if
op
is
add
:
if
op
is
add
:
self
.
assertTrue
(
_is_dense_variable
(
apb
))
self
.
assertTrue
(
_is_dense_variable
(
apb
))
self
.
assertTrue
(
numpy
.
all
(
val
==
(
a
+
b
)))
self
.
assertTrue
(
numpy
.
all
(
val
==
(
array1
+
b
)))
ans
=
numpy
.
array
([[
1.
,
2
],
[
3
,
4
],
[
5
,
6
]])
ans
=
numpy
.
array
([[
1.
,
2
],
[
3
,
4
],
[
5
,
6
]])
self
.
assertTrue
(
numpy
.
all
(
val
==
ans
))
self
.
assertTrue
(
numpy
.
all
(
val
==
ans
))
elif
op
is
mul
:
elif
op
is
mul
:
self
.
assertTrue
(
_is_sparse_variable
(
apb
))
self
.
assertTrue
(
_is_sparse_variable
(
apb
))
self
.
assertTrue
(
numpy
.
all
(
val
.
todense
()
==
(
b
.
multiply
(
a
))))
self
.
assertTrue
(
numpy
.
all
(
val
.
todense
()
==
(
b
.
multiply
(
array1
))))
self
.
assertTrue
(
numpy
.
all
(
val
.
todense
()
==
numpy
.
array
(
self
.
assertTrue
(
numpy
.
all
(
val
.
todense
()
==
numpy
.
array
(
[[
1
,
0
],
[
9
,
0
],
[
0
,
36
]])))
[[
1
,
0
],
[
9
,
0
],
[
0
,
36
]])))
def
_testDS
(
self
,
op
,
array1
=
numpy
.
array
([[
1.
,
0
],
[
3
,
0
],
[
0
,
6
]]),
def
_testDS
(
self
,
op
,
array1
=
numpy
.
array
([[
1.
,
0
],
[
3
,
0
],
[
0
,
6
]]),
array2
=
numpy
.
asarray
([[
0
,
2.
],
[
0
,
4
],
[
5
,
0
]])):
array2
=
numpy
.
asarray
([[
0
,
2.
],
[
0
,
4
],
[
5
,
0
]])):
for
mtype
in
_mtypes
:
for
mtype
in
_mtypes
:
for
b
in
[
numpy
.
asarray
(
array2
),
tensor
.
as_tensor_variable
(
array2
)]:
a
=
mtype
(
array1
)
a
=
mtype
(
array1
)
aR
=
as_sparse_variable
(
a
)
aR
=
as_sparse_variable
(
a
)
self
.
assertFalse
(
aR
.
data
is
a
)
self
.
assertFalse
(
aR
.
data
is
a
)
self
.
assertTrue
(
_is_sparse
(
a
))
self
.
assertTrue
(
_is_sparse
(
a
))
self
.
assertTrue
(
_is_sparse_variable
(
aR
))
self
.
assertTrue
(
_is_sparse_variable
(
aR
))
b
=
numpy
.
asarray
(
array2
)
apb
=
op
(
aR
,
b
)
bR
=
tensor
.
as_tensor_variable
(
b
)
self
.
assertFalse
(
bR
.
data
is
b
)
self
.
assertTrue
(
_is_dense
(
b
))
self
.
assertTrue
(
_is_dense_variable
(
bR
))
apb
=
op
(
aR
,
bR
)
self
.
assertTrue
(
apb
.
type
.
dtype
==
aR
.
type
.
dtype
,
apb
.
type
.
dtype
)
self
.
assertTrue
(
apb
.
type
.
dtype
==
aR
.
type
.
dtype
,
apb
.
type
.
dtype
)
self
.
assertTrue
(
apb
.
type
.
dtype
==
bR
.
type
.
dtype
,
apb
.
type
.
dtype
)
self
.
assertTrue
(
apb
.
type
.
dtype
==
b
.
dtype
,
apb
.
type
.
dtype
)
val
=
eval_outputs
([
apb
])
val
=
eval_outputs
([
apb
])
self
.
assertTrue
(
val
.
shape
==
(
3
,
2
))
self
.
assertTrue
(
val
.
shape
==
(
3
,
2
))
if
op
is
add
:
if
op
is
add
:
self
.
assertTrue
(
_is_dense_variable
(
apb
))
self
.
assertTrue
(
_is_dense_variable
(
apb
))
self
.
assertTrue
(
numpy
.
all
(
val
==
(
a
+
b
)))
self
.
assertTrue
(
numpy
.
all
(
val
==
(
a
+
array2
)))
ans
=
numpy
.
array
([[
1.
,
2
],
[
3
,
4
],
[
5
,
6
]])
ans
=
numpy
.
array
([[
1.
,
2
],
[
3
,
4
],
[
5
,
6
]])
self
.
assertTrue
(
numpy
.
all
(
val
==
ans
))
self
.
assertTrue
(
numpy
.
all
(
val
==
ans
))
elif
op
is
mul
:
elif
op
is
mul
:
self
.
assertTrue
(
_is_sparse_variable
(
apb
))
self
.
assertTrue
(
_is_sparse_variable
(
apb
))
ans
=
numpy
.
array
([[
1
,
0
],
[
9
,
0
],
[
0
,
36
]])
ans
=
numpy
.
array
([[
1
,
0
],
[
9
,
0
],
[
0
,
36
]])
self
.
assertTrue
(
numpy
.
all
(
val
.
todense
()
==
(
a
.
multiply
(
b
))))
self
.
assertTrue
(
numpy
.
all
(
val
.
todense
()
==
(
a
.
multiply
(
array2
))))
self
.
assertTrue
(
numpy
.
all
(
val
.
todense
()
==
ans
))
self
.
assertTrue
(
numpy
.
all
(
val
.
todense
()
==
ans
))
def
test_upcast
(
self
):
def
test_upcast
(
self
):
...
@@ -718,16 +708,23 @@ class T_conversion(unittest.TestCase):
...
@@ -718,16 +708,23 @@ class T_conversion(unittest.TestCase):
self
.
assertTrue
(
str
(
val
.
dtype
)
==
'float64'
)
self
.
assertTrue
(
str
(
val
.
dtype
)
==
'float64'
)
self
.
assertTrue
(
val
.
format
==
'csr'
)
self
.
assertTrue
(
val
.
format
==
'csr'
)
if
1
:
def
test_dense_from_sparse
(
self
):
def
test2
(
self
):
#call dense_from_sparse
#call dense_from_sparse
for
t
in
_mtypes
:
for
t
in
_mtypes
:
s
=
t
(
scipy
.
sparse
.
identity
(
5
))
s
=
t
(
scipy
.
sparse
.
identity
(
5
))
s
=
as_sparse_variable
(
s
)
d
=
dense_from_sparse
(
s
)
d
=
dense_from_sparse
(
s
)
# s should be copied into the graph as a constant
s
[
0
,
0
]
=
3.0
# changes s, but not the copy
val
=
eval_outputs
([
d
])
val
=
eval_outputs
([
d
])
return
self
.
assertTrue
(
str
(
val
.
dtype
)
==
s
.
dtype
)
self
.
assertTrue
(
numpy
.
all
(
val
[
0
]
==
[
1
,
0
,
0
,
0
,
0
]))
def
test_todense
(
self
):
#call sparse_var.todense()
for
t
in
_mtypes
:
s
=
t
(
scipy
.
sparse
.
identity
(
5
))
s
=
as_sparse_variable
(
s
)
d
=
s
.
toarray
()
val
=
eval_outputs
([
d
])
self
.
assertTrue
(
str
(
val
.
dtype
)
==
s
.
dtype
)
self
.
assertTrue
(
str
(
val
.
dtype
)
==
s
.
dtype
)
self
.
assertTrue
(
numpy
.
all
(
val
[
0
]
==
[
1
,
0
,
0
,
0
,
0
]))
self
.
assertTrue
(
numpy
.
all
(
val
[
0
]
==
[
1
,
0
,
0
,
0
,
0
]))
...
...
theano/tensor/blas_c.py
浏览文件 @
11c4882a
...
@@ -252,6 +252,8 @@ class CGer(BaseBLAS, Ger):
...
@@ -252,6 +252,8 @@ class CGer(BaseBLAS, Ger):
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
8
,
blas_header_version
())
return
(
8
,
blas_header_version
())
cger_inplace
=
CGer
(
True
)
cger_no_inplace
=
CGer
(
False
)
@local_optimizer
([
ger
,
ger_destructive
])
@local_optimizer
([
ger
,
ger_destructive
])
...
@@ -269,8 +271,8 @@ def use_c_ger(node):
...
@@ -269,8 +271,8 @@ def use_c_ger(node):
@local_optimizer
([
CGer
(
False
)])
@local_optimizer
([
CGer
(
False
)])
def
make_c_ger_destructive
(
node
):
def
make_c_ger_destructive
(
node
):
if
node
.
op
==
CGer
(
False
)
:
if
node
.
op
==
cger_no_inplace
:
return
[
CGer
(
True
)
(
*
node
.
inputs
)]
return
[
cger_inplace
(
*
node
.
inputs
)]
####### ####### #######
####### ####### #######
...
@@ -579,6 +581,8 @@ class CGemv(BaseBLAS, Gemv):
...
@@ -579,6 +581,8 @@ class CGemv(BaseBLAS, Gemv):
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
return
(
10
,
blas_header_version
())
return
(
10
,
blas_header_version
())
cgemv_inplace
=
CGemv
(
inplace
=
True
)
cgemv_no_inplace
=
CGemv
(
inplace
=
False
)
@local_optimizer
([
gemv_inplace
,
gemv_no_inplace
])
@local_optimizer
([
gemv_inplace
,
gemv_no_inplace
])
...
@@ -596,8 +600,8 @@ def use_c_gemv(node):
...
@@ -596,8 +600,8 @@ def use_c_gemv(node):
@local_optimizer
([
CGemv
(
inplace
=
False
)])
@local_optimizer
([
CGemv
(
inplace
=
False
)])
def
make_c_gemv_destructive
(
node
):
def
make_c_gemv_destructive
(
node
):
if
node
.
op
==
CGemv
(
inplace
=
False
)
:
if
node
.
op
==
cgemv_no_inplace
:
return
[
CGemv
(
inplace
=
True
)
(
*
node
.
inputs
)]
return
[
cgemv_inplace
(
*
node
.
inputs
)]
####### ####### #######
####### ####### #######
...
...
theano/tensor/elemwise.py
浏览文件 @
11c4882a
...
@@ -546,7 +546,7 @@ class Elemwise(Op):
...
@@ -546,7 +546,7 @@ class Elemwise(Op):
args
.
append
(
DimShuffle
(
args
.
append
(
DimShuffle
(
input
.
type
.
broadcastable
,
input
.
type
.
broadcastable
,
[
'x'
]
*
difference
+
range
(
length
),
[
'x'
]
*
difference
+
range
(
length
),
inplace
=
Tru
e
)(
input
))
inplace
=
Fals
e
)(
input
))
inputs
=
args
inputs
=
args
#HERE: all the broadcast dims have the same length now
#HERE: all the broadcast dims have the same length now
...
...
theano/tensor/opt.py
浏览文件 @
11c4882a
...
@@ -47,29 +47,43 @@ theano.configparser.AddConfigVar('on_shape_error',
...
@@ -47,29 +47,43 @@ theano.configparser.AddConfigVar('on_shape_error',
# Utilities
# Utilities
def
out2in
(
*
local_opts
):
def
out2in
(
*
local_opts
,
**
kwargs
):
"""WRITEME """
"""WRITEME """
name
=
(
kwargs
and
kwargs
.
pop
(
'name'
,
None
))
if
len
(
local_opts
)
>
1
:
if
len
(
local_opts
)
>
1
:
# Don't wrap it uselessly if their is only 1 optimization.
# Don't wrap it uselessly if their is only 1 optimization.
local_opts
=
opt
.
LocalOptGroup
(
*
local_opts
),
local_opts
=
opt
.
LocalOptGroup
(
*
local_opts
),
else
:
else
:
local_opts
,
=
local_opts
local_opts
,
=
local_opts
return
opt
.
TopoOptimizer
(
local_opts
,
if
not
name
:
name
=
local_opts
.
__name__
ret
=
opt
.
TopoOptimizer
(
local_opts
,
order
=
'out_to_in'
,
order
=
'out_to_in'
,
failure_callback
=
TopoOptimizer
.
warn_inplace
)
failure_callback
=
TopoOptimizer
.
warn_inplace
,
**
kwargs
)
if
name
:
ret
.
__name__
=
name
return
ret
def
in2out
(
*
local_opts
,
**
kwargs
):
def
in2out
(
*
local_opts
,
**
kwargs
):
"""WRITEME """
"""WRITEME """
name
=
(
kwargs
and
kwargs
.
pop
(
'name'
,
None
))
if
len
(
local_opts
)
>
1
:
if
len
(
local_opts
)
>
1
:
# Don't wrap it uselessly if their is only 1 optimization.
# Don't wrap it uselessly if their is only 1 optimization.
local_opts
=
opt
.
LocalOptGroup
(
*
local_opts
),
local_opts
=
opt
.
LocalOptGroup
(
*
local_opts
),
else
:
else
:
local_opts
,
=
local_opts
local_opts
,
=
local_opts
return
opt
.
TopoOptimizer
(
local_opts
,
if
not
name
:
#import pdb;pdb.set_trace()
name
=
local_opts
.
__name__
ret
=
opt
.
TopoOptimizer
(
local_opts
,
order
=
'in_to_out'
,
order
=
'in_to_out'
,
failure_callback
=
TopoOptimizer
.
warn_inplace
,
failure_callback
=
TopoOptimizer
.
warn_inplace
,
**
kwargs
)
**
kwargs
)
if
name
:
ret
.
__name__
=
name
return
ret
def
_fill_chain
(
new_out
,
orig_inputs
):
def
_fill_chain
(
new_out
,
orig_inputs
):
...
@@ -1075,7 +1089,7 @@ class ShapeFeature(object):
...
@@ -1075,7 +1089,7 @@ class ShapeFeature(object):
for
r
,
s
in
izip
(
node
.
outputs
,
o_shapes
):
for
r
,
s
in
izip
(
node
.
outputs
,
o_shapes
):
self
.
set_shape
(
r
,
s
)
self
.
set_shape
(
r
,
s
)
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
):
def
on_change_input
(
self
,
fgraph
,
node
,
i
,
r
,
new_r
,
reason
):
if
new_r
not
in
self
.
shape_of
:
if
new_r
not
in
self
.
shape_of
:
# It happen that the fgraph didn't called on_import for some
# It happen that the fgraph didn't called on_import for some
# new_r. This happen when new_r don't have an
# new_r. This happen when new_r don't have an
...
@@ -2102,6 +2116,14 @@ def local_IncSubtensor_serialize(node):
...
@@ -2102,6 +2116,14 @@ def local_IncSubtensor_serialize(node):
#print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs]
#print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs]
# We register it in a TopoOptimizer inside the canonizer EQ optimizer.
# Otherwise in some cases it was making the EQ optimizer use 45. In
# the TopoOptimizer, the EQ only use 6 passes.
compile
.
optdb
.
register
(
'pre_local_IncSubtensor_serialize'
,
in2out
(
local_IncSubtensor_serialize
),
#Just before canonizer
0.99
,
'fast_run'
)
#after priority 50 Destructive inplace operations
#after priority 50 Destructive inplace operations
#gemm is the first one now, at priority 70
#gemm is the first one now, at priority 70
...
@@ -3717,7 +3739,8 @@ register_specialize(local_add_specialize)
...
@@ -3717,7 +3739,8 @@ register_specialize(local_add_specialize)
# mul_to_neg = out2in(gof.LocalOptGroup(local_mul_to_neg))
# mul_to_neg = out2in(gof.LocalOptGroup(local_mul_to_neg))
mul_canonizer
=
in2out
(
gof
.
LocalOptGroup
(
local_mul_canonizer
,
local_fill_cut
,
mul_canonizer
=
in2out
(
gof
.
LocalOptGroup
(
local_mul_canonizer
,
local_fill_cut
,
local_fill_sink
))
local_fill_sink
),
name
=
'mul_canonizer_groups'
)
def
check_for_x_over_absX
(
numerators
,
denominators
):
def
check_for_x_over_absX
(
numerators
,
denominators
):
...
@@ -3859,7 +3882,8 @@ def add_calculate(num, denum, aslist=False, out_type=None):
...
@@ -3859,7 +3882,8 @@ def add_calculate(num, denum, aslist=False, out_type=None):
local_add_canonizer
=
Canonizer
(
T
.
add
,
T
.
sub
,
T
.
neg
,
add_calculate
)
local_add_canonizer
=
Canonizer
(
T
.
add
,
T
.
sub
,
T
.
neg
,
add_calculate
)
add_canonizer
=
in2out
(
gof
.
LocalOptGroup
(
local_add_canonizer
,
local_fill_cut
,
add_canonizer
=
in2out
(
gof
.
LocalOptGroup
(
local_add_canonizer
,
local_fill_cut
,
local_fill_sink
))
local_fill_sink
),
name
=
'add_canonizer_group'
)
register_canonicalize
(
local_add_canonizer
,
name
=
'local_add_canonizer'
)
register_canonicalize
(
local_add_canonizer
,
name
=
'local_add_canonizer'
)
...
...
theano/tensor/tests/test_opt.py
浏览文件 @
11c4882a
...
@@ -124,13 +124,27 @@ class test_dimshuffle_lift(unittest.TestCase):
...
@@ -124,13 +124,27 @@ class test_dimshuffle_lift(unittest.TestCase):
x
,
y
,
z
=
inputs
([
False
]
*
1
,
[
False
]
*
2
,
[
False
]
*
3
)
x
,
y
,
z
=
inputs
([
False
]
*
1
,
[
False
]
*
2
,
[
False
]
*
3
)
e
=
x
+
y
+
z
e
=
x
+
y
+
z
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
g
=
FunctionGraph
([
x
,
y
,
z
],
[
e
])
self
.
assertTrue
(
str
(
g
)
==
(
"[Elemwise{add,no_inplace}("
"InplaceDimShuffle{x,0,1}(Elemwise{add,no_inplace}"
# It does not really matter if the DimShuffles are inplace
"(InplaceDimShuffle{x,0}(x), y)), z)]"
),
str
(
g
))
# or not.
init_str_g_inplace
=
(
"[Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z)]"
)
init_str_g_noinplace
=
(
"[Elemwise{add,no_inplace}(DimShuffle{x,0,1}"
"(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z)]"
)
self
.
assertTrue
(
str
(
g
)
in
(
init_str_g_inplace
,
init_str_g_noinplace
),
str
(
g
))
opt_str_g_inplace
=
(
"[Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z)]"
)
opt_str_g_noinplace
=
(
"[Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
"(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z)]"
)
dimshuffle_lift
.
optimize
(
g
)
dimshuffle_lift
.
optimize
(
g
)
self
.
assertTrue
(
str
(
g
)
==
(
"[Elemwise{add,no_inplace}(Elemwise"
self
.
assertTrue
(
str
(
g
)
in
(
opt_str_g_inplace
,
opt_str_g_noinplace
),
"{add,no_inplace}(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle"
str
(
g
))
"{x,0,1}(y)), z)]"
),
str
(
g
))
def
test_add_canonizer_problem0
():
def
test_add_canonizer_problem0
():
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论