Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1e8b0b00
提交
1e8b0b00
authored
11月 01, 2011
作者:
Frederic
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pep8
上级
e2e178e1
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
115 行增加
和
114 行删除
+115
-114
scan_utils.py
theano/scan_module/scan_utils.py
+115
-114
没有找到文件。
theano/scan_module/scan_utils.py
浏览文件 @
1e8b0b00
...
@@ -4,7 +4,7 @@ This module provides utility functions for the Scan Op
...
@@ -4,7 +4,7 @@ This module provides utility functions for the Scan Op
See scan.py for details on scan
See scan.py for details on scan
"""
"""
__docformat__
=
'restructedtext en'
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"Frederic Bastien "
"James Bergstra "
"James Bergstra "
"Pascal Lamblin "
"Pascal Lamblin "
...
@@ -16,7 +16,6 @@ import copy
...
@@ -16,7 +16,6 @@ import copy
import
logging
import
logging
import
numpy
import
numpy
from
theano
import
config
from
theano.compile.pfunc
import
rebuild_collect_shared
from
theano.compile.pfunc
import
rebuild_collect_shared
from
theano
import
gof
from
theano
import
gof
from
theano
import
tensor
,
scalar
from
theano
import
tensor
,
scalar
...
@@ -30,7 +29,8 @@ import theano
...
@@ -30,7 +29,8 @@ import theano
# Logging function for sending warning or info
# Logging function for sending warning or info
_logger
=
logging
.
getLogger
(
'theano.scan_utils'
)
_logger
=
logging
.
getLogger
(
'theano.scan_utils'
)
def
safe_new
(
x
,
tag
=
''
):
def
safe_new
(
x
,
tag
=
''
):
"""
"""
Internal function that constructs a new variable from x with the same
Internal function that constructs a new variable from x with the same
type, but with a different name ( old name + tag). This function is used
type, but with a different name ( old name + tag). This function is used
...
@@ -81,7 +81,7 @@ class until(object):
...
@@ -81,7 +81,7 @@ class until(object):
assert
self
.
condition
.
ndim
==
0
assert
self
.
condition
.
ndim
==
0
def
traverse
(
out
,
x
,
x_copy
,
d
):
def
traverse
(
out
,
x
,
x_copy
,
d
):
''' Function used by scan to parse the tree and figure out which nodes
''' Function used by scan to parse the tree and figure out which nodes
it needs to replace. There are two options :
it needs to replace. There are two options :
1) x and x_copy or on host, then you would replace x with x_copy
1) x and x_copy or on host, then you would replace x with x_copy
...
@@ -111,10 +111,10 @@ def traverse(out, x,x_copy, d):
...
@@ -111,10 +111,10 @@ def traverse(out, x,x_copy, d):
def
hash_listsDictsTuples
(
x
):
def
hash_listsDictsTuples
(
x
):
hash_value
=
0
hash_value
=
0
if
isinstance
(
x
,
dict
):
if
isinstance
(
x
,
dict
):
for
k
,
v
in
x
.
iteritems
():
for
k
,
v
in
x
.
iteritems
():
hash_value
^=
hash_listsDictsTuples
(
k
)
hash_value
^=
hash_listsDictsTuples
(
k
)
hash_value
^=
hash_listsDictsTuples
(
v
)
hash_value
^=
hash_listsDictsTuples
(
v
)
elif
isinstance
(
x
,
(
list
,
tuple
)):
elif
isinstance
(
x
,
(
list
,
tuple
)):
for
v
in
x
:
for
v
in
x
:
hash_value
^=
hash_listsDictsTuples
(
v
)
hash_value
^=
hash_listsDictsTuples
(
v
)
else
:
else
:
...
@@ -122,10 +122,10 @@ def hash_listsDictsTuples(x):
...
@@ -122,10 +122,10 @@ def hash_listsDictsTuples(x):
return
hash_value
return
hash_value
def
clone
(
output
def
clone
(
output
,
,
replace
=
None
replace
=
None
,
,
strict
=
True
strict
=
True
,
,
copy_inputs
=
True
):
copy_inputs
=
True
):
"""
"""
Function that allows replacing subgraphs of a computational
Function that allows replacing subgraphs of a computational
graph. It returns a copy of the initial subgraph with the corresponding
graph. It returns a copy of the initial subgraph with the corresponding
...
@@ -140,17 +140,16 @@ def clone( output
...
@@ -140,17 +140,16 @@ def clone( output
replaced by what
replaced by what
"""
"""
inps
,
outs
,
other_stuff
=
rebuild_collect_shared
(
output
inps
,
outs
,
other_stuff
=
rebuild_collect_shared
(
output
,
,
[]
[],
,
replace
replace
,
,
[]
[],
,
strict
strict
,
,
copy_inputs
copy_inputs
)
)
return
outs
return
outs
def
get_updates_and_outputs
(
ls
):
def
get_updates_and_outputs
(
ls
):
"""
"""
This function tries to recognize the updates dictionary, the
This function tries to recognize the updates dictionary, the
...
@@ -160,7 +159,7 @@ def get_updates_and_outputs(ls):
...
@@ -160,7 +159,7 @@ def get_updates_and_outputs(ls):
"""
"""
def
is_outputs
(
elem
):
def
is_outputs
(
elem
):
if
(
isinstance
(
elem
,
(
list
,
tuple
))
and
if
(
isinstance
(
elem
,
(
list
,
tuple
))
and
all
([
isinstance
(
x
,
theano
.
Variable
)
for
x
in
elem
])):
all
([
isinstance
(
x
,
theano
.
Variable
)
for
x
in
elem
])):
return
True
return
True
if
isinstance
(
elem
,
theano
.
Variable
):
if
isinstance
(
elem
,
theano
.
Variable
):
...
@@ -172,7 +171,7 @@ def get_updates_and_outputs(ls):
...
@@ -172,7 +171,7 @@ def get_updates_and_outputs(ls):
return
True
return
True
# Dictionaries can be given as lists of tuples
# Dictionaries can be given as lists of tuples
if
(
isinstance
(
elem
,
(
list
,
tuple
))
and
if
(
isinstance
(
elem
,
(
list
,
tuple
))
and
all
([
isinstance
(
x
,
(
list
,
tuple
))
and
len
(
x
)
==
2
all
([
isinstance
(
x
,
(
list
,
tuple
))
and
len
(
x
)
==
2
for
x
in
elem
])):
for
x
in
elem
])):
return
True
return
True
return
False
return
False
...
@@ -204,13 +203,13 @@ def get_updates_and_outputs(ls):
...
@@ -204,13 +203,13 @@ def get_updates_and_outputs(ls):
if
is_updates
(
ls
[
1
]):
if
is_updates
(
ls
[
1
]):
return
(
None
,
_list
(
ls
[
0
]),
dict
(
ls
[
1
]))
return
(
None
,
_list
(
ls
[
0
]),
dict
(
ls
[
1
]))
elif
is_condition
(
ls
[
1
]):
elif
is_condition
(
ls
[
1
]):
return
(
ls
[
1
]
.
condition
,
_list
(
ls
[
0
]),
{})
return
(
ls
[
1
]
.
condition
,
_list
(
ls
[
0
]),
{})
else
:
else
:
raise
ValueError
(
error_msg
)
raise
ValueError
(
error_msg
)
elif
is_updates
(
ls
[
0
]):
elif
is_updates
(
ls
[
0
]):
if
is_outputs
(
ls
[
1
]):
if
is_outputs
(
ls
[
1
]):
_logger
.
warning
(
deprication_msg
)
_logger
.
warning
(
deprication_msg
)
return
(
None
,
_list
(
ls
[
1
]),
dict
(
ls
[
0
])
)
return
(
None
,
_list
(
ls
[
1
]),
dict
(
ls
[
0
])
)
elif
is_condition
(
ls
[
1
]):
elif
is_condition
(
ls
[
1
]):
return
(
ls
[
1
]
.
condition
,
[],
dict
(
ls
[
0
]))
return
(
ls
[
1
]
.
condition
,
[],
dict
(
ls
[
0
]))
else
:
else
:
...
@@ -264,7 +263,7 @@ def isNaN_or_Inf_or_None(x):
...
@@ -264,7 +263,7 @@ def isNaN_or_Inf_or_None(x):
return
isNone
or
isNaN
or
isInf
or
isStr
return
isNone
or
isNaN
or
isInf
or
isStr
def
expand
(
tensor_var
,
size
):
def
expand
(
tensor_var
,
size
):
'''
'''
Transoforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..)
Transoforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..)
by adding 0s at the end of the tensor.
by adding 0s at the end of the tensor.
...
@@ -272,13 +271,14 @@ def expand( tensor_var, size):
...
@@ -272,13 +271,14 @@ def expand( tensor_var, size):
# Corner case that I might use in an optimization
# Corner case that I might use in an optimization
if
size
==
0
:
if
size
==
0
:
return
tensor_var
return
tensor_var
shapes
=
[
tensor_var
.
shape
[
x
]
for
x
in
xrange
(
tensor_var
.
ndim
)
]
shapes
=
[
tensor_var
.
shape
[
x
]
for
x
in
xrange
(
tensor_var
.
ndim
)
]
zeros_shape
=
[
size
+
shapes
[
0
]]
+
shapes
[
1
:]
zeros_shape
=
[
size
+
shapes
[
0
]]
+
shapes
[
1
:]
empty
=
tensor
.
zeros
(
zeros_shape
empty
=
tensor
.
zeros
(
zeros_shape
,
,
dtype
=
tensor_var
.
dtype
)
dtype
=
tensor_var
.
dtype
)
return
tensor
.
set_subtensor
(
empty
[:
shapes
[
0
]],
tensor_var
)
return
tensor
.
set_subtensor
(
empty
[:
shapes
[
0
]],
tensor_var
)
def
equal_computations
(
xs
,
ys
,
in_xs
=
None
,
in_ys
=
None
,
strict
=
True
):
def
equal_computations
(
xs
,
ys
,
in_xs
=
None
,
in_ys
=
None
,
strict
=
True
):
'''
'''
Checks if to theano graphs represent the same computations (with
Checks if to theano graphs represent the same computations (with
equivalence of inputs defined by map). Inputs are always assumed
equivalence of inputs defined by map). Inputs are always assumed
...
@@ -289,8 +289,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
...
@@ -289,8 +289,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
if
in_ys
is
None
:
if
in_ys
is
None
:
in_ys
=
[]
in_ys
=
[]
for
x
,
y
in
zip
(
xs
,
ys
):
for
x
,
y
in
zip
(
xs
,
ys
):
if
x
.
owner
and
not
y
.
owner
:
if
x
.
owner
and
not
y
.
owner
:
return
False
return
False
if
y
.
owner
and
not
x
.
owner
:
if
y
.
owner
and
not
x
.
owner
:
...
@@ -300,7 +299,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
...
@@ -300,7 +299,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
return
False
return
False
if
len
(
in_xs
)
!=
len
(
in_ys
):
if
len
(
in_xs
)
!=
len
(
in_ys
):
return
False
return
False
for
_x
,
_y
in
zip
(
in_xs
,
in_ys
):
for
_x
,
_y
in
zip
(
in_xs
,
in_ys
):
if
_x
.
type
!=
_y
.
type
:
if
_x
.
type
!=
_y
.
type
:
return
False
return
False
...
@@ -308,17 +307,17 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
...
@@ -308,17 +307,17 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
nds_y
=
gof
.
graph
.
io_toposort
(
in_ys
,
ys
)
nds_y
=
gof
.
graph
.
io_toposort
(
in_ys
,
ys
)
if
len
(
nds_x
)
!=
len
(
nds_y
):
if
len
(
nds_x
)
!=
len
(
nds_y
):
return
False
return
False
common
=
set
(
zip
(
in_xs
,
in_ys
))
common
=
set
(
zip
(
in_xs
,
in_ys
))
n_nodes
=
len
(
nds_x
)
n_nodes
=
len
(
nds_x
)
cont
=
True
cont
=
True
idx
=
0
idx
=
0
for
dx
,
dy
in
zip
(
xs
,
ys
):
for
dx
,
dy
in
zip
(
xs
,
ys
):
if
not
dx
.
owner
or
not
dy
.
owner
:
if
not
dx
.
owner
or
not
dy
.
owner
:
if
dy
.
owner
or
dx
.
owner
:
if
dy
.
owner
or
dx
.
owner
:
return
False
return
False
elif
(
isinstance
(
dx
,
tensor
.
Constant
)
and
elif
(
isinstance
(
dx
,
tensor
.
Constant
)
and
isinstance
(
dy
,
tensor
.
Constant
)):
isinstance
(
dy
,
tensor
.
Constant
)):
if
not
(
numpy
.
all
(
dx
.
data
==
dy
.
data
)
and
if
not
(
numpy
.
all
(
dx
.
data
==
dy
.
data
)
and
dx
.
dtype
==
dy
.
dtype
and
dx
.
dtype
==
dy
.
dtype
and
dx
.
data
.
shape
==
dy
.
data
.
shape
):
dx
.
data
.
shape
==
dy
.
data
.
shape
):
return
False
return
False
...
@@ -329,7 +328,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
...
@@ -329,7 +328,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
if
dx
.
type
!=
dy
.
type
:
if
dx
.
type
!=
dy
.
type
:
return
False
return
False
else
:
else
:
if
(
dx
,
dy
)
not
in
common
:
if
(
dx
,
dy
)
not
in
common
:
return
False
return
False
while
cont
and
idx
<
n_nodes
:
while
cont
and
idx
<
n_nodes
:
...
@@ -342,9 +341,9 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
...
@@ -342,9 +341,9 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
elif
len
(
nd_x
.
outputs
)
!=
len
(
nd_y
.
outputs
):
elif
len
(
nd_x
.
outputs
)
!=
len
(
nd_y
.
outputs
):
cont
=
False
cont
=
False
else
:
else
:
for
dx
,
dy
in
zip
(
nd_x
.
inputs
,
nd_y
.
inputs
):
for
dx
,
dy
in
zip
(
nd_x
.
inputs
,
nd_y
.
inputs
):
if
(
dx
,
dy
)
not
in
common
:
if
(
dx
,
dy
)
not
in
common
:
if
strict
and
dx
!=
dy
:
if
strict
and
dx
!=
dy
:
if
(
isinstance
(
dx
,
tensor
.
Constant
)
and
if
(
isinstance
(
dx
,
tensor
.
Constant
)
and
isinstance
(
dy
,
tensor
.
Constant
)):
isinstance
(
dy
,
tensor
.
Constant
)):
if
not
(
numpy
.
all
(
dx
.
data
==
dy
.
data
)
and
if
not
(
numpy
.
all
(
dx
.
data
==
dy
.
data
)
and
...
@@ -359,16 +358,13 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
...
@@ -359,16 +358,13 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
cont
=
cont
and
(
dx
.
type
==
dy
.
type
)
cont
=
cont
and
(
dx
.
type
==
dy
.
type
)
if
cont
:
if
cont
:
for
dx
,
dy
in
zip
(
nd_x
.
outputs
,
nd_y
.
outputs
):
for
dx
,
dy
in
zip
(
nd_x
.
outputs
,
nd_y
.
outputs
):
common
.
add
((
dx
,
dy
))
common
.
add
((
dx
,
dy
))
idx
+=
1
idx
+=
1
return
cont
return
cont
def
infer_shape
(
outs
,
inputs
,
input_shapes
):
def
infer_shape
(
outs
,
inputs
,
input_shapes
):
'''
'''
Compute the shape of the outputs given the shape of the inputs
Compute the shape of the outputs given the shape of the inputs
...
@@ -416,6 +412,7 @@ def infer_shape(outs, inputs, input_shapes):
...
@@ -416,6 +412,7 @@ def infer_shape(outs, inputs, input_shapes):
ret
.
append
(
shape_feature
.
shape_of
[
o
])
ret
.
append
(
shape_feature
.
shape_of
[
o
])
return
ret
return
ret
class
Validator
(
object
):
class
Validator
(
object
):
def
__init__
(
self
,
valid
=
[],
invalid
=
[],
valid_equivalent
=
{}):
def
__init__
(
self
,
valid
=
[],
invalid
=
[],
valid_equivalent
=
{}):
'''
'''
...
@@ -494,7 +491,7 @@ def scan_can_remove_outs(op, out_idxs):
...
@@ -494,7 +491,7 @@ def scan_can_remove_outs(op, out_idxs):
the first one with the indices of outs that can be removed, the second
the first one with the indices of outs that can be removed, the second
with the outputs that can not be removed.
with the outputs that can not be removed.
'''
'''
non_removable
=
[
o
for
i
,
o
in
enumerate
(
op
.
outputs
)
if
i
not
in
non_removable
=
[
o
for
i
,
o
in
enumerate
(
op
.
outputs
)
if
i
not
in
out_idxs
]
out_idxs
]
required_inputs
=
gof
.
graph
.
inputs
(
non_removable
)
required_inputs
=
gof
.
graph
.
inputs
(
non_removable
)
...
@@ -503,26 +500,26 @@ def scan_can_remove_outs(op, out_idxs):
...
@@ -503,26 +500,26 @@ def scan_can_remove_outs(op, out_idxs):
lim
=
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
lim
=
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
for
idx
in
range
(
lim
):
for
idx
in
range
(
lim
):
n_ins
=
len
(
op
.
info
[
'tap_array'
][
idx
])
n_ins
=
len
(
op
.
info
[
'tap_array'
][
idx
])
out_ins
+=
[
op
.
inputs
[
offset
:
offset
+
n_ins
]]
out_ins
+=
[
op
.
inputs
[
offset
:
offset
+
n_ins
]]
offset
+=
n_ins
offset
+=
n_ins
out_ins
+=
[
[]
for
k
in
xrange
(
op
.
n_nit_sot
)
]
out_ins
+=
[
[]
for
k
in
xrange
(
op
.
n_nit_sot
)
]
out_ins
+=
[
[
op
.
inputs
[
offset
+
k
]]
for
k
in
xrange
(
op
.
n_shared_outs
)]
out_ins
+=
[
[
op
.
inputs
[
offset
+
k
]]
for
k
in
xrange
(
op
.
n_shared_outs
)]
added
=
True
added
=
True
out_idxs_mask
=
[
1
for
idx
in
out_idxs
]
out_idxs_mask
=
[
1
for
idx
in
out_idxs
]
while
added
:
while
added
:
added
=
False
added
=
False
for
pos
,
idx
in
enumerate
(
out_idxs
):
for
pos
,
idx
in
enumerate
(
out_idxs
):
if
(
out_idxs_mask
[
pos
]
and
if
(
out_idxs_mask
[
pos
]
and
numpy
.
any
([
x
in
required_inputs
for
x
in
out_ins
[
idx
]])
):
numpy
.
any
([
x
in
required_inputs
for
x
in
out_ins
[
idx
]])):
# This output is required ..
# This output is required ..
out_idxs_mask
[
pos
]
=
0
out_idxs_mask
[
pos
]
=
0
required_inputs
+=
gof
.
graph
.
inputs
([
op
.
outputs
[
idx
]])
required_inputs
+=
gof
.
graph
.
inputs
([
op
.
outputs
[
idx
]])
added
=
True
added
=
True
required_outs
=
[
x
for
i
,
x
in
enumerate
(
out_idxs
)
required_outs
=
[
x
for
i
,
x
in
enumerate
(
out_idxs
)
if
out_idxs_mask
[
i
]
==
0
]
if
out_idxs_mask
[
i
]
==
0
]
not_required
=
[
x
for
i
,
x
in
enumerate
(
out_idxs
)
if
out_idxs_mask
[
i
]
==
1
]
not_required
=
[
x
for
i
,
x
in
enumerate
(
out_idxs
)
if
out_idxs_mask
[
i
]
==
1
]
return
(
required_outs
,
not_required
)
return
(
required_outs
,
not_required
)
...
@@ -560,84 +557,84 @@ def compress_outs(op, not_required, inputs):
...
@@ -560,84 +557,84 @@ def compress_outs(op, not_required, inputs):
map_old_new
=
{}
map_old_new
=
{}
offset
=
0
offset
=
0
ni_offset
=
op
.
n_seqs
+
1
ni_offset
=
op
.
n_seqs
+
1
i_offset
=
op
.
n_seqs
i_offset
=
op
.
n_seqs
o_offset
=
0
o_offset
=
0
curr_pos
=
0
curr_pos
=
0
for
idx
in
xrange
(
op
.
info
[
'n_mit_mot'
]):
for
idx
in
xrange
(
op
.
info
[
'n_mit_mot'
]):
if
offset
+
idx
not
in
not_required
:
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
curr_pos
+=
1
info
[
'n_mit_mot'
]
+=
1
info
[
'n_mit_mot'
]
+=
1
info
[
'tap_array'
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
info
[
'tap_array'
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
info
[
'mit_mot_out_slices'
]
+=
[
op
.
mit_mot_out_slices
[
offset
+
idx
]]
info
[
'mit_mot_out_slices'
]
+=
[
op
.
mit_mot_out_slices
[
offset
+
idx
]]
# input taps
# input taps
for
jdx
in
op
.
tap_array
[
offset
+
idx
]:
for
jdx
in
op
.
tap_array
[
offset
+
idx
]:
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
i_offset
+=
1
i_offset
+=
1
# output taps
# output taps
for
jdx
in
op
.
mit_mot_out_slices
[
offset
+
idx
]:
for
jdx
in
op
.
mit_mot_out_slices
[
offset
+
idx
]:
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
o_offset
+=
1
o_offset
+=
1
# node inputs
# node inputs
node_inputs
+=
[
inputs
[
ni_offset
+
idx
]]
node_inputs
+=
[
inputs
[
ni_offset
+
idx
]]
else
:
else
:
o_offset
+=
len
(
op
.
mit_mot_out_slices
[
offset
+
idx
])
o_offset
+=
len
(
op
.
mit_mot_out_slices
[
offset
+
idx
])
i_offset
+=
len
(
op
.
tap_array
[
offset
+
idx
])
i_offset
+=
len
(
op
.
tap_array
[
offset
+
idx
])
info
[
'n_mit_mot_outs'
]
=
len
(
op_outputs
)
info
[
'n_mit_mot_outs'
]
=
len
(
op_outputs
)
offset
+=
op
.
n_mit_mot
offset
+=
op
.
n_mit_mot
ni_offset
+=
op
.
n_mit_mot
ni_offset
+=
op
.
n_mit_mot
for
idx
in
xrange
(
op
.
info
[
'n_mit_sot'
]):
for
idx
in
xrange
(
op
.
info
[
'n_mit_sot'
]):
if
offset
+
idx
not
in
not_required
:
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
curr_pos
+=
1
info
[
'n_mit_sot'
]
+=
1
info
[
'n_mit_sot'
]
+=
1
info
[
'tap_array'
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
info
[
'tap_array'
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
#input taps
#input taps
for
jdx
in
op
.
tap_array
[
offset
+
idx
]:
for
jdx
in
op
.
tap_array
[
offset
+
idx
]:
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
i_offset
+=
1
i_offset
+=
1
#output taps
#output taps
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
o_offset
+=
1
o_offset
+=
1
#node inputs
#node inputs
node_inputs
+=
[
inputs
[
ni_offset
+
idx
]]
node_inputs
+=
[
inputs
[
ni_offset
+
idx
]]
else
:
else
:
o_offset
+=
1
o_offset
+=
1
i_offset
+=
len
(
op
.
tap_array
[
offset
+
idx
])
i_offset
+=
len
(
op
.
tap_array
[
offset
+
idx
])
offset
+=
op
.
n_mit_sot
offset
+=
op
.
n_mit_sot
ni_offset
+=
op
.
n_mit_sot
ni_offset
+=
op
.
n_mit_sot
for
idx
in
xrange
(
op
.
info
[
'n_sit_sot'
]):
for
idx
in
xrange
(
op
.
info
[
'n_sit_sot'
]):
if
offset
+
idx
not
in
not_required
:
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
curr_pos
+=
1
info
[
'n_sit_sot'
]
+=
1
info
[
'n_sit_sot'
]
+=
1
info
[
'tap_array'
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
info
[
'tap_array'
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
#input taps
#input taps
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
i_offset
+=
1
i_offset
+=
1
#output taps
#output taps
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
o_offset
+=
1
o_offset
+=
1
#node inputs
#node inputs
node_inputs
+=
[
inputs
[
ni_offset
+
idx
]]
node_inputs
+=
[
inputs
[
ni_offset
+
idx
]]
else
:
else
:
o_offset
+=
1
o_offset
+=
1
i_offset
+=
1
i_offset
+=
1
offset
+=
op
.
n_sit_sot
offset
+=
op
.
n_sit_sot
ni_offset
+=
op
.
n_sit_sot
ni_offset
+=
op
.
n_sit_sot
nit_sot_ins
=
[]
nit_sot_ins
=
[]
for
idx
in
xrange
(
op
.
info
[
'n_nit_sot'
]):
for
idx
in
xrange
(
op
.
info
[
'n_nit_sot'
]):
if
offset
+
idx
not
in
not_required
:
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
curr_pos
+=
1
info
[
'n_nit_sot'
]
+=
1
info
[
'n_nit_sot'
]
+=
1
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
o_offset
+=
1
o_offset
+=
1
nit_sot_ins
+=
[
inputs
[
ni_offset
+
idx
+
op
.
n_shared_outs
]]
nit_sot_ins
+=
[
inputs
[
ni_offset
+
idx
+
op
.
n_shared_outs
]]
else
:
else
:
o_offset
+=
1
o_offset
+=
1
...
@@ -645,14 +642,14 @@ def compress_outs(op, not_required, inputs):
...
@@ -645,14 +642,14 @@ def compress_outs(op, not_required, inputs):
shared_ins
=
[]
shared_ins
=
[]
for
idx
in
xrange
(
op
.
info
[
'n_shared_outs'
]):
for
idx
in
xrange
(
op
.
info
[
'n_shared_outs'
]):
if
offset
+
idx
not
in
not_required
:
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
curr_pos
+=
1
info
[
'n_shared_outs'
]
+=
1
info
[
'n_shared_outs'
]
+=
1
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
o_offset
+=
1
o_offset
+=
1
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
i_offset
+=
1
i_offset
+=
1
shared_ins
+=
[
inputs
[
ni_offset
+
idx
]]
shared_ins
+=
[
inputs
[
ni_offset
+
idx
]]
else
:
else
:
o_offset
+=
1
o_offset
+=
1
i_offset
+=
1
i_offset
+=
1
...
@@ -660,14 +657,15 @@ def compress_outs(op, not_required, inputs):
...
@@ -660,14 +657,15 @@ def compress_outs(op, not_required, inputs):
node_inputs
+=
nit_sot_ins
node_inputs
+=
nit_sot_ins
# other stuff
# other stuff
op_inputs
+=
op
.
inputs
[
i_offset
:]
op_inputs
+=
op
.
inputs
[
i_offset
:]
node_inputs
+=
inputs
[
ni_offset
+
op
.
n_shared_outs
+
op
.
n_nit_sot
:]
node_inputs
+=
inputs
[
ni_offset
+
op
.
n_shared_outs
+
op
.
n_nit_sot
:]
if
op
.
as_while
:
if
op
.
as_while
:
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
map_old_new
[
o_offset
]
=
len
(
op_outputs
)
-
1
map_old_new
[
o_offset
]
=
len
(
op_outputs
)
-
1
#map_old_new[len(op_outputs)-1] = o_offset
#map_old_new[len(op_outputs)-1] = o_offset
return
(
op_inputs
,
op_outputs
,
info
,
node_inputs
,
map_old_new
)
return
(
op_inputs
,
op_outputs
,
info
,
node_inputs
,
map_old_new
)
def
find_up
(
l_node
,
f_node
):
def
find_up
(
l_node
,
f_node
):
r"""
r"""
Goes up in the graph and returns True if a node in nodes is found.
Goes up in the graph and returns True if a node in nodes is found.
...
@@ -680,7 +678,8 @@ def find_up(l_node, f_node):
...
@@ -680,7 +678,8 @@ def find_up(l_node, f_node):
nodes
=
gof
.
graph
.
io_toposort
(
l_ins
,
l_outs
)
nodes
=
gof
.
graph
.
io_toposort
(
l_ins
,
l_outs
)
return
f_node
in
nodes
return
f_node
in
nodes
def
reconstruct_graph
(
inputs
,
outputs
,
tag
=
None
):
def
reconstruct_graph
(
inputs
,
outputs
,
tag
=
None
):
"""
"""
Different interface to clone, that allows you to pass inputs.
Different interface to clone, that allows you to pass inputs.
Compared to clone, this method always replaces the inputs with
Compared to clone, this method always replaces the inputs with
...
@@ -689,7 +688,7 @@ def reconstruct_graph(inputs, outputs, tag = None):
...
@@ -689,7 +688,7 @@ def reconstruct_graph(inputs, outputs, tag = None):
"""
"""
if
tag
is
None
:
if
tag
is
None
:
tag
=
''
tag
=
''
nw_inputs
=
[
safe_new
(
x
,
tag
)
for
x
in
inputs
]
nw_inputs
=
[
safe_new
(
x
,
tag
)
for
x
in
inputs
]
givens
=
{}
givens
=
{}
for
nw_x
,
x
in
zip
(
nw_inputs
,
inputs
):
for
nw_x
,
x
in
zip
(
nw_inputs
,
inputs
):
givens
[
x
]
=
nw_x
givens
[
x
]
=
nw_x
...
@@ -698,9 +697,10 @@ def reconstruct_graph(inputs, outputs, tag = None):
...
@@ -698,9 +697,10 @@ def reconstruct_graph(inputs, outputs, tag = None):
if
isinstance
(
inp
,
theano
.
Constant
):
if
isinstance
(
inp
,
theano
.
Constant
):
givens
[
inp
]
=
inp
.
clone
()
givens
[
inp
]
=
inp
.
clone
()
nw_outputs
=
clone
(
outputs
,
replace
=
givens
)
nw_outputs
=
clone
(
outputs
,
replace
=
givens
)
return
(
nw_inputs
,
nw_outputs
)
return
(
nw_inputs
,
nw_outputs
)
class
scan_args
(
object
):
class
scan_args
(
object
):
"""Parses the inputs and outputs of scan in an easy to manipulate format"""
"""Parses the inputs and outputs of scan in an easy to manipulate format"""
def
__init__
(
self
,
outer_inputs
,
outer_outputs
,
def
__init__
(
self
,
outer_inputs
,
outer_outputs
,
...
@@ -718,8 +718,8 @@ class scan_args(object):
...
@@ -718,8 +718,8 @@ class scan_args(object):
q
=
0
q
=
0
n_seqs
=
info
[
'n_seqs'
]
n_seqs
=
info
[
'n_seqs'
]
self
.
outer_in_seqs
=
outer_inputs
[
p
:
p
+
n_seqs
]
self
.
outer_in_seqs
=
outer_inputs
[
p
:
p
+
n_seqs
]
self
.
inner_in_seqs
=
inner_inputs
[
q
:
q
+
n_seqs
]
self
.
inner_in_seqs
=
inner_inputs
[
q
:
q
+
n_seqs
]
p
+=
n_seqs
p
+=
n_seqs
q
+=
n_seqs
q
+=
n_seqs
...
@@ -727,46 +727,47 @@ class scan_args(object):
...
@@ -727,46 +727,47 @@ class scan_args(object):
n_mit_sot
=
info
[
'n_mit_sot'
]
n_mit_sot
=
info
[
'n_mit_sot'
]
self
.
mit_mot_in_slices
=
info
[
'tap_array'
][:
n_mit_mot
]
self
.
mit_mot_in_slices
=
info
[
'tap_array'
][:
n_mit_mot
]
self
.
mit_sot_in_slices
=
info
[
'tap_array'
][
n_mit_mot
:
n_mit_mot
+
n_mit_sot
]
self
.
mit_sot_in_slices
=
info
[
'tap_array'
][
n_mit_mot
:
n_mit_mot
+
n_mit_sot
]
n_mit_mot_ins
=
sum
(
len
(
s
)
for
s
in
self
.
mit_mot_in_slices
)
n_mit_mot_ins
=
sum
(
len
(
s
)
for
s
in
self
.
mit_mot_in_slices
)
n_mit_sot_ins
=
sum
(
len
(
s
)
for
s
in
self
.
mit_sot_in_slices
)
n_mit_sot_ins
=
sum
(
len
(
s
)
for
s
in
self
.
mit_sot_in_slices
)
iimm
=
inner_inputs
[
q
:
q
+
n_mit_mot_ins
]
iimm
=
inner_inputs
[
q
:
q
+
n_mit_mot_ins
]
self
.
inner_in_mit_mot
=
[]
self
.
inner_in_mit_mot
=
[]
qq
=
0
qq
=
0
for
sl
in
self
.
mit_mot_in_slices
:
for
sl
in
self
.
mit_mot_in_slices
:
self
.
inner_in_mit_mot
.
append
(
iimm
[
qq
:
qq
+
len
(
sl
)])
self
.
inner_in_mit_mot
.
append
(
iimm
[
qq
:
qq
+
len
(
sl
)])
qq
+=
len
(
sl
)
qq
+=
len
(
sl
)
q
+=
n_mit_mot_ins
q
+=
n_mit_mot_ins
iims
=
inner_inputs
[
q
:
q
+
n_mit_sot_ins
]
iims
=
inner_inputs
[
q
:
q
+
n_mit_sot_ins
]
self
.
inner_in_mit_sot
=
[]
self
.
inner_in_mit_sot
=
[]
qq
=
0
qq
=
0
for
sl
in
self
.
mit_sot_in_slices
:
for
sl
in
self
.
mit_sot_in_slices
:
self
.
inner_in_mit_sot
.
append
(
iims
[
qq
:
qq
+
len
(
sl
)])
self
.
inner_in_mit_sot
.
append
(
iims
[
qq
:
qq
+
len
(
sl
)])
qq
+=
len
(
sl
)
qq
+=
len
(
sl
)
q
+=
n_mit_sot_ins
q
+=
n_mit_sot_ins
self
.
outer_in_mit_mot
=
outer_inputs
[
p
:
p
+
n_mit_mot
]
self
.
outer_in_mit_mot
=
outer_inputs
[
p
:
p
+
n_mit_mot
]
p
+=
n_mit_mot
p
+=
n_mit_mot
self
.
outer_in_mit_sot
=
outer_inputs
[
p
:
p
+
n_mit_sot
]
self
.
outer_in_mit_sot
=
outer_inputs
[
p
:
p
+
n_mit_sot
]
p
+=
n_mit_sot
p
+=
n_mit_sot
n_sit_sot
=
info
[
'n_sit_sot'
]
n_sit_sot
=
info
[
'n_sit_sot'
]
self
.
outer_in_sit_sot
=
outer_inputs
[
p
:
p
+
n_sit_sot
]
self
.
outer_in_sit_sot
=
outer_inputs
[
p
:
p
+
n_sit_sot
]
self
.
inner_in_sit_sot
=
inner_inputs
[
q
:
q
+
n_sit_sot
]
self
.
inner_in_sit_sot
=
inner_inputs
[
q
:
q
+
n_sit_sot
]
p
+=
n_sit_sot
p
+=
n_sit_sot
q
+=
n_sit_sot
q
+=
n_sit_sot
n_shared_outs
=
info
[
'n_shared_outs'
]
n_shared_outs
=
info
[
'n_shared_outs'
]
self
.
outer_in_shared
=
outer_inputs
[
p
:
p
+
n_shared_outs
]
self
.
outer_in_shared
=
outer_inputs
[
p
:
p
+
n_shared_outs
]
self
.
inner_in_shared
=
inner_inputs
[
q
:
q
+
n_shared_outs
]
self
.
inner_in_shared
=
inner_inputs
[
q
:
q
+
n_shared_outs
]
p
+=
n_shared_outs
p
+=
n_shared_outs
q
+=
n_shared_outs
q
+=
n_shared_outs
n_nit_sot
=
info
[
'n_nit_sot'
]
n_nit_sot
=
info
[
'n_nit_sot'
]
self
.
outer_in_nit_sot
=
outer_inputs
[
p
:
p
+
n_nit_sot
]
self
.
outer_in_nit_sot
=
outer_inputs
[
p
:
p
+
n_nit_sot
]
p
+=
n_nit_sot
p
+=
n_nit_sot
self
.
outer_in_non_seqs
=
outer_inputs
[
p
:]
self
.
outer_in_non_seqs
=
outer_inputs
[
p
:]
...
@@ -778,40 +779,39 @@ class scan_args(object):
...
@@ -778,40 +779,39 @@ class scan_args(object):
self
.
mit_mot_out_slices
=
info
[
'mit_mot_out_slices'
]
self
.
mit_mot_out_slices
=
info
[
'mit_mot_out_slices'
]
n_mit_mot_outs
=
info
[
'n_mit_mot_outs'
]
n_mit_mot_outs
=
info
[
'n_mit_mot_outs'
]
self
.
outer_out_mit_mot
=
outer_outputs
[
p
:
p
+
n_mit_mot
]
self
.
outer_out_mit_mot
=
outer_outputs
[
p
:
p
+
n_mit_mot
]
iomm
=
inner_outputs
[
q
:
q
+
n_mit_mot_outs
]
iomm
=
inner_outputs
[
q
:
q
+
n_mit_mot_outs
]
self
.
inner_out_mit_mot
=
[]
self
.
inner_out_mit_mot
=
[]
qq
=
0
qq
=
0
for
sl
in
self
.
mit_mot_out_slices
:
for
sl
in
self
.
mit_mot_out_slices
:
self
.
inner_out_mit_mot
.
append
(
iomm
[
qq
:
qq
+
len
(
sl
)])
self
.
inner_out_mit_mot
.
append
(
iomm
[
qq
:
qq
+
len
(
sl
)])
qq
+=
len
(
sl
)
qq
+=
len
(
sl
)
p
+=
n_mit_mot
p
+=
n_mit_mot
q
+=
n_mit_mot_outs
q
+=
n_mit_mot_outs
self
.
outer_out_mit_sot
=
outer_outputs
[
p
:
p
+
n_mit_sot
]
self
.
outer_out_mit_sot
=
outer_outputs
[
p
:
p
+
n_mit_sot
]
self
.
inner_out_mit_sot
=
inner_outputs
[
q
:
q
+
n_mit_sot
]
self
.
inner_out_mit_sot
=
inner_outputs
[
q
:
q
+
n_mit_sot
]
p
+=
n_mit_sot
p
+=
n_mit_sot
q
+=
n_mit_sot
q
+=
n_mit_sot
self
.
outer_out_sit_sot
=
outer_outputs
[
p
:
p
+
n_sit_sot
]
self
.
outer_out_sit_sot
=
outer_outputs
[
p
:
p
+
n_sit_sot
]
self
.
inner_out_sit_sot
=
inner_outputs
[
q
:
q
+
n_sit_sot
]
self
.
inner_out_sit_sot
=
inner_outputs
[
q
:
q
+
n_sit_sot
]
p
+=
n_sit_sot
p
+=
n_sit_sot
q
+=
n_sit_sot
q
+=
n_sit_sot
self
.
outer_out_nit_sot
=
outer_outputs
[
p
:
p
+
n_nit_sot
]
self
.
outer_out_nit_sot
=
outer_outputs
[
p
:
p
+
n_nit_sot
]
self
.
inner_out_nit_sot
=
inner_outputs
[
q
:
q
+
n_nit_sot
]
self
.
inner_out_nit_sot
=
inner_outputs
[
q
:
q
+
n_nit_sot
]
p
+=
n_nit_sot
p
+=
n_nit_sot
q
+=
n_nit_sot
q
+=
n_nit_sot
self
.
outer_out_shared
=
outer_outputs
[
p
:
p
+
n_shared_outs
]
self
.
outer_out_shared
=
outer_outputs
[
p
:
p
+
n_shared_outs
]
self
.
inner_out_shared
=
inner_outputs
[
q
:
q
+
n_shared_outs
]
self
.
inner_out_shared
=
inner_outputs
[
q
:
q
+
n_shared_outs
]
p
+=
n_shared_outs
p
+=
n_shared_outs
q
+=
n_shared_outs
q
+=
n_shared_outs
self
.
other_info
=
dict
()
self
.
other_info
=
dict
()
for
k
in
(
'truncate_gradient'
,
'name'
,
'mode'
,
'inplace'
,
for
k
in
(
'truncate_gradient'
,
'name'
,
'mode'
,
'inplace'
,
'gpu'
,
'as_while'
,
'profile'
):
'gpu'
,
'as_while'
,
'profile'
):
self
.
other_info
[
k
]
=
info
[
k
]
self
.
other_info
[
k
]
=
info
[
k
]
inner_inputs
=
property
(
lambda
self
:
(
self
.
inner_in_seqs
+
inner_inputs
=
property
(
lambda
self
:
(
self
.
inner_in_seqs
+
...
@@ -842,7 +842,8 @@ class scan_args(object):
...
@@ -842,7 +842,8 @@ class scan_args(object):
self
.
outer_out_nit_sot
+
self
.
outer_out_nit_sot
+
self
.
outer_out_shared
))
self
.
outer_out_shared
))
info
=
property
(
lambda
self
:
dict
(
n_seqs
=
len
(
self
.
outer_in_seqs
),
info
=
property
(
lambda
self
:
dict
(
n_seqs
=
len
(
self
.
outer_in_seqs
),
n_mit_mot
=
len
(
self
.
outer_in_mit_mot
),
n_mit_mot
=
len
(
self
.
outer_in_mit_mot
),
n_mit_sot
=
len
(
self
.
outer_in_mit_sot
),
n_mit_sot
=
len
(
self
.
outer_in_mit_sot
),
tap_array
=
(
self
.
mit_mot_in_slices
+
tap_array
=
(
self
.
mit_mot_in_slices
+
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论