Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
b3cf9269
提交
b3cf9269
authored
11月 01, 2011
作者:
Razvan Pascanu
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #177 from nouiz/important_fix
Important fix Everything looks good
上级
ba9c7fd4
5ca75d68
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
178 行增加
和
175 行删除
+178
-175
elemwise.py
theano/sandbox/cuda/elemwise.py
+7
-3
basic.py
theano/scalar/basic.py
+2
-2
scan_utils.py
theano/scan_module/scan_utils.py
+169
-170
没有找到文件。
theano/sandbox/cuda/elemwise.py
浏览文件 @
b3cf9269
...
@@ -8,7 +8,7 @@ The elemwise fct are also used with scalar operation! So it can happen that ndim
...
@@ -8,7 +8,7 @@ The elemwise fct are also used with scalar operation! So it can happen that ndim
import
StringIO
,
sys
import
StringIO
,
sys
import
numpy
import
numpy
from
theano
import
Op
,
Type
,
Apply
,
Variable
,
Constant
from
theano
import
Op
,
Type
,
Apply
,
Variable
,
Constant
from
theano
import
tensor
,
scalar
from
theano
import
tensor
,
scalar
,
gof
import
logging
,
copy
import
logging
,
copy
_logger_name
=
'theano.sandbox.cuda.elemwise'
_logger_name
=
'theano.sandbox.cuda.elemwise'
...
@@ -42,8 +42,12 @@ class NaiveAlgo(object):
...
@@ -42,8 +42,12 @@ class NaiveAlgo(object):
:param scalar_op: the scalar operation to execute on each element.
:param scalar_op: the scalar operation to execute on each element.
:param sync: if True, will wait after the kernel launch and check for error call.
:param sync: if True, will wait after the kernel launch and check for error call.
"""
"""
if
scalar_op
.
c_support_code_apply
(
node
=
None
,
nodename
=
"nodename"
):
try
:
raise
SupportCodeError
(
scalar_op
)
code
=
scalar_op
.
c_support_code_apply
(
node
=
None
,
name
=
"nodename"
)
if
code
:
raise
SupportCodeError
(
scalar_op
)
except
gof
.
utils
.
MethodNotDefined
:
pass
self
.
scalar_op
=
scalar_op
self
.
scalar_op
=
scalar_op
self
.
sync
=
sync
self
.
sync
=
sync
self
.
inplace_pattern
=
inplace_pattern
self
.
inplace_pattern
=
inplace_pattern
...
...
theano/scalar/basic.py
浏览文件 @
b3cf9269
...
@@ -2097,14 +2097,14 @@ class Composite(ScalarOp):
...
@@ -2097,14 +2097,14 @@ class Composite(ScalarOp):
return
()
return
()
return
tuple
(
rval
)
return
tuple
(
rval
)
def
c_support_code_apply
(
self
,
node
,
n
oden
ame
):
def
c_support_code_apply
(
self
,
node
,
name
):
rval
=
[]
rval
=
[]
for
subnode
,
subnodename
in
zip
(
self
.
env
.
toposort
(),
self
.
nodenames
):
for
subnode
,
subnodename
in
zip
(
self
.
env
.
toposort
(),
self
.
nodenames
):
try
:
try
:
rval
.
append
(
rval
.
append
(
subnode
.
op
.
c_support_code_apply
(
subnode
.
op
.
c_support_code_apply
(
subnode
,
subnode
,
subnodename
%
dict
(
nodename
=
n
oden
ame
)))
subnodename
%
dict
(
nodename
=
name
)))
except
gof
.
utils
.
MethodNotDefined
:
except
gof
.
utils
.
MethodNotDefined
:
pass
pass
return
"
\n
"
.
join
(
rval
)
return
"
\n
"
.
join
(
rval
)
...
...
theano/scan_module/scan_utils.py
浏览文件 @
b3cf9269
...
@@ -4,11 +4,11 @@ This module provides utility functions for the Scan Op
...
@@ -4,11 +4,11 @@ This module provides utility functions for the Scan Op
See scan.py for details on scan
See scan.py for details on scan
"""
"""
__docformat__
=
'restructedtext en'
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"Frederic Bastien "
"James Bergstra "
"James Bergstra "
"Pascal Lamblin "
"Pascal Lamblin "
"Arnaud Bergeron"
)
"Arnaud Bergeron"
)
__copyright__
=
"(c) 2010, Universite de Montreal"
__copyright__
=
"(c) 2010, Universite de Montreal"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
...
@@ -16,7 +16,6 @@ import copy
...
@@ -16,7 +16,6 @@ import copy
import
logging
import
logging
import
numpy
import
numpy
from
theano
import
config
from
theano.compile.pfunc
import
rebuild_collect_shared
from
theano.compile.pfunc
import
rebuild_collect_shared
from
theano
import
gof
from
theano
import
gof
from
theano
import
tensor
,
scalar
from
theano
import
tensor
,
scalar
...
@@ -30,7 +29,8 @@ import theano
...
@@ -30,7 +29,8 @@ import theano
# Logging function for sending warning or info
# Logging function for sending warning or info
_logger
=
logging
.
getLogger
(
'theano.scan_utils'
)
_logger
=
logging
.
getLogger
(
'theano.scan_utils'
)
def
safe_new
(
x
,
tag
=
''
):
def
safe_new
(
x
,
tag
=
''
):
"""
"""
Internal function that constructs a new variable from x with the same
Internal function that constructs a new variable from x with the same
type, but with a different name ( old name + tag). This function is used
type, but with a different name ( old name + tag). This function is used
...
@@ -81,7 +81,7 @@ class until(object):
...
@@ -81,7 +81,7 @@ class until(object):
assert
self
.
condition
.
ndim
==
0
assert
self
.
condition
.
ndim
==
0
def
traverse
(
out
,
x
,
x_copy
,
d
):
def
traverse
(
out
,
x
,
x_copy
,
d
):
''' Function used by scan to parse the tree and figure out which nodes
''' Function used by scan to parse the tree and figure out which nodes
it needs to replace. There are two options :
it needs to replace. There are two options :
1) x and x_copy or on host, then you would replace x with x_copy
1) x and x_copy or on host, then you would replace x with x_copy
...
@@ -111,10 +111,10 @@ def traverse(out, x,x_copy, d):
...
@@ -111,10 +111,10 @@ def traverse(out, x,x_copy, d):
def
hash_listsDictsTuples
(
x
):
def
hash_listsDictsTuples
(
x
):
hash_value
=
0
hash_value
=
0
if
isinstance
(
x
,
dict
):
if
isinstance
(
x
,
dict
):
for
k
,
v
in
x
.
iteritems
():
for
k
,
v
in
x
.
iteritems
():
hash_value
^=
hash_listsDictsTuples
(
k
)
hash_value
^=
hash_listsDictsTuples
(
k
)
hash_value
^=
hash_listsDictsTuples
(
v
)
hash_value
^=
hash_listsDictsTuples
(
v
)
elif
isinstance
(
x
,
(
list
,
tuple
)):
elif
isinstance
(
x
,
(
list
,
tuple
)):
for
v
in
x
:
for
v
in
x
:
hash_value
^=
hash_listsDictsTuples
(
v
)
hash_value
^=
hash_listsDictsTuples
(
v
)
else
:
else
:
...
@@ -122,10 +122,10 @@ def hash_listsDictsTuples(x):
...
@@ -122,10 +122,10 @@ def hash_listsDictsTuples(x):
return
hash_value
return
hash_value
def
clone
(
output
def
clone
(
output
,
,
replace
=
None
replace
=
None
,
,
strict
=
True
strict
=
True
,
,
copy_inputs
=
True
):
copy_inputs
=
True
):
"""
"""
Function that allows replacing subgraphs of a computational
Function that allows replacing subgraphs of a computational
graph. It returns a copy of the initial subgraph with the corresponding
graph. It returns a copy of the initial subgraph with the corresponding
...
@@ -140,17 +140,16 @@ def clone( output
...
@@ -140,17 +140,16 @@ def clone( output
replaced by what
replaced by what
"""
"""
inps
,
outs
,
other_stuff
=
rebuild_collect_shared
(
output
inps
,
outs
,
other_stuff
=
rebuild_collect_shared
(
output
,
,
[]
[],
,
replace
replace
,
,
[]
[],
,
strict
strict
,
,
copy_inputs
copy_inputs
)
)
return
outs
return
outs
def
get_updates_and_outputs
(
ls
):
def
get_updates_and_outputs
(
ls
):
"""
"""
This function tries to recognize the updates dictionary, the
This function tries to recognize the updates dictionary, the
...
@@ -160,7 +159,7 @@ def get_updates_and_outputs(ls):
...
@@ -160,7 +159,7 @@ def get_updates_and_outputs(ls):
"""
"""
def
is_outputs
(
elem
):
def
is_outputs
(
elem
):
if
(
isinstance
(
elem
,
(
list
,
tuple
))
and
if
(
isinstance
(
elem
,
(
list
,
tuple
))
and
all
([
isinstance
(
x
,
theano
.
Variable
)
for
x
in
elem
])):
all
([
isinstance
(
x
,
theano
.
Variable
)
for
x
in
elem
])):
return
True
return
True
if
isinstance
(
elem
,
theano
.
Variable
):
if
isinstance
(
elem
,
theano
.
Variable
):
...
@@ -172,7 +171,7 @@ def get_updates_and_outputs(ls):
...
@@ -172,7 +171,7 @@ def get_updates_and_outputs(ls):
return
True
return
True
# Dictionaries can be given as lists of tuples
# Dictionaries can be given as lists of tuples
if
(
isinstance
(
elem
,
(
list
,
tuple
))
and
if
(
isinstance
(
elem
,
(
list
,
tuple
))
and
all
([
isinstance
(
x
,
(
list
,
tuple
))
and
len
(
x
)
==
2
all
([
isinstance
(
x
,
(
list
,
tuple
))
and
len
(
x
)
==
2
for
x
in
elem
])):
for
x
in
elem
])):
return
True
return
True
return
False
return
False
...
@@ -204,13 +203,13 @@ def get_updates_and_outputs(ls):
...
@@ -204,13 +203,13 @@ def get_updates_and_outputs(ls):
if
is_updates
(
ls
[
1
]):
if
is_updates
(
ls
[
1
]):
return
(
None
,
_list
(
ls
[
0
]),
dict
(
ls
[
1
]))
return
(
None
,
_list
(
ls
[
0
]),
dict
(
ls
[
1
]))
elif
is_condition
(
ls
[
1
]):
elif
is_condition
(
ls
[
1
]):
return
(
ls
[
1
]
.
condition
,
_list
(
ls
[
0
]),
{})
return
(
ls
[
1
]
.
condition
,
_list
(
ls
[
0
]),
{})
else
:
else
:
raise
ValueError
(
error_msg
)
raise
ValueError
(
error_msg
)
elif
is_updates
(
ls
[
0
]):
elif
is_updates
(
ls
[
0
]):
if
is_outputs
(
ls
[
1
]):
if
is_outputs
(
ls
[
1
]):
_logger
.
warning
(
deprication_msg
)
_logger
.
warning
(
deprication_msg
)
return
(
None
,
_list
(
ls
[
1
]),
dict
(
ls
[
0
])
)
return
(
None
,
_list
(
ls
[
1
]),
dict
(
ls
[
0
])
)
elif
is_condition
(
ls
[
1
]):
elif
is_condition
(
ls
[
1
]):
return
(
ls
[
1
]
.
condition
,
[],
dict
(
ls
[
0
]))
return
(
ls
[
1
]
.
condition
,
[],
dict
(
ls
[
0
]))
else
:
else
:
...
@@ -251,7 +250,7 @@ def isNaN_or_Inf_or_None(x):
...
@@ -251,7 +250,7 @@ def isNaN_or_Inf_or_None(x):
isStr
=
False
isStr
=
False
if
not
isNaN
and
not
isInf
:
if
not
isNaN
and
not
isInf
:
try
:
try
:
val
=
get_constant_value
(
x
)
val
=
get_constant_value
(
x
)
isInf
=
numpy
.
isinf
(
val
)
isInf
=
numpy
.
isinf
(
val
)
isNaN
=
numpy
.
isnan
(
val
)
isNaN
=
numpy
.
isnan
(
val
)
except
Exception
:
except
Exception
:
...
@@ -264,7 +263,7 @@ def isNaN_or_Inf_or_None(x):
...
@@ -264,7 +263,7 @@ def isNaN_or_Inf_or_None(x):
return
isNone
or
isNaN
or
isInf
or
isStr
return
isNone
or
isNaN
or
isInf
or
isStr
def
expand
(
tensor_var
,
size
):
def
expand
(
tensor_var
,
size
):
'''
'''
Transoforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..)
Transoforms the shape of a tensor from (d1, d2 ... ) to ( d1+size, d2, ..)
by adding 0s at the end of the tensor.
by adding 0s at the end of the tensor.
...
@@ -272,13 +271,14 @@ def expand( tensor_var, size):
...
@@ -272,13 +271,14 @@ def expand( tensor_var, size):
# Corner case that I might use in an optimization
# Corner case that I might use in an optimization
if
size
==
0
:
if
size
==
0
:
return
tensor_var
return
tensor_var
shapes
=
[
tensor_var
.
shape
[
x
]
for
x
in
xrange
(
tensor_var
.
ndim
)
]
shapes
=
[
tensor_var
.
shape
[
x
]
for
x
in
xrange
(
tensor_var
.
ndim
)
]
zeros_shape
=
[
size
+
shapes
[
0
]]
+
shapes
[
1
:]
zeros_shape
=
[
size
+
shapes
[
0
]]
+
shapes
[
1
:]
empty
=
tensor
.
zeros
(
zeros_shape
empty
=
tensor
.
zeros
(
zeros_shape
,
,
dtype
=
tensor_var
.
dtype
)
dtype
=
tensor_var
.
dtype
)
return
tensor
.
set_subtensor
(
empty
[:
shapes
[
0
]],
tensor_var
)
return
tensor
.
set_subtensor
(
empty
[:
shapes
[
0
]],
tensor_var
)
def
equal_computations
(
xs
,
ys
,
in_xs
=
None
,
in_ys
=
None
,
strict
=
True
):
def
equal_computations
(
xs
,
ys
,
in_xs
=
None
,
in_ys
=
None
,
strict
=
True
):
'''
'''
Checks if to theano graphs represent the same computations (with
Checks if to theano graphs represent the same computations (with
equivalence of inputs defined by map). Inputs are always assumed
equivalence of inputs defined by map). Inputs are always assumed
...
@@ -289,8 +289,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
...
@@ -289,8 +289,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
if
in_ys
is
None
:
if
in_ys
is
None
:
in_ys
=
[]
in_ys
=
[]
for
x
,
y
in
zip
(
xs
,
ys
):
for
x
,
y
in
zip
(
xs
,
ys
):
if
x
.
owner
and
not
y
.
owner
:
if
x
.
owner
and
not
y
.
owner
:
return
False
return
False
if
y
.
owner
and
not
x
.
owner
:
if
y
.
owner
and
not
x
.
owner
:
...
@@ -300,7 +299,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
...
@@ -300,7 +299,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
return
False
return
False
if
len
(
in_xs
)
!=
len
(
in_ys
):
if
len
(
in_xs
)
!=
len
(
in_ys
):
return
False
return
False
for
_x
,
_y
in
zip
(
in_xs
,
in_ys
):
for
_x
,
_y
in
zip
(
in_xs
,
in_ys
):
if
_x
.
type
!=
_y
.
type
:
if
_x
.
type
!=
_y
.
type
:
return
False
return
False
...
@@ -308,17 +307,17 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
...
@@ -308,17 +307,17 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
nds_y
=
gof
.
graph
.
io_toposort
(
in_ys
,
ys
)
nds_y
=
gof
.
graph
.
io_toposort
(
in_ys
,
ys
)
if
len
(
nds_x
)
!=
len
(
nds_y
):
if
len
(
nds_x
)
!=
len
(
nds_y
):
return
False
return
False
common
=
set
(
zip
(
in_xs
,
in_ys
))
common
=
set
(
zip
(
in_xs
,
in_ys
))
n_nodes
=
len
(
nds_x
)
n_nodes
=
len
(
nds_x
)
cont
=
True
cont
=
True
idx
=
0
idx
=
0
for
dx
,
dy
in
zip
(
xs
,
ys
):
for
dx
,
dy
in
zip
(
xs
,
ys
):
if
not
dx
.
owner
or
not
dy
.
owner
:
if
not
dx
.
owner
or
not
dy
.
owner
:
if
dy
.
owner
or
dx
.
owner
:
if
dy
.
owner
or
dx
.
owner
:
return
False
return
False
elif
(
isinstance
(
dx
,
tensor
.
Constant
)
and
elif
(
isinstance
(
dx
,
tensor
.
Constant
)
and
isinstance
(
dy
,
tensor
.
Constant
)):
isinstance
(
dy
,
tensor
.
Constant
)):
if
not
(
numpy
.
all
(
dx
.
data
==
dy
.
data
)
and
if
not
(
numpy
.
all
(
dx
.
data
==
dy
.
data
)
and
dx
.
dtype
==
dy
.
dtype
and
dx
.
dtype
==
dy
.
dtype
and
dx
.
data
.
shape
==
dy
.
data
.
shape
):
dx
.
data
.
shape
==
dy
.
data
.
shape
):
return
False
return
False
...
@@ -329,7 +328,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
...
@@ -329,7 +328,7 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
if
dx
.
type
!=
dy
.
type
:
if
dx
.
type
!=
dy
.
type
:
return
False
return
False
else
:
else
:
if
(
dx
,
dy
)
not
in
common
:
if
(
dx
,
dy
)
not
in
common
:
return
False
return
False
while
cont
and
idx
<
n_nodes
:
while
cont
and
idx
<
n_nodes
:
...
@@ -342,9 +341,9 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
...
@@ -342,9 +341,9 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
elif
len
(
nd_x
.
outputs
)
!=
len
(
nd_y
.
outputs
):
elif
len
(
nd_x
.
outputs
)
!=
len
(
nd_y
.
outputs
):
cont
=
False
cont
=
False
else
:
else
:
for
dx
,
dy
in
zip
(
nd_x
.
inputs
,
nd_y
.
inputs
):
for
dx
,
dy
in
zip
(
nd_x
.
inputs
,
nd_y
.
inputs
):
if
(
dx
,
dy
)
not
in
common
:
if
(
dx
,
dy
)
not
in
common
:
if
strict
and
dx
!=
dy
:
if
strict
and
dx
!=
dy
:
if
(
isinstance
(
dx
,
tensor
.
Constant
)
and
if
(
isinstance
(
dx
,
tensor
.
Constant
)
and
isinstance
(
dy
,
tensor
.
Constant
)):
isinstance
(
dy
,
tensor
.
Constant
)):
if
not
(
numpy
.
all
(
dx
.
data
==
dy
.
data
)
and
if
not
(
numpy
.
all
(
dx
.
data
==
dy
.
data
)
and
...
@@ -359,32 +358,27 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
...
@@ -359,32 +358,27 @@ def equal_computations(xs,ys, in_xs = None, in_ys = None, strict=True):
cont
=
cont
and
(
dx
.
type
==
dy
.
type
)
cont
=
cont
and
(
dx
.
type
==
dy
.
type
)
if
cont
:
if
cont
:
for
dx
,
dy
in
zip
(
nd_x
.
outputs
,
nd_y
.
outputs
):
for
dx
,
dy
in
zip
(
nd_x
.
outputs
,
nd_y
.
outputs
):
common
.
add
((
dx
,
dy
))
common
.
add
((
dx
,
dy
))
idx
+=
1
idx
+=
1
return
cont
return
cont
def
infer_shape
(
outs
,
inputs
,
input_shapes
):
def
infer_shape
(
outs
,
inputs
,
input_shapes
):
'''
'''
Compute the shape of the outputs given the shape of the inputs
Compute the shape of the outputs given the shape of the inputs
of a theano graph.
of a theano graph.
We do it this way to don't compile the inner function just to get
the shape. Change to ShapeFeature could request change in this function.
'''
'''
# We use a ShapeFeature because it has all the necessary logic inside.
# We use a ShapeFeature because it has all the necessary logic
# We don't use the Feature interface, so we need to initialize some
# inside. We don't use the full ShapeFeature interface, but we
# things by hand.
# let it initialize itself with an empty env, otherwise we will
# need to do it manually
shape_feature
=
tensor
.
opt
.
ShapeFeature
()
shape_feature
=
tensor
.
opt
.
ShapeFeature
()
shape_feature
.
on_attach
(
theano
.
gof
.
Env
([],
[]))
# Variable -> tuple(scalars) or None (All tensor vars map to tuple)
# All keys of shape_of should be either in valid or in invalid
shape_feature
.
shape_of
=
{}
# To avoid merging lots of ones together.
shape_feature
.
lscalar_one
=
tensor
.
constant
(
1
,
dtype
=
'int64'
)
# Initialize shape_of with the input shapes
# Initialize shape_of with the input shapes
for
inp
,
inp_shp
in
zip
(
inputs
,
input_shapes
):
for
inp
,
inp_shp
in
zip
(
inputs
,
input_shapes
):
...
@@ -418,6 +412,7 @@ def infer_shape(outs, inputs, input_shapes):
...
@@ -418,6 +412,7 @@ def infer_shape(outs, inputs, input_shapes):
ret
.
append
(
shape_feature
.
shape_of
[
o
])
ret
.
append
(
shape_feature
.
shape_of
[
o
])
return
ret
return
ret
class
Validator
(
object
):
class
Validator
(
object
):
def
__init__
(
self
,
valid
=
[],
invalid
=
[],
valid_equivalent
=
{}):
def
__init__
(
self
,
valid
=
[],
invalid
=
[],
valid_equivalent
=
{}):
'''
'''
...
@@ -496,35 +491,35 @@ def scan_can_remove_outs(op, out_idxs):
...
@@ -496,35 +491,35 @@ def scan_can_remove_outs(op, out_idxs):
the first one with the indices of outs that can be removed, the second
the first one with the indices of outs that can be removed, the second
with the outputs that can not be removed.
with the outputs that can not be removed.
'''
'''
non_removable
=
[
o
for
i
,
o
in
enumerate
(
op
.
outputs
)
if
i
not
in
non_removable
=
[
o
for
i
,
o
in
enumerate
(
op
.
outputs
)
if
i
not
in
out_idxs
]
out_idxs
]
required_inputs
=
gof
.
graph
.
inputs
(
non_removable
)
required_inputs
=
gof
.
graph
.
inputs
(
non_removable
)
out_ins
=
[]
out_ins
=
[]
offset
=
op
.
n_seqs
offset
=
op
.
n_seqs
lim
=
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
lim
=
op
.
n_mit_mot
+
op
.
n_mit_sot
+
op
.
n_sit_sot
for
idx
in
range
(
lim
):
for
idx
in
range
(
lim
):
n_ins
=
len
(
op
.
info
[
'tap_array'
][
idx
])
n_ins
=
len
(
op
.
info
[
'tap_array'
][
idx
])
out_ins
+=
[
op
.
inputs
[
offset
:
offset
+
n_ins
]]
out_ins
+=
[
op
.
inputs
[
offset
:
offset
+
n_ins
]]
offset
+=
n_ins
offset
+=
n_ins
out_ins
+=
[
[]
for
k
in
xrange
(
op
.
n_nit_sot
)
]
out_ins
+=
[
[]
for
k
in
xrange
(
op
.
n_nit_sot
)
]
out_ins
+=
[
[
op
.
inputs
[
offset
+
k
]]
for
k
in
xrange
(
op
.
n_shared_outs
)]
out_ins
+=
[
[
op
.
inputs
[
offset
+
k
]]
for
k
in
xrange
(
op
.
n_shared_outs
)]
added
=
True
added
=
True
out_idxs_mask
=
[
1
for
idx
in
out_idxs
]
out_idxs_mask
=
[
1
for
idx
in
out_idxs
]
while
added
:
while
added
:
added
=
False
added
=
False
for
pos
,
idx
in
enumerate
(
out_idxs
):
for
pos
,
idx
in
enumerate
(
out_idxs
):
if
(
out_idxs_mask
[
pos
]
and
if
(
out_idxs_mask
[
pos
]
and
numpy
.
any
([
x
in
required_inputs
for
x
in
out_ins
[
idx
]])
):
numpy
.
any
([
x
in
required_inputs
for
x
in
out_ins
[
idx
]])):
# This output is required ..
# This output is required ..
out_idxs_mask
[
pos
]
=
0
out_idxs_mask
[
pos
]
=
0
required_inputs
+=
gof
.
graph
.
inputs
([
op
.
outputs
[
idx
]])
required_inputs
+=
gof
.
graph
.
inputs
([
op
.
outputs
[
idx
]])
added
=
True
added
=
True
required_outs
=
[
x
for
i
,
x
in
enumerate
(
out_idxs
)
required_outs
=
[
x
for
i
,
x
in
enumerate
(
out_idxs
)
if
out_idxs_mask
[
i
]
==
0
]
if
out_idxs_mask
[
i
]
==
0
]
not_required
=
[
x
for
i
,
x
in
enumerate
(
out_idxs
)
if
out_idxs_mask
[
i
]
==
1
]
not_required
=
[
x
for
i
,
x
in
enumerate
(
out_idxs
)
if
out_idxs_mask
[
i
]
==
1
]
return
(
required_outs
,
not_required
)
return
(
required_outs
,
not_required
)
...
@@ -539,107 +534,107 @@ def compress_outs(op, not_required, inputs):
...
@@ -539,107 +534,107 @@ def compress_outs(op, not_required, inputs):
node inputs, and changing the dictionary.
node inputs, and changing the dictionary.
'''
'''
info
=
{}
info
=
{}
info
[
'tap_array'
]
=
[]
info
[
'tap_array'
]
=
[]
info
[
'n_seqs'
]
=
op
.
info
[
'n_seqs'
]
info
[
'n_seqs'
]
=
op
.
info
[
'n_seqs'
]
info
[
'n_mit_mot'
]
=
0
info
[
'n_mit_mot'
]
=
0
info
[
'n_mit_mot_outs'
]
=
0
info
[
'n_mit_mot_outs'
]
=
0
info
[
'mit_mot_out_slices'
]
=
[]
info
[
'mit_mot_out_slices'
]
=
[]
info
[
'n_mit_sot'
]
=
0
info
[
'n_mit_sot'
]
=
0
info
[
'n_sit_sot'
]
=
0
info
[
'n_sit_sot'
]
=
0
info
[
'n_shared_outs'
]
=
0
info
[
'n_shared_outs'
]
=
0
info
[
'n_nit_sot'
]
=
0
info
[
'n_nit_sot'
]
=
0
info
[
'truncate_gradient'
]
=
op
.
info
[
'truncate_gradient'
]
info
[
'truncate_gradient'
]
=
op
.
info
[
'truncate_gradient'
]
info
[
'name'
]
=
op
.
info
[
'name'
]
info
[
'name'
]
=
op
.
info
[
'name'
]
info
[
'inplace'
]
=
op
.
info
[
'inplace'
]
info
[
'inplace'
]
=
op
.
info
[
'inplace'
]
info
[
'gpu'
]
=
op
.
info
[
'gpu'
]
info
[
'gpu'
]
=
op
.
info
[
'gpu'
]
info
[
'mode'
]
=
op
.
info
[
'mode'
]
info
[
'mode'
]
=
op
.
info
[
'mode'
]
info
[
'as_while'
]
=
op
.
info
[
'as_while'
]
info
[
'as_while'
]
=
op
.
info
[
'as_while'
]
info
[
'profile'
]
=
op
.
info
[
'profile'
]
info
[
'profile'
]
=
op
.
info
[
'profile'
]
op_inputs
=
op
.
inputs
[:
op
.
n_seqs
]
op_inputs
=
op
.
inputs
[:
op
.
n_seqs
]
op_outputs
=
[]
op_outputs
=
[]
node_inputs
=
inputs
[:
op
.
n_seqs
+
1
]
node_inputs
=
inputs
[:
op
.
n_seqs
+
1
]
map_old_new
=
{}
map_old_new
=
{}
offset
=
0
offset
=
0
ni_offset
=
op
.
n_seqs
+
1
ni_offset
=
op
.
n_seqs
+
1
i_offset
=
op
.
n_seqs
i_offset
=
op
.
n_seqs
o_offset
=
0
o_offset
=
0
curr_pos
=
0
curr_pos
=
0
for
idx
in
xrange
(
op
.
info
[
'n_mit_mot'
]):
for
idx
in
xrange
(
op
.
info
[
'n_mit_mot'
]):
if
offset
+
idx
not
in
not_required
:
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
curr_pos
+=
1
info
[
'n_mit_mot'
]
+=
1
info
[
'n_mit_mot'
]
+=
1
info
[
'tap_array'
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
info
[
'tap_array'
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
info
[
'mit_mot_out_slices'
]
+=
[
op
.
mit_mot_out_slices
[
offset
+
idx
]]
info
[
'mit_mot_out_slices'
]
+=
[
op
.
mit_mot_out_slices
[
offset
+
idx
]]
# input taps
# input taps
for
jdx
in
op
.
tap_array
[
offset
+
idx
]:
for
jdx
in
op
.
tap_array
[
offset
+
idx
]:
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
i_offset
+=
1
i_offset
+=
1
# output taps
# output taps
for
jdx
in
op
.
mit_mot_out_slices
[
offset
+
idx
]:
for
jdx
in
op
.
mit_mot_out_slices
[
offset
+
idx
]:
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
o_offset
+=
1
o_offset
+=
1
# node inputs
# node inputs
node_inputs
+=
[
inputs
[
ni_offset
+
idx
]]
node_inputs
+=
[
inputs
[
ni_offset
+
idx
]]
else
:
else
:
o_offset
+=
len
(
op
.
mit_mot_out_slices
[
offset
+
idx
])
o_offset
+=
len
(
op
.
mit_mot_out_slices
[
offset
+
idx
])
i_offset
+=
len
(
op
.
tap_array
[
offset
+
idx
])
i_offset
+=
len
(
op
.
tap_array
[
offset
+
idx
])
info
[
'n_mit_mot_outs'
]
=
len
(
op_outputs
)
info
[
'n_mit_mot_outs'
]
=
len
(
op_outputs
)
offset
+=
op
.
n_mit_mot
offset
+=
op
.
n_mit_mot
ni_offset
+=
op
.
n_mit_mot
ni_offset
+=
op
.
n_mit_mot
for
idx
in
xrange
(
op
.
info
[
'n_mit_sot'
]):
for
idx
in
xrange
(
op
.
info
[
'n_mit_sot'
]):
if
offset
+
idx
not
in
not_required
:
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
curr_pos
+=
1
info
[
'n_mit_sot'
]
+=
1
info
[
'n_mit_sot'
]
+=
1
info
[
'tap_array'
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
info
[
'tap_array'
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
#input taps
#input taps
for
jdx
in
op
.
tap_array
[
offset
+
idx
]:
for
jdx
in
op
.
tap_array
[
offset
+
idx
]:
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
i_offset
+=
1
i_offset
+=
1
#output taps
#output taps
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
o_offset
+=
1
o_offset
+=
1
#node inputs
#node inputs
node_inputs
+=
[
inputs
[
ni_offset
+
idx
]]
node_inputs
+=
[
inputs
[
ni_offset
+
idx
]]
else
:
else
:
o_offset
+=
1
o_offset
+=
1
i_offset
+=
len
(
op
.
tap_array
[
offset
+
idx
])
i_offset
+=
len
(
op
.
tap_array
[
offset
+
idx
])
offset
+=
op
.
n_mit_sot
offset
+=
op
.
n_mit_sot
ni_offset
+=
op
.
n_mit_sot
ni_offset
+=
op
.
n_mit_sot
for
idx
in
xrange
(
op
.
info
[
'n_sit_sot'
]):
for
idx
in
xrange
(
op
.
info
[
'n_sit_sot'
]):
if
offset
+
idx
not
in
not_required
:
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
curr_pos
+=
1
info
[
'n_sit_sot'
]
+=
1
info
[
'n_sit_sot'
]
+=
1
info
[
'tap_array'
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
info
[
'tap_array'
]
+=
[
op
.
tap_array
[
offset
+
idx
]]
#input taps
#input taps
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
i_offset
+=
1
i_offset
+=
1
#output taps
#output taps
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
o_offset
+=
1
o_offset
+=
1
#node inputs
#node inputs
node_inputs
+=
[
inputs
[
ni_offset
+
idx
]]
node_inputs
+=
[
inputs
[
ni_offset
+
idx
]]
else
:
else
:
o_offset
+=
1
o_offset
+=
1
i_offset
+=
1
i_offset
+=
1
offset
+=
op
.
n_sit_sot
offset
+=
op
.
n_sit_sot
ni_offset
+=
op
.
n_sit_sot
ni_offset
+=
op
.
n_sit_sot
nit_sot_ins
=
[]
nit_sot_ins
=
[]
for
idx
in
xrange
(
op
.
info
[
'n_nit_sot'
]):
for
idx
in
xrange
(
op
.
info
[
'n_nit_sot'
]):
if
offset
+
idx
not
in
not_required
:
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
curr_pos
+=
1
info
[
'n_nit_sot'
]
+=
1
info
[
'n_nit_sot'
]
+=
1
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
o_offset
+=
1
o_offset
+=
1
nit_sot_ins
+=
[
inputs
[
ni_offset
+
idx
+
op
.
n_shared_outs
]]
nit_sot_ins
+=
[
inputs
[
ni_offset
+
idx
+
op
.
n_shared_outs
]]
else
:
else
:
o_offset
+=
1
o_offset
+=
1
...
@@ -647,14 +642,14 @@ def compress_outs(op, not_required, inputs):
...
@@ -647,14 +642,14 @@ def compress_outs(op, not_required, inputs):
shared_ins
=
[]
shared_ins
=
[]
for
idx
in
xrange
(
op
.
info
[
'n_shared_outs'
]):
for
idx
in
xrange
(
op
.
info
[
'n_shared_outs'
]):
if
offset
+
idx
not
in
not_required
:
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
curr_pos
+=
1
info
[
'n_shared_outs'
]
+=
1
info
[
'n_shared_outs'
]
+=
1
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
o_offset
+=
1
o_offset
+=
1
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
op_inputs
+=
[
op
.
inputs
[
i_offset
]]
i_offset
+=
1
i_offset
+=
1
shared_ins
+=
[
inputs
[
ni_offset
+
idx
]]
shared_ins
+=
[
inputs
[
ni_offset
+
idx
]]
else
:
else
:
o_offset
+=
1
o_offset
+=
1
i_offset
+=
1
i_offset
+=
1
...
@@ -662,14 +657,15 @@ def compress_outs(op, not_required, inputs):
...
@@ -662,14 +657,15 @@ def compress_outs(op, not_required, inputs):
node_inputs
+=
nit_sot_ins
node_inputs
+=
nit_sot_ins
# other stuff
# other stuff
op_inputs
+=
op
.
inputs
[
i_offset
:]
op_inputs
+=
op
.
inputs
[
i_offset
:]
node_inputs
+=
inputs
[
ni_offset
+
op
.
n_shared_outs
+
op
.
n_nit_sot
:]
node_inputs
+=
inputs
[
ni_offset
+
op
.
n_shared_outs
+
op
.
n_nit_sot
:]
if
op
.
as_while
:
if
op
.
as_while
:
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
op_outputs
+=
[
op
.
outputs
[
o_offset
]]
map_old_new
[
o_offset
]
=
len
(
op_outputs
)
-
1
map_old_new
[
o_offset
]
=
len
(
op_outputs
)
-
1
#map_old_new[len(op_outputs)-1] = o_offset
#map_old_new[len(op_outputs)-1] = o_offset
return
(
op_inputs
,
op_outputs
,
info
,
node_inputs
,
map_old_new
)
return
(
op_inputs
,
op_outputs
,
info
,
node_inputs
,
map_old_new
)
def
find_up
(
l_node
,
f_node
):
def
find_up
(
l_node
,
f_node
):
r"""
r"""
Goes up in the graph and returns True if a node in nodes is found.
Goes up in the graph and returns True if a node in nodes is found.
...
@@ -678,11 +674,12 @@ def find_up(l_node, f_node):
...
@@ -678,11 +674,12 @@ def find_up(l_node, f_node):
l_outs
=
l_node
.
outputs
l_outs
=
l_node
.
outputs
else
:
else
:
l_outs
=
l_node
l_outs
=
l_node
l_ins
=
gof
.
graph
.
inputs
(
l_outs
)
l_ins
=
gof
.
graph
.
inputs
(
l_outs
)
nodes
=
gof
.
graph
.
io_toposort
(
l_ins
,
l_outs
)
nodes
=
gof
.
graph
.
io_toposort
(
l_ins
,
l_outs
)
return
f_node
in
nodes
return
f_node
in
nodes
def
reconstruct_graph
(
inputs
,
outputs
,
tag
=
None
):
def
reconstruct_graph
(
inputs
,
outputs
,
tag
=
None
):
"""
"""
Different interface to clone, that allows you to pass inputs.
Different interface to clone, that allows you to pass inputs.
Compared to clone, this method always replaces the inputs with
Compared to clone, this method always replaces the inputs with
...
@@ -691,7 +688,7 @@ def reconstruct_graph(inputs, outputs, tag = None):
...
@@ -691,7 +688,7 @@ def reconstruct_graph(inputs, outputs, tag = None):
"""
"""
if
tag
is
None
:
if
tag
is
None
:
tag
=
''
tag
=
''
nw_inputs
=
[
safe_new
(
x
,
tag
)
for
x
in
inputs
]
nw_inputs
=
[
safe_new
(
x
,
tag
)
for
x
in
inputs
]
givens
=
{}
givens
=
{}
for
nw_x
,
x
in
zip
(
nw_inputs
,
inputs
):
for
nw_x
,
x
in
zip
(
nw_inputs
,
inputs
):
givens
[
x
]
=
nw_x
givens
[
x
]
=
nw_x
...
@@ -700,9 +697,10 @@ def reconstruct_graph(inputs, outputs, tag = None):
...
@@ -700,9 +697,10 @@ def reconstruct_graph(inputs, outputs, tag = None):
if
isinstance
(
inp
,
theano
.
Constant
):
if
isinstance
(
inp
,
theano
.
Constant
):
givens
[
inp
]
=
inp
.
clone
()
givens
[
inp
]
=
inp
.
clone
()
nw_outputs
=
clone
(
outputs
,
replace
=
givens
)
nw_outputs
=
clone
(
outputs
,
replace
=
givens
)
return
(
nw_inputs
,
nw_outputs
)
return
(
nw_inputs
,
nw_outputs
)
class
scan_args
(
object
):
class
scan_args
(
object
):
"""Parses the inputs and outputs of scan in an easy to manipulate format"""
"""Parses the inputs and outputs of scan in an easy to manipulate format"""
def
__init__
(
self
,
outer_inputs
,
outer_outputs
,
def
__init__
(
self
,
outer_inputs
,
outer_outputs
,
...
@@ -714,14 +712,14 @@ class scan_args(object):
...
@@ -714,14 +712,14 @@ class scan_args(object):
inner_outputs
=
rval
[
1
][:
-
1
]
inner_outputs
=
rval
[
1
][:
-
1
]
else
:
else
:
inner_outputs
=
rval
[
1
]
inner_outputs
=
rval
[
1
]
inner_inputs
=
rval
[
0
]
inner_inputs
=
rval
[
0
]
p
=
1
p
=
1
q
=
0
q
=
0
n_seqs
=
info
[
'n_seqs'
]
n_seqs
=
info
[
'n_seqs'
]
self
.
outer_in_seqs
=
outer_inputs
[
p
:
p
+
n_seqs
]
self
.
outer_in_seqs
=
outer_inputs
[
p
:
p
+
n_seqs
]
self
.
inner_in_seqs
=
inner_inputs
[
q
:
q
+
n_seqs
]
self
.
inner_in_seqs
=
inner_inputs
[
q
:
q
+
n_seqs
]
p
+=
n_seqs
p
+=
n_seqs
q
+=
n_seqs
q
+=
n_seqs
...
@@ -729,46 +727,47 @@ class scan_args(object):
...
@@ -729,46 +727,47 @@ class scan_args(object):
n_mit_sot
=
info
[
'n_mit_sot'
]
n_mit_sot
=
info
[
'n_mit_sot'
]
self
.
mit_mot_in_slices
=
info
[
'tap_array'
][:
n_mit_mot
]
self
.
mit_mot_in_slices
=
info
[
'tap_array'
][:
n_mit_mot
]
self
.
mit_sot_in_slices
=
info
[
'tap_array'
][
n_mit_mot
:
n_mit_mot
+
n_mit_sot
]
self
.
mit_sot_in_slices
=
info
[
'tap_array'
][
n_mit_mot
:
n_mit_mot
+
n_mit_sot
]
n_mit_mot_ins
=
sum
(
len
(
s
)
for
s
in
self
.
mit_mot_in_slices
)
n_mit_mot_ins
=
sum
(
len
(
s
)
for
s
in
self
.
mit_mot_in_slices
)
n_mit_sot_ins
=
sum
(
len
(
s
)
for
s
in
self
.
mit_sot_in_slices
)
n_mit_sot_ins
=
sum
(
len
(
s
)
for
s
in
self
.
mit_sot_in_slices
)
iimm
=
inner_inputs
[
q
:
q
+
n_mit_mot_ins
]
iimm
=
inner_inputs
[
q
:
q
+
n_mit_mot_ins
]
self
.
inner_in_mit_mot
=
[]
self
.
inner_in_mit_mot
=
[]
qq
=
0
qq
=
0
for
sl
in
self
.
mit_mot_in_slices
:
for
sl
in
self
.
mit_mot_in_slices
:
self
.
inner_in_mit_mot
.
append
(
iimm
[
qq
:
qq
+
len
(
sl
)])
self
.
inner_in_mit_mot
.
append
(
iimm
[
qq
:
qq
+
len
(
sl
)])
qq
+=
len
(
sl
)
qq
+=
len
(
sl
)
q
+=
n_mit_mot_ins
q
+=
n_mit_mot_ins
iims
=
inner_inputs
[
q
:
q
+
n_mit_sot_ins
]
iims
=
inner_inputs
[
q
:
q
+
n_mit_sot_ins
]
self
.
inner_in_mit_sot
=
[]
self
.
inner_in_mit_sot
=
[]
qq
=
0
qq
=
0
for
sl
in
self
.
mit_sot_in_slices
:
for
sl
in
self
.
mit_sot_in_slices
:
self
.
inner_in_mit_sot
.
append
(
iims
[
qq
:
qq
+
len
(
sl
)])
self
.
inner_in_mit_sot
.
append
(
iims
[
qq
:
qq
+
len
(
sl
)])
qq
+=
len
(
sl
)
qq
+=
len
(
sl
)
q
+=
n_mit_sot_ins
q
+=
n_mit_sot_ins
self
.
outer_in_mit_mot
=
outer_inputs
[
p
:
p
+
n_mit_mot
]
self
.
outer_in_mit_mot
=
outer_inputs
[
p
:
p
+
n_mit_mot
]
p
+=
n_mit_mot
p
+=
n_mit_mot
self
.
outer_in_mit_sot
=
outer_inputs
[
p
:
p
+
n_mit_sot
]
self
.
outer_in_mit_sot
=
outer_inputs
[
p
:
p
+
n_mit_sot
]
p
+=
n_mit_sot
p
+=
n_mit_sot
n_sit_sot
=
info
[
'n_sit_sot'
]
n_sit_sot
=
info
[
'n_sit_sot'
]
self
.
outer_in_sit_sot
=
outer_inputs
[
p
:
p
+
n_sit_sot
]
self
.
outer_in_sit_sot
=
outer_inputs
[
p
:
p
+
n_sit_sot
]
self
.
inner_in_sit_sot
=
inner_inputs
[
q
:
q
+
n_sit_sot
]
self
.
inner_in_sit_sot
=
inner_inputs
[
q
:
q
+
n_sit_sot
]
p
+=
n_sit_sot
p
+=
n_sit_sot
q
+=
n_sit_sot
q
+=
n_sit_sot
n_shared_outs
=
info
[
'n_shared_outs'
]
n_shared_outs
=
info
[
'n_shared_outs'
]
self
.
outer_in_shared
=
outer_inputs
[
p
:
p
+
n_shared_outs
]
self
.
outer_in_shared
=
outer_inputs
[
p
:
p
+
n_shared_outs
]
self
.
inner_in_shared
=
inner_inputs
[
q
:
q
+
n_shared_outs
]
self
.
inner_in_shared
=
inner_inputs
[
q
:
q
+
n_shared_outs
]
p
+=
n_shared_outs
p
+=
n_shared_outs
q
+=
n_shared_outs
q
+=
n_shared_outs
n_nit_sot
=
info
[
'n_nit_sot'
]
n_nit_sot
=
info
[
'n_nit_sot'
]
self
.
outer_in_nit_sot
=
outer_inputs
[
p
:
p
+
n_nit_sot
]
self
.
outer_in_nit_sot
=
outer_inputs
[
p
:
p
+
n_nit_sot
]
p
+=
n_nit_sot
p
+=
n_nit_sot
self
.
outer_in_non_seqs
=
outer_inputs
[
p
:]
self
.
outer_in_non_seqs
=
outer_inputs
[
p
:]
...
@@ -780,40 +779,39 @@ class scan_args(object):
...
@@ -780,40 +779,39 @@ class scan_args(object):
self
.
mit_mot_out_slices
=
info
[
'mit_mot_out_slices'
]
self
.
mit_mot_out_slices
=
info
[
'mit_mot_out_slices'
]
n_mit_mot_outs
=
info
[
'n_mit_mot_outs'
]
n_mit_mot_outs
=
info
[
'n_mit_mot_outs'
]
self
.
outer_out_mit_mot
=
outer_outputs
[
p
:
p
+
n_mit_mot
]
self
.
outer_out_mit_mot
=
outer_outputs
[
p
:
p
+
n_mit_mot
]
iomm
=
inner_outputs
[
q
:
q
+
n_mit_mot_outs
]
iomm
=
inner_outputs
[
q
:
q
+
n_mit_mot_outs
]
self
.
inner_out_mit_mot
=
[]
self
.
inner_out_mit_mot
=
[]
qq
=
0
qq
=
0
for
sl
in
self
.
mit_mot_out_slices
:
for
sl
in
self
.
mit_mot_out_slices
:
self
.
inner_out_mit_mot
.
append
(
iomm
[
qq
:
qq
+
len
(
sl
)])
self
.
inner_out_mit_mot
.
append
(
iomm
[
qq
:
qq
+
len
(
sl
)])
qq
+=
len
(
sl
)
qq
+=
len
(
sl
)
p
+=
n_mit_mot
p
+=
n_mit_mot
q
+=
n_mit_mot_outs
q
+=
n_mit_mot_outs
self
.
outer_out_mit_sot
=
outer_outputs
[
p
:
p
+
n_mit_sot
]
self
.
outer_out_mit_sot
=
outer_outputs
[
p
:
p
+
n_mit_sot
]
self
.
inner_out_mit_sot
=
inner_outputs
[
q
:
q
+
n_mit_sot
]
self
.
inner_out_mit_sot
=
inner_outputs
[
q
:
q
+
n_mit_sot
]
p
+=
n_mit_sot
p
+=
n_mit_sot
q
+=
n_mit_sot
q
+=
n_mit_sot
self
.
outer_out_sit_sot
=
outer_outputs
[
p
:
p
+
n_sit_sot
]
self
.
outer_out_sit_sot
=
outer_outputs
[
p
:
p
+
n_sit_sot
]
self
.
inner_out_sit_sot
=
inner_outputs
[
q
:
q
+
n_sit_sot
]
self
.
inner_out_sit_sot
=
inner_outputs
[
q
:
q
+
n_sit_sot
]
p
+=
n_sit_sot
p
+=
n_sit_sot
q
+=
n_sit_sot
q
+=
n_sit_sot
self
.
outer_out_nit_sot
=
outer_outputs
[
p
:
p
+
n_nit_sot
]
self
.
outer_out_nit_sot
=
outer_outputs
[
p
:
p
+
n_nit_sot
]
self
.
inner_out_nit_sot
=
inner_outputs
[
q
:
q
+
n_nit_sot
]
self
.
inner_out_nit_sot
=
inner_outputs
[
q
:
q
+
n_nit_sot
]
p
+=
n_nit_sot
p
+=
n_nit_sot
q
+=
n_nit_sot
q
+=
n_nit_sot
self
.
outer_out_shared
=
outer_outputs
[
p
:
p
+
n_shared_outs
]
self
.
outer_out_shared
=
outer_outputs
[
p
:
p
+
n_shared_outs
]
self
.
inner_out_shared
=
inner_outputs
[
q
:
q
+
n_shared_outs
]
self
.
inner_out_shared
=
inner_outputs
[
q
:
q
+
n_shared_outs
]
p
+=
n_shared_outs
p
+=
n_shared_outs
q
+=
n_shared_outs
q
+=
n_shared_outs
self
.
other_info
=
dict
()
self
.
other_info
=
dict
()
for
k
in
(
'truncate_gradient'
,
'name'
,
'mode'
,
'inplace'
,
for
k
in
(
'truncate_gradient'
,
'name'
,
'mode'
,
'inplace'
,
'gpu'
,
'as_while'
,
'profile'
):
'gpu'
,
'as_while'
,
'profile'
):
self
.
other_info
[
k
]
=
info
[
k
]
self
.
other_info
[
k
]
=
info
[
k
]
inner_inputs
=
property
(
lambda
self
:
(
self
.
inner_in_seqs
+
inner_inputs
=
property
(
lambda
self
:
(
self
.
inner_in_seqs
+
...
@@ -844,18 +842,19 @@ class scan_args(object):
...
@@ -844,18 +842,19 @@ class scan_args(object):
self
.
outer_out_nit_sot
+
self
.
outer_out_nit_sot
+
self
.
outer_out_shared
))
self
.
outer_out_shared
))
info
=
property
(
lambda
self
:
dict
(
n_seqs
=
len
(
self
.
outer_in_seqs
),
info
=
property
(
lambda
self
:
dict
(
n_mit_mot
=
len
(
self
.
outer_in_mit_mot
),
n_seqs
=
len
(
self
.
outer_in_seqs
),
n_mit_sot
=
len
(
self
.
outer_in_mit_sot
),
n_mit_mot
=
len
(
self
.
outer_in_mit_mot
),
tap_array
=
(
self
.
mit_mot_in_slices
+
n_mit_sot
=
len
(
self
.
outer_in_mit_sot
),
self
.
mit_sot_in_slices
+
tap_array
=
(
self
.
mit_mot_in_slices
+
[[
-
1
]]
*
len
(
self
.
inner_in_sit_sot
)),
self
.
mit_sot_in_slices
+
n_sit_sot
=
len
(
self
.
outer_in_sit_sot
),
[[
-
1
]]
*
len
(
self
.
inner_in_sit_sot
)),
n_nit_sot
=
len
(
self
.
outer_in_nit_sot
),
n_sit_sot
=
len
(
self
.
outer_in_sit_sot
),
n_shared_outs
=
len
(
self
.
outer_in_shared
),
n_nit_sot
=
len
(
self
.
outer_in_nit_sot
),
n_mit_mot_outs
=
sum
(
len
(
s
)
for
s
in
self
.
mit_mot_out_slices
),
n_shared_outs
=
len
(
self
.
outer_in_shared
),
mit_mot_out_slices
=
self
.
mit_mot_out_slices
,
n_mit_mot_outs
=
sum
(
len
(
s
)
for
s
in
self
.
mit_mot_out_slices
),
**
self
.
other_info
))
mit_mot_out_slices
=
self
.
mit_mot_out_slices
,
**
self
.
other_info
))
def
__copy__
(
self
):
def
__copy__
(
self
):
res
=
object
.
__new__
(
type
(
self
))
res
=
object
.
__new__
(
type
(
self
))
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论