Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7cb2384c
提交
7cb2384c
authored
1月 23, 2012
作者:
nouiz
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #371 from pascanur/fixed_grad_of_grad_of_scan
Fixed grad of grad of scan
上级
70c008b8
15812503
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
431 行增加
和
383 行删除
+431
-383
__init__.py
theano/scan_module/__init__.py
+5
-5
scan_op.py
theano/scan_module/scan_op.py
+337
-311
scan_opt.py
theano/scan_module/scan_opt.py
+5
-3
scan_views.py
theano/scan_module/scan_views.py
+62
-62
test_scan.py
theano/scan_module/tests/test_scan.py
+22
-2
没有找到文件。
theano/scan_module/__init__.py
浏览文件 @
7cb2384c
...
@@ -30,11 +30,11 @@ functions: ``scan()``, ``map()``, ``reduce()``, ``foldl()``,
...
@@ -30,11 +30,11 @@ functions: ``scan()``, ``map()``, ``reduce()``, ``foldl()``,
__docformat__
=
'restructedtext en'
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"Frederic Bastien "
"James Bergstra "
"James Bergstra "
"Pascal Lamblin "
"Pascal Lamblin "
"Arnaud Bergeron "
)
"Arnaud Bergeron "
)
__copyright__
=
"(c) 2010, Universite de Montreal"
__copyright__
=
"(c) 2010, Universite de Montreal"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
...
...
theano/scan_module/scan_op.py
浏览文件 @
7cb2384c
...
@@ -179,26 +179,30 @@ class Scan(PureOp):
...
@@ -179,26 +179,30 @@ class Scan(PureOp):
err_msg1
=
(
'When compiling the inner function of scan the '
err_msg1
=
(
'When compiling the inner function of scan the '
'following error has been encountered: The '
'following error has been encountered: The '
'
%
s
%
s (argument number
%
d) has dtype '
'
%
s
%
s (argument number
%
d) has dtype '
'
%
s. The corresponding slice
%
s however has'
'
%
s and
%
d dimension(s). The corresponding slice
%
s '
' dtype
%
s. This should never happen, please '
'however has dtype
%
s and
%
d dimension(s). This '
'should never happen, please '
'report to theano-dev mailing list'
'report to theano-dev mailing list'
)
)
err_msg2
=
(
'When compiling the inner function of scan the '
err_msg2
=
(
'When compiling the inner function of scan the '
'following error has been encountered: The '
'following error has been encountered: The '
'initial state (outputs_info in scan nomenclature)'
'initial state (outputs_info in scan nomenclature)'
'of variable
%
s (argument number
%
d)'
'of variable
%
s (argument number
%
d)'
' has dtype
%
s
, while the result of the
'
' has dtype
%
s
and
%
d dimension(s), while the result
'
'
inner function for this output has dtype
%
s. Thi
s '
'
of the inner function for this output has dtype
%
s '
'
could happen if the inner graph of scan results in
'
'
and
%
d dimension(s). This could happen if the inner
'
'
an upcast or downcast. Please make sure that you use
'
'
graph of scan results in an upcast or downcast.
'
'dtypes consistently'
)
'
Please make sure that you use
dtypes consistently'
)
# TODO make the assert exact
# TODO make the assert exact
# TODO assert the type(dtype, nbdim of self.inputs and inputs correspond)
# TODO assert the type(dtype, nbdim of self.inputs and
# inputs correspond)
#assert len(inputs) >= len(self.inputs)
#assert len(inputs) >= len(self.inputs)
# if self.info['as_while']:
#if self.info['as_while']:
# assert len(inputs) == len(self.inputs) + 2 + self.info["n_nit_sot"]
# assert len(inputs) == len(self.inputs) + 2 + \
# else:
# self.info["n_nit_sot"]
# assert len(inputs) == len(self.inputs) + 1 + self.info["n_nit_sot"]
#else:
# assert len(inputs) == len(self.inputs) + 1 + \
# self.info["n_nit_sot"]
# Flags that indicate which inputs are vectors
# Flags that indicate which inputs are vectors
self
.
vector_seqs
=
[
seq
.
ndim
==
1
for
seq
in
self
.
vector_seqs
=
[
seq
.
ndim
==
1
for
seq
in
...
@@ -235,26 +239,27 @@ class Scan(PureOp):
...
@@ -235,26 +239,27 @@ class Scan(PureOp):
self
.
mitmot_out_taps
(),
self
.
mitmot_out_taps
(),
self
.
outer_mitmot
(
inputs
))):
self
.
outer_mitmot
(
inputs
))):
for
k
in
xrange
(
len
(
itaps
)):
for
k
in
xrange
(
len
(
itaps
)):
if
(
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
!=
if
(
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
!=
outer_mitmot
.
type
.
dtype
or
outer_mitmot
.
type
.
dtype
or
inner_mitmot
[
ipos
+
k
]
.
ndim
!=
outer_mitmot
.
ndim
-
1
):
inner_mitmot
[
ipos
+
k
]
.
ndim
!=
outer_mitmot
.
ndim
-
1
):
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
' in scan nomenclature) '
,
str
(
outer_mitmot
),
str
(
outer_mitmot
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_mitmot
.
type
.
dtype
,
outer_mitmot
.
type
.
dtype
,
str
(
inner_mitmot
[
ipos
+
k
]),
str
(
inner_mitmot
[
ipos
+
k
]),
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
))
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
))
ipos
+=
len
(
itaps
)
ipos
+=
len
(
itaps
)
for
k
in
xrange
(
len
(
otaps
)):
for
k
in
xrange
(
len
(
otaps
)):
if
(
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
!=
if
(
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
!=
\
outer_mitmot
.
type
.
dtype
or
outer_mitmot
.
type
.
dtype
or
inner_mitmot_outs
[
opos
+
k
]
.
ndim
!=
outer_mitmot
.
ndim
-
1
):
inner_mitmot_outs
[
opos
+
k
]
.
ndim
!=
\
raise
ValueError
(
err_msg2
%
outer_mitmot
.
ndim
-
1
):
raise
ValueError
(
err_msg2
%
(
str
(
outer_mitmot
),
(
str
(
outer_mitmot
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_mitmot
.
type
.
dtype
,
outer_mitmot
.
type
.
dtype
,
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
))
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
))
opos
+=
len
(
otaps
)
opos
+=
len
(
otaps
)
argoffset
+=
len
(
self
.
outer_mitmot
(
inputs
))
argoffset
+=
len
(
self
.
outer_mitmot
(
inputs
))
# Same checks as above but for outputs of type mit_sot
# Same checks as above but for outputs of type mit_sot
...
@@ -265,24 +270,28 @@ class Scan(PureOp):
...
@@ -265,24 +270,28 @@ class Scan(PureOp):
self
.
outer_mitsot
(
inputs
),
self
.
outer_mitsot
(
inputs
),
self
.
inner_mitsot_outs
(
self
.
outputs
))):
self
.
inner_mitsot_outs
(
self
.
outputs
))):
for
k
in
xrange
(
len
(
itaps
)):
for
k
in
xrange
(
len
(
itaps
)):
if
(
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
!=
if
(
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
!=
\
outer_mitsot
.
type
.
dtype
or
outer_mitsot
.
type
.
dtype
or
inner_mitsots
[
ipos
+
k
]
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
inner_mitsots
[
ipos
+
k
]
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
' in scan nomenclature) '
,
str
(
outer_mitsot
),
str
(
outer_mitsot
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_mitsot
.
type
.
dtype
,
outer_mitsot
.
type
.
dtype
,
str
(
inner_mitsot
[
ipos
+
k
]),
otuer_mitsot
.
type
.
ndim
,
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
))
str
(
inner_mitsot
[
ipos
+
k
]),
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
,
inner_mitsots
[
ipos
+
k
]
.
type
.
ndim
))
ipos
+=
len
(
itaps
)
ipos
+=
len
(
itaps
)
if
(
inner_mitsot_out
.
type
.
dtype
!=
outer_mitsot
.
type
.
dtype
or
if
(
inner_mitsot_out
.
type
.
dtype
!=
outer_mitsot
.
type
.
dtype
or
inner_mitsot_out
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
inner_mitsot_out
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg2
%
raise
ValueError
(
err_msg2
%
(
str
(
outer_mitsot
),
(
str
(
outer_mitsot
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_mitsot
.
type
.
dtype
,
outer_mitsot
.
type
.
dtype
,
inner_mitsot_out
.
type
.
dtype
))
outer_mitsot
.
type
.
ndim
,
inner_mitsot_out
.
type
.
dtype
,
inner_mitsot_out
.
type
.
ndim
))
argoffset
+=
len
(
self
.
outer_mitsot
(
inputs
))
argoffset
+=
len
(
self
.
outer_mitsot
(
inputs
))
# Same checks as above but for outputs of type sit_sot
# Same checks as above but for outputs of type sit_sot
...
@@ -297,15 +306,19 @@ class Scan(PureOp):
...
@@ -297,15 +306,19 @@ class Scan(PureOp):
str
(
outer_sitsot
),
str
(
outer_sitsot
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_sitsot
.
type
.
dtype
,
outer_sitsot
.
type
.
dtype
,
outer_sitsot
.
type
.
ndim
,
str
(
inner_sitsot
),
str
(
inner_sitsot
),
inner_sitsot
.
type
.
dtype
))
inner_sitsot
.
type
.
dtype
,
inner_sitsot
.
type
.
ndim
))
if
(
inner_sitsot_out
.
type
.
dtype
!=
outer_sitsot
.
type
.
dtype
or
if
(
inner_sitsot_out
.
type
.
dtype
!=
outer_sitsot
.
type
.
dtype
or
inner_sitsot_out
.
ndim
!=
outer_sitsot
.
ndim
-
1
):
inner_sitsot_out
.
ndim
!=
outer_sitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg2
%
raise
ValueError
(
err_msg2
%
(
str
(
outer_sitsot
),
(
str
(
outer_sitsot
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_sitsot
.
type
.
dtype
,
outer_sitsot
.
type
.
dtype
,
inner_sitsot_out
.
type
.
dtype
))
outer_sitsot
.
type
.
ndim
,
inner_sitsot_out
.
type
.
dtype
,
inner_sitsot_out
.
type
.
ndim
))
argoffset
+=
len
(
self
.
outer_sitsot
(
inputs
))
argoffset
+=
len
(
self
.
outer_sitsot
(
inputs
))
# Check that the shared variable and their update rule have the same
# Check that the shared variable and their update rule have the same
...
@@ -320,7 +333,9 @@ class Scan(PureOp):
...
@@ -320,7 +333,9 @@ class Scan(PureOp):
raise
ValueError
(
err_msg2
%
(
str
(
outer_shared
),
raise
ValueError
(
err_msg2
%
(
str
(
outer_shared
),
idx
+
argoffset
,
idx
+
argoffset
,
outer_shared
.
dtype
,
outer_shared
.
dtype
,
inner_shared_out
.
dtype
))
outer_shared
.
ndim
,
inner_shared_out
.
dtype
,
inner_shared_out
.
ndim
))
if
(
hasattr
(
outer_shared
,
'dtype'
)
and
if
(
hasattr
(
outer_shared
,
'dtype'
)
and
(
outer_shared
.
dtype
!=
inner_shared
.
dtype
or
(
outer_shared
.
dtype
!=
inner_shared
.
dtype
or
...
@@ -330,8 +345,10 @@ class Scan(PureOp):
...
@@ -330,8 +345,10 @@ class Scan(PureOp):
str
(
outer_shared
),
str
(
outer_shared
),
argoffset
+
idx
,
argoffset
+
idx
,
outer_shared
.
dtype
,
outer_shared
.
dtype
,
outer_shared
.
ndim
,
str
(
inner_shared
),
str
(
inner_shared
),
inner_shared
.
dtype
))
inner_shared
.
dtype
,
inner_shared
.
ndim
))
for
inner_nonseq
,
outer_nonseq
in
zip
(
for
inner_nonseq
,
outer_nonseq
in
zip
(
self
.
inner_non_seqs
(
self
.
inputs
),
self
.
inner_non_seqs
(
self
.
inputs
),
self
.
outer_non_seqs
(
inputs
)):
self
.
outer_non_seqs
(
inputs
)):
...
@@ -339,7 +356,7 @@ class Scan(PureOp):
...
@@ -339,7 +356,7 @@ class Scan(PureOp):
inner_nonseq
.
type
.
ndim
!=
outer_nonseq
.
type
.
ndim
):
inner_nonseq
.
type
.
ndim
!=
outer_nonseq
.
type
.
ndim
):
raise
ValueError
((
'Argument
%
s given to scan node does not'
raise
ValueError
((
'Argument
%
s given to scan node does not'
' match its correspondance
%
s'
)
%
' match its correspondance
%
s'
)
%
(
str
(
outer_nonseq
),
str
(
inner_nonseq
)))
(
str
(
outer_nonseq
),
str
(
inner_nonseq
)))
for
outer_nitsot
in
self
.
outer_nitsot
(
inputs
):
for
outer_nitsot
in
self
.
outer_nitsot
(
inputs
):
# For every nit_sot input we get as input a int/uint that
# For every nit_sot input we get as input a int/uint that
...
@@ -1073,8 +1090,9 @@ class Scan(PureOp):
...
@@ -1073,8 +1090,9 @@ class Scan(PureOp):
offset
=
1
+
self
.
n_seqs
offset
=
1
+
self
.
n_seqs
scan_outs
=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
n_outs
]]
scan_outs
=
[
x
for
x
in
input_shapes
[
offset
:
offset
+
n_outs
]]
offset
+=
n_outs
offset
+=
n_outs
outs_shape_n
=
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_sit_sot
for
x
in
xrange
(
self
.
n_nit_sot
):
for
x
in
xrange
(
self
.
n_nit_sot
):
out_shape_x
=
outs_shape
[
n_outs
+
x
]
out_shape_x
=
outs_shape
[
outs_shape_n
+
x
]
if
out_shape_x
is
None
:
if
out_shape_x
is
None
:
# This output is not a tensor, and has no shape
# This output is not a tensor, and has no shape
scan_outs
.
append
(
None
)
scan_outs
.
append
(
None
)
...
@@ -1106,18 +1124,12 @@ class Scan(PureOp):
...
@@ -1106,18 +1124,12 @@ class Scan(PureOp):
# if we are dealing with a repeat-until, then we do not know the
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
# leading dimension so we replace it for every entry with Shape_i
if
self
.
as_while
:
if
self
.
as_while
:
scan_outs
=
[(
Shape_i
(
0
)(
o
),)
+
x
[
1
:]
scan_outs
=
[(
Shape_i
(
0
)(
o
),)
+
x
[
1
:]
for
o
,
x
in
izip
(
node
.
outputs
,
scan_outs
)]
for
o
,
x
in
izip
(
node
.
outputs
,
scan_outs
)]
return
scan_outs
return
scan_outs
### GRAD FUNCTION
### GRAD FUNCTION
def
grad
(
self
,
args
,
g_outs
):
def
grad
(
self
,
args
,
g_outs
):
if
'computed_grad'
in
self
.
info
:
raise
ValueError
((
'Computing gradients through the gradients '
'of a scan node can be wrong. For now Theano '
'will not allow you to do so, until the '
'possible bug is fixed'
))
# 1. forward pass - get the outputs after applying scan
# 1. forward pass - get the outputs after applying scan
scan_outputs
=
self
(
*
args
)
scan_outputs
=
self
(
*
args
)
# 2. make sure they are given as a list
# 2. make sure they are given as a list
...
@@ -1139,79 +1151,80 @@ class Scan(PureOp):
...
@@ -1139,79 +1151,80 @@ class Scan(PureOp):
in
xrange
(
self
.
n_mit_mot
)])
in
xrange
(
self
.
n_mit_mot
)])
outs_mit_mot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_mot
]
outs_mit_mot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_mot
]
offset
+=
n_ins_mit_mot
offset
+=
n_ins_mit_mot
n_ins_mit_sot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
n_ins_mit_sot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
in
xrange
(
self
.
n_mit_mot
in
xrange
(
self
.
n_mit_mot
,
,
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
outs_mit_sot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_sot
]
outs_mit_sot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_sot
]
offset
+=
n_ins_mit_sot
offset
+=
n_ins_mit_sot
outs_sit_sot
=
self_inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
outs_sit_sot
=
self_inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
offset
+=
self
.
n_sit_sot
offset
+=
self
.
n_sit_sot
old_scan_shared_ins
=
self_inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
old_scan_shared_ins
=
self_inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
out_offset
=
(
self
.
n_mit_mot_outs
out_offset
=
(
self
.
n_mit_mot_outs
+
+
self
.
n_mit_sot
self
.
n_mit_sot
+
+
self
.
n_nit_sot
self
.
n_nit_sot
+
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
# shared variables as well as the condition
# shared variables as well as the condition
old_scan_shared_outs
=
self_outputs
[
out_offset
:]
old_scan_shared_outs
=
self_outputs
[
out_offset
:]
arg_offset
=
(
1
arg_offset
=
(
1
+
+
self
.
n_seqs
self
.
n_seqs
+
+
self
.
n_mit_mot
self
.
n_mit_mot
+
+
self
.
n_mit_sot
self
.
n_mit_sot
+
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
old_scan_init
=
args
[
arg_offset
:
arg_offset
+
self
.
n_shared_outs
]
old_scan_init
=
args
[
arg_offset
:
arg_offset
+
self
.
n_shared_outs
]
offset
+=
self
.
n_shared_outs
offset
+=
self
.
n_shared_outs
other_args
=
self_inputs
[
offset
:]
other_args
=
self_inputs
[
offset
:]
# 4. Collect (possibly) differentiable inputs
# 4. Collect (possibly) differentiable inputs
diff_inputs
=
(
seqs
+
diff_inputs
=
(
seqs
+
outs_mit_mot
+
outs_mit_mot
+
outs_mit_sot
+
outs_mit_sot
+
outs_sit_sot
+
outs_sit_sot
+
other_args
)
other_args
)
#args[-len(other_args):] )
#args[-len(other_args):] )
# 5. construct the function that computes the gradient (we sum over
# 5. construct the function that computes the gradient (we sum over
# the gradients with respect to all outputs)
# the gradients with respect to all outputs)
def
compute_gradient
(
y
,
g_y
):
def
compute_gradient
(
y
,
g_y
):
gmp
=
gradient
.
grad_sources_inputs
(
gmp
=
gradient
.
grad_sources_inputs
(
[(
y
,
g_y
)],
diff_inputs
,
False
)
[(
y
,
g_y
)],
diff_inputs
,
False
)
return
[
gmp
.
get
(
p
,
None
)
for
p
in
diff_inputs
]
return
[
gmp
.
get
(
p
,
None
)
for
p
in
diff_inputs
]
# 6. clean the outputs (i.e. remove update rules)
# 6. clean the outputs (i.e. remove update rules)
end
=
(
self
.
n_mit_mot_outs
end
=
(
self
.
n_mit_mot_outs
+
+
self
.
n_mit_sot
self
.
n_mit_sot
+
+
self
.
n_sit_sot
self
.
n_sit_sot
+
+
self
.
n_nit_sot
)
self
.
n_nit_sot
)
clean_outputs
=
self_outputs
[:
end
]
clean_outputs
=
self_outputs
[:
end
]
g_outs_no_shared
=
g_outs
[:
end
]
g_outs_no_shared
=
g_outs
[:
end
]
# 7.1. empty lists to hold gradients
# 7.1. empty lists to hold gradients
# List of slices from outputs (used to compute the gradients)
# List of slices from outputs (used to compute the gradients)
inner_g_outs
=
[]
inner_g_outs
=
[]
g_out_slices
=
[]
g_out_slices
=
[]
# List of outputs of the gradient function
# List of outputs of the gradient function
inner_gfn_outs
=
[]
inner_gfn_outs
=
[]
# slices of the input
# slices of the input
prev_inner_gfn_outs
=
[]
prev_inner_gfn_outs
=
[]
zeros_like_diff_ins
=
[]
zeros_like_diff_ins
=
[]
pos
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_sot
+
pos
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_sot
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
offset
=
len
(
args
)
-
len
(
other_args
)
-
pos
offset
=
len
(
args
)
-
len
(
other_args
)
-
pos
# 7.2. generate variables to represent previous steps of g_outs
# 7.2. generate variables to represent previous steps of g_outs
for
idx
,
diff_in
in
enumerate
(
diff_inputs
):
for
idx
,
diff_in
in
enumerate
(
diff_inputs
):
prev_gfn_out
=
safe_new
(
diff_in
)
prev_gfn_out
=
safe_new
(
diff_in
)
if
hasattr
(
diff_in
,
'name'
)
and
diff_in
.
name
:
if
hasattr
(
diff_in
,
'name'
)
and
diff_in
.
name
:
prev_gfn_out
.
name
=
'g_prev_'
+
diff_in
.
name
prev_gfn_out
.
name
=
'g_prev_'
+
diff_in
.
name
else
:
else
:
prev_gfn_out
.
name
=
'g_prev_'
+
str
(
idx
)
prev_gfn_out
.
name
=
'g_prev_'
+
str
(
idx
)
prev_inner_gfn_outs
.
append
(
prev_gfn_out
)
prev_inner_gfn_outs
.
append
(
prev_gfn_out
)
if
idx
<
pos
:
if
idx
<
pos
:
zeros_like_diff_ins
.
append
(
tensor
.
zeros_like
(
diff_in
))
zeros_like_diff_ins
.
append
(
tensor
.
zeros_like
(
diff_in
))
else
:
else
:
zeros_like_diff_ins
.
append
(
tensor
.
zeros_like
(
args
[
idx
+
offset
]))
zeros_like_diff_ins
.
append
(
tensor
.
zeros_like
(
args
[
idx
+
offset
]))
# 7.3. compute gradients of the inputs given one output
# 7.3. compute gradients of the inputs given one output
for
dx
,
out
in
enumerate
(
clean_outputs
):
for
dx
,
out
in
enumerate
(
clean_outputs
):
...
@@ -1219,32 +1232,30 @@ class Scan(PureOp):
...
@@ -1219,32 +1232,30 @@ class Scan(PureOp):
###
###
#### I need to clip the gradient HERE !!
#### I need to clip the gradient HERE !!
if
g_outs_no_shared
[
dx
]:
if
g_outs_no_shared
[
dx
]:
g_out_slices
.
append
(
g_outs_no_shared
[
dx
][
0
])
g_out_slices
.
append
(
g_outs_no_shared
[
dx
][
0
])
else
:
else
:
g_out_slices
.
append
(
None
)
g_out_slices
.
append
(
None
)
if
getattr
(
out
,
'name'
,
None
)
is
not
None
:
if
getattr
(
out
,
'name'
,
None
)
is
not
None
:
inner_g_out
.
name
=
'g_'
+
out
.
name
inner_g_out
.
name
=
'g_'
+
out
.
name
else
:
else
:
inner_g_out
.
name
=
'g_'
+
str
(
dx
)
inner_g_out
.
name
=
'g_'
+
str
(
dx
)
inner_g_outs
.
append
(
inner_g_out
)
inner_g_outs
.
append
(
inner_g_out
)
_g_out
=
inner_g_out
_g_out
=
inner_g_out
grad_outs
=
compute_gradient
(
out
,
_g_out
)
grad_outs
=
compute_gradient
(
out
,
_g_out
)
if
not
inner_gfn_outs
:
if
not
inner_gfn_outs
:
for
idx
,
gfn_out
in
enumerate
(
grad_outs
):
for
idx
,
gfn_out
in
enumerate
(
grad_outs
):
if
idx
>=
self
.
n_seqs
:
if
idx
>=
self
.
n_seqs
:
inner_gfn_outs
.
append
(
prev_inner_gfn_outs
[
idx
]
)
inner_gfn_outs
.
append
(
prev_inner_gfn_outs
[
idx
]
)
else
:
else
:
inner_gfn_outs
.
append
(
None
)
inner_gfn_outs
.
append
(
None
)
# 7.4 Sum the gradients
# 7.4 Sum the gradients
# safety check, some of this inputs might still not be
# safety check, some of this inputs might still not be
# differentiable, for those we don't add them to the mix
# differentiable, for those we don't add them to the mix
# (assume their gradient is 0)
# (assume their gradient is 0)
for
i
,
(
x
,
y
)
in
enumerate
(
zip
(
grad_outs
,
inner_gfn_outs
)):
for
i
,
(
x
,
y
)
in
enumerate
(
zip
(
grad_outs
,
inner_gfn_outs
)):
if
x
and
y
:
if
x
and
y
:
inner_gfn_outs
[
i
]
=
x
+
y
inner_gfn_outs
[
i
]
=
x
+
y
elif
y
:
elif
y
:
inner_gfn_outs
[
i
]
=
y
inner_gfn_outs
[
i
]
=
y
else
:
else
:
...
@@ -1268,28 +1279,27 @@ class Scan(PureOp):
...
@@ -1268,28 +1279,27 @@ class Scan(PureOp):
g_outs
[
i
]
=
theano
.
tensor
.
constant
(
g_outs
[
i
]
=
theano
.
tensor
.
constant
(
numpy
.
array
(
0
,
theano
.
config
.
floatX
))
numpy
.
array
(
0
,
theano
.
config
.
floatX
))
## 10. Get your sequence in order for the scan:
## 10. Get your sequence in order for the scan:
n_seqs
=
(
self
.
n_seqs
+
n_seqs
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_mot
+
n_ins_mit_sot
+
n_ins_mit_sot
+
self
.
n_sit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
self
.
n_nit_sot
)
offset
=
(
self
.
n_mit_mot_outs
+
offset
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
inner_seqs
=
(
seqs
+
inner_seqs
=
(
seqs
+
outs_mit_mot
+
outs_mit_mot
+
outs_mit_sot
+
outs_mit_sot
+
outs_sit_sot
+
outs_sit_sot
+
inner_g_outs
[
offset
:
offset
+
self
.
n_nit_sot
])
inner_g_outs
[
offset
:
offset
+
self
.
n_nit_sot
])
scan_seqs
=
[
x
[::
-
1
]
for
x
in
args
[
1
:
self
.
n_seqs
+
1
]]
scan_seqs
=
[
x
[::
-
1
]
for
x
in
args
[
1
:
self
.
n_seqs
+
1
]]
offset
=
0
offset
=
0
for
idx
in
xrange
(
self
.
n_mit_mot
+
self
.
n_mit_sot
):
for
idx
in
xrange
(
self
.
n_mit_mot
+
self
.
n_mit_sot
):
mintap
=
numpy
.
min
(
self
.
tap_array
[
idx
])
mintap
=
numpy
.
min
(
self
.
tap_array
[
idx
])
maxtap
=
numpy
.
max
(
self
.
tap_array
[
idx
])
maxtap
=
numpy
.
max
(
self
.
tap_array
[
idx
])
seq
=
scan_outputs
[
offset
+
idx
]
seq
=
scan_outputs
[
offset
+
idx
]
for
k
in
self
.
tap_array
[
idx
]:
for
k
in
self
.
tap_array
[
idx
]:
# We cut the sequence such that seq[i] to correspond to
# We cut the sequence such that seq[i] to correspond to
# seq[i-k]
# seq[i-k]
...
@@ -1299,199 +1309,205 @@ class Scan(PureOp):
...
@@ -1299,199 +1309,205 @@ class Scan(PureOp):
dim_offset
=
0
dim_offset
=
0
if
maxtap
==
mintap
and
maxtap
!=
0
:
if
maxtap
==
mintap
and
maxtap
!=
0
:
nw_seq
=
seq
[:
abs
(
maxtap
)]
nw_seq
=
seq
[:
abs
(
maxtap
)]
elif
maxtap
-
k
!=
0
:
elif
maxtap
-
k
!=
0
:
tmp
=
seq
[
dim_offset
+
k
-
mintap
-
1
:
-
(
maxtap
-
k
+
1
)]
nw_seq
=
seq
[
dim_offset
+
k
-
mintap
-
1
:
\
nw_seq
=
tmp
[::
-
1
]
-
(
maxtap
-
k
+
1
)]
[::
-
1
]
else
:
else
:
nw_seq
=
seq
[
dim_offset
+
k
-
mintap
-
1
:
-
1
][::
-
1
]
nw_seq
=
seq
[
dim_offset
+
k
-
mintap
-
1
:
-
1
][::
-
1
]
if
getattr
(
seq
,
'name'
,
None
)
is
not
None
:
if
getattr
(
seq
,
'name'
,
None
)
is
not
None
:
nw_seq
.
name
=
seq
.
name
+
'[
%
d:]'
%
k
nw_seq
.
name
=
seq
.
name
+
'[
%
d:]'
%
k
scan_seqs
.
append
(
nw_seq
)
scan_seqs
.
append
(
nw_seq
)
offset
+=
self
.
n_mit_sot
offset
+=
self
.
n_mit_sot
for
idx
in
xrange
(
self
.
n_sit_sot
):
for
idx
in
xrange
(
self
.
n_sit_sot
):
seq
=
scan_outputs
[
offset
+
idx
][:
-
1
]
seq
=
scan_outputs
[
offset
+
idx
][:
-
1
]
scan_seqs
.
append
(
seq
[::
-
1
])
scan_seqs
.
append
(
seq
[::
-
1
])
offset
=
(
self
.
n_mit_mot_outs
+
offset
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
scan_seqs
+=
[
x
[::
-
1
]
for
x
in
scan_seqs
+=
[
x
[::
-
1
]
for
x
in
g_outs
[
offset
:
offset
+
self
.
n_nit_sot
]]
g_outs
[
offset
:
offset
+
self
.
n_nit_sot
]]
scan_mit_mot
=
[]
scan_mit_mot
=
[]
inner_mit_mot
=
[]
inner_mit_mot
=
[]
scan_mit_mot_outs
=
[]
scan_mit_mot_outs
=
[]
mit_mot_taps
=
[]
mit_mot_taps
=
[]
mit_mot_out_slices
=
[]
mit_mot_out_slices
=
[]
out_pos
=
0
out_pos
=
0
ins_pos
=
n_seqs
ins_pos
=
n_seqs
n_mit_mot_outs
=
0
n_mit_mot_outs
=
0
n_mit_mot_ins
=
0
n_mit_mot_ins
=
0
ins_pos
=
self
.
n_seqs
ins_pos
=
self
.
n_seqs
for
idx
in
xrange
(
self
.
n_mit_mot
):
for
idx
in
xrange
(
self
.
n_mit_mot
):
scan_mit_mot
.
append
(
g_outs
[
idx
][::
-
1
]
)
scan_mit_mot
.
append
(
g_outs
[
idx
][::
-
1
]
)
mit_mot_taps
.
append
([])
mit_mot_taps
.
append
([])
mit_mot_out_slices
.
append
([])
mit_mot_out_slices
.
append
([])
for
jdx
in
xrange
(
len
(
self
.
mit_mot_out_slices
[
idx
])):
for
jdx
in
xrange
(
len
(
self
.
mit_mot_out_slices
[
idx
])):
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
]
)
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
]
)
mit_mot_taps
[
idx
]
.
append
(
mit_mot_taps
[
idx
]
.
append
(
\
-
self
.
mit_mot_out_slices
[
idx
][
jdx
])
-
self
.
mit_mot_out_slices
[
idx
][
jdx
])
n_mit_mot_ins
+=
1
n_mit_mot_ins
+=
1
out_pos
+=
1
out_pos
+=
1
for
jdx
in
xrange
(
len
(
self
.
tap_array
[
idx
])):
for
jdx
in
xrange
(
len
(
self
.
tap_array
[
idx
])):
inner_mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_pos
]
)
inner_mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_pos
]
)
scan_mit_mot_outs
.
append
(
scan_mit_mot_outs
.
append
(
\
inner_gfn_outs
[
ins_pos
]
)
inner_gfn_outs
[
ins_pos
]
)
n_mit_mot_ins
+=
1
n_mit_mot_ins
+=
1
ins_pos
+=
1
ins_pos
+=
1
n_mit_mot_outs
+=
1
n_mit_mot_outs
+=
1
mit_mot_taps
[
idx
]
.
append
(
-
self
.
tap_array
[
idx
][
jdx
])
mit_mot_taps
[
idx
]
.
append
(
-
self
.
tap_array
[
idx
][
jdx
])
mit_mot_out_slices
[
idx
]
.
append
(
mit_mot_out_slices
[
idx
]
.
append
(
\
-
self
.
tap_array
[
idx
][
jdx
]
)
-
self
.
tap_array
[
idx
][
jdx
])
offset
=
self
.
n_mit_mot
offset
=
self
.
n_mit_mot
for
idx
in
xrange
(
self
.
n_mit_sot
):
for
idx
in
xrange
(
self
.
n_mit_sot
):
mit_mot_taps
.
append
([])
mit_mot_taps
.
append
([])
mit_mot_out_slices
.
append
([])
mit_mot_out_slices
.
append
([])
scan_mit_mot
.
append
(
g_outs
[
idx
+
offset
][::
-
1
]
)
scan_mit_mot
.
append
(
g_outs
[
idx
+
offset
][::
-
1
]
)
idx_tap
=
idx
+
self
.
n_mit_mot
idx_tap
=
idx
+
self
.
n_mit_mot
for
jdx
in
xrange
(
len
(
self
.
tap_array
[
idx_tap
])):
for
jdx
in
xrange
(
len
(
self
.
tap_array
[
idx_tap
])):
inner_mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_pos
]
)
inner_mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_pos
]
)
mit_mot_taps
[
idx
+
offset
]
.
append
(
mit_mot_taps
[
idx
+
offset
]
.
append
(
\
-
self
.
tap_array
[
idx_tap
][
jdx
]
)
-
self
.
tap_array
[
idx_tap
][
jdx
])
mit_mot_out_slices
[
idx
]
.
append
(
mit_mot_out_slices
[
idx
]
.
append
(
\
-
self
.
tap_array
[
idx_tap
][
jdx
]
)
-
self
.
tap_array
[
idx_tap
][
jdx
])
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
]
)
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
]
)
n_mit_mot_ins
+=
1
n_mit_mot_ins
+=
1
ins_pos
+=
1
ins_pos
+=
1
n_mit_mot_outs
+=
1
n_mit_mot_outs
+=
1
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
]
)
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
]
)
out_pos
+=
1
out_pos
+=
1
n_mit_mot_ins
+=
1
n_mit_mot_ins
+=
1
mit_mot_taps
[
idx
+
offset
]
.
append
(
0
)
mit_mot_taps
[
idx
+
offset
]
.
append
(
0
)
offset
+=
self
.
n_mit_sot
offset
+=
self
.
n_mit_sot
for
idx
in
xrange
(
self
.
n_sit_sot
):
for
idx
in
xrange
(
self
.
n_sit_sot
):
mit_mot_taps
.
append
([
0
,
1
])
mit_mot_taps
.
append
([
0
,
1
])
mit_mot_out_slices
.
append
([
1
])
mit_mot_out_slices
.
append
([
1
])
scan_mit_mot
.
append
(
g_outs
[
idx
+
offset
][::
-
1
]
)
scan_mit_mot
.
append
(
g_outs
[
idx
+
offset
][::
-
1
]
)
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
])
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
])
inner_mit_mot
+=
[
inner_g_outs
[
out_pos
]
inner_mit_mot
+=
[
inner_g_outs
[
out_pos
],
,
prev_inner_gfn_outs
[
ins_pos
]
]
prev_inner_gfn_outs
[
ins_pos
]
]
n_mit_mot_outs
+=
1
n_mit_mot_outs
+=
1
out_pos
+=
1
out_pos
+=
1
ins_pos
+=
1
ins_pos
+=
1
n_mit_mot_ins
+=
2
n_mit_mot_ins
+=
2
n_nit_sot
=
self
.
n_seqs
n_nit_sot
=
self
.
n_seqs
scan_nit_sot_outs
=
inner_gfn_outs
[:
self
.
n_seqs
]
scan_nit_sot_outs
=
inner_gfn_outs
[:
self
.
n_seqs
]
offset
=
(
self
.
n_seqs
if
self
.
truncate_gradient
!=
-
1
:
+
n_ins_mit_sot
do_steps
=
tensor
.
minimum
(
args
[
0
],
self
.
truncate_gradient
)
+
n_ins_mit_mot
else
:
+
self
.
n_sit_sot
)
do_steps
=
args
[
0
]
n_shared_outs
=
len
(
prev_inner_gfn_outs
[
offset
:])
offset
=
(
self
.
n_seqs
+
scan_shared_ins
=
prev_inner_gfn_outs
[
offset
:]
n_ins_mit_sot
+
scan_shared_init
=
zeros_like_diff_ins
[
offset
:]
n_ins_mit_mot
+
scan_shared_outs
=
inner_gfn_outs
[
offset
:]
self
.
n_sit_sot
)
tap_array
=
mit_mot_taps
# Instead of shared outs use sit_sot
n_sitsot_outs
=
len
(
prev_inner_gfn_outs
[
offset
:])
scan_sitsot_ins
=
prev_inner_gfn_outs
[
offset
:]
scan_sitsot_init
=
[]
for
x
in
zeros_like_diff_ins
[
offset
:]:
shapes
=
[
x
.
shape
[
i
]
for
i
in
xrange
(
x
.
ndim
)]
empty
=
tensor
.
zeros
([
do_steps
+
1
]
+
shapes
,
dtype
=
x
.
dtype
)
scan_sitsot_init
.
append
(
empty
)
scan_sitsot_outs
=
inner_gfn_outs
[
offset
:]
tap_array
=
mit_mot_taps
+
[[
-
1
]
for
k
in
xrange
(
n_sitsot_outs
)]
info
=
{}
info
=
{}
info
[
'n_seqs'
]
=
n_seqs
info
[
'n_seqs'
]
=
n_seqs
info
[
'n_mit_sot'
]
=
0
info
[
'n_mit_sot'
]
=
0
info
[
'tap_array'
]
=
tap_array
info
[
'tap_array'
]
=
tap_array
info
[
'gpu'
]
=
False
info
[
'gpu'
]
=
False
n_mit_mot
=
(
self
.
n_mit_mot
n_mit_mot
=
(
self
.
n_mit_mot
+
+
self
.
n_mit_sot
self
.
n_mit_sot
+
+
self
.
n_sit_sot
)
self
.
n_sit_sot
)
info
[
'n_mit_mot'
]
=
n_mit_mot
info
[
'n_mit_mot'
]
=
n_mit_mot
info
[
'n_mit_mot_outs'
]
=
n_mit_mot_outs
info
[
'n_mit_mot_outs'
]
=
n_mit_mot_outs
info
[
'mit_mot_out_slices'
]
=
mit_mot_out_slices
info
[
'mit_mot_out_slices'
]
=
mit_mot_out_slices
info
[
'truncate_gradient'
]
=
self
.
truncate_gradient
info
[
'truncate_gradient'
]
=
self
.
truncate_gradient
info
[
'n_sit_sot'
]
=
0
info
[
'n_sit_sot'
]
=
n_sitsot_outs
info
[
'n_shared_outs'
]
=
n_shared_outs
+
self
.
n_shared_outs
info
[
'n_shared_outs'
]
=
self
.
n_shared_outs
info
[
'n_nit_sot'
]
=
n_nit_sot
info
[
'n_nit_sot'
]
=
n_nit_sot
info
[
'as_while'
]
=
self
.
as_while
info
[
'as_while'
]
=
self
.
as_while
info
[
'profile'
]
=
self
.
profile
info
[
'profile'
]
=
self
.
profile
if
self
.
name
:
if
self
.
name
:
info
[
'name'
]
=
'grad_of_'
+
self
.
name
info
[
'name'
]
=
'grad_of_'
+
self
.
name
else
:
else
:
info
[
'name'
]
=
None
info
[
'name'
]
=
None
info
[
'mode'
]
=
self
.
mode
info
[
'mode'
]
=
self
.
mode
info
[
'inplace'
]
=
False
info
[
'inplace'
]
=
False
info
[
'computed_grad'
]
=
True
n_mit_sot
=
0
n_mit_sot
=
0
n_sit_sot
=
0
n_sit_sot
=
0
if
self
.
truncate_gradient
!=
-
1
:
do_steps
=
tensor
.
minimum
(
args
[
0
],
self
.
truncate_gradient
)
else
:
do_steps
=
args
[
0
]
offset
=
(
1
offset
=
(
1
+
+
self
.
n_seqs
self
.
n_seqs
+
+
self
.
n_mit_mot
self
.
n_mit_mot
+
+
self
.
n_mit_sot
self
.
n_mit_sot
+
+
self
.
n_sit_sot
self
.
n_sit_sot
+
+
self
.
n_nit_sot
self
.
n_nit_sot
+
+
self
.
n_shared_outs
)
self
.
n_shared_outs
)
scan_inputs
=
(
[
do_steps
]
+
scan_inputs
=
(
[
do_steps
]
+
scan_seqs
+
scan_seqs
+
scan_mit_mot
+
scan_mit_mot
+
scan_s
hared_init
+
scan_s
itsot_init
+
old_scan_init
+
old_scan_init
+
[
args
[
0
]
for
x
in
xrange
(
n_nit_sot
)
]
+
[
args
[
0
]
for
x
in
xrange
(
n_nit_sot
)
]
+
args
[
offset
:]
)
args
[
offset
:])
offset
=
(
self
.
n_seqs
offset
=
(
self
.
n_seqs
+
+
n_ins_mit_mot
n_ins_mit_mot
+
+
n_ins_mit_sot
n_ins_mit_sot
+
+
self
.
n_sit_sot
self
.
n_sit_sot
+
+
self
.
n_shared_outs
)
self
.
n_shared_outs
)
inner_other_args
=
self_inputs
[
offset
:]
inner_other_args
=
self_inputs
[
offset
:]
inner_gfn_ins
=
(
inner_seqs
+
inner_gfn_ins
=
(
inner_seqs
+
inner_mit_mot
+
inner_mit_mot
+
scan_shared_ins
+
scan_sitsot_ins
+
old_scan_shared_ins
+
old_scan_shared_ins
+
inner_other_args
)
inner_other_args
)
inner_gfn_outs
=
(
scan_mit_mot_outs
+
inner_gfn_outs
=
(
scan_mit_mot_outs
+
scan_nit_sot_outs
+
scan_sitsot_outs
+
scan_shared_outs
+
scan_nit_sot_outs
+
old_scan_shared_outs
)
old_scan_shared_outs
)
local_op
=
Scan
(
inner_gfn_ins
,
inner_gfn_outs
,
info
)
local_op
=
Scan
(
inner_gfn_ins
,
inner_gfn_outs
,
info
)
outputs
=
local_op
(
*
scan_inputs
)
outputs
=
local_op
(
*
scan_inputs
)
if
type
(
outputs
)
not
in
(
list
,
tuple
):
if
type
(
outputs
)
not
in
(
list
,
tuple
):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
# Re-order the gradients correctly
# Re-order the gradients correctly
gradients
=
[
None
]
gradients
=
[
None
]
offset
=
(
self
.
n_mit_mot
offset
=
(
self
.
n_mit_mot
+
+
self
.
n_mit_sot
self
.
n_mit_sot
+
+
self
.
n_sit_sot
)
self
.
n_sit_sot
+
gradients
+=
[
x
[::
-
1
]
for
x
in
outputs
[
offset
:
offset
+
self
.
n_seqs
]]
n_sitsot_outs
)
gradients
+=
[
x
[::
-
1
]
for
x
in
outputs
[
offset
:
offset
+
self
.
n_seqs
]]
end
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
end
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
gradients
+=
[
x
[::
-
1
]
for
x
in
outputs
[:
end
]]
gradients
+=
[
x
[::
-
1
]
for
x
in
outputs
[:
end
]]
gradients
+=
[
None
for
x
in
xrange
(
self
.
n_shared_outs
)]
gradients
+=
[
None
for
x
in
xrange
(
self
.
n_shared_outs
)]
gradients
+=
[
None
for
x
in
xrange
(
self
.
n_nit_sot
)
]
gradients
+=
[
None
for
x
in
xrange
(
self
.
n_nit_sot
)
]
begin
=
end
+
self
.
n_seqs
begin
=
end
end
=
begin
+
n_shared
_outs
end
=
begin
+
n_sitsot
_outs
gradients
+=
outputs
[
begin
:
end
]
gradients
+=
[
x
[
-
1
]
for
x
in
outputs
[
begin
:
end
]
]
return
gradients
return
gradients
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
# Step 0. Don't work on the orignal tensor variables
# Step 0. Don't work on the orignal tensor variables
rval
=
scan_utils
.
reconstruct_graph
(
self
.
inputs
,
rval
=
scan_utils
.
reconstruct_graph
(
self
.
inputs
,
self
.
outputs
,
'_rop'
)
self
.
outputs
,
'_rop'
)
self_inputs
=
rval
[
0
]
self_inputs
=
rval
[
0
]
self_outputs
=
rval
[
1
]
self_outputs
=
rval
[
1
]
# Step 1. Compute the R_op of the inner function
# Step 1. Compute the R_op of the inner function
inner_eval_points
=
[
scan_utils
.
safe_new
(
x
,
'_evalpoint'
)
for
x
in
self_inputs
]
inner_eval_points
=
[
scan_utils
.
safe_new
(
x
,
'_evalpoint'
)
for
x
in
self_inputs
]
if
self
.
as_while
:
if
self
.
as_while
:
rop_self_outputs
=
self_outputs
[:
-
1
]
rop_self_outputs
=
self_outputs
[:
-
1
]
else
:
else
:
...
@@ -1510,82 +1526,82 @@ class Scan(PureOp):
...
@@ -1510,82 +1526,82 @@ class Scan(PureOp):
# evan point for the number of nit_sot which I think should just be
# evan point for the number of nit_sot which I think should just be
# ignored (?)
# ignored (?)
info
=
{}
info
=
{}
info
[
'n_seqs'
]
=
self
.
n_seqs
*
2
info
[
'n_seqs'
]
=
self
.
n_seqs
*
2
info
[
'n_mit_sot'
]
=
self
.
n_mit_sot
*
2
info
[
'n_mit_sot'
]
=
self
.
n_mit_sot
*
2
info
[
'n_sit_sot'
]
=
self
.
n_sit_sot
*
2
info
[
'n_sit_sot'
]
=
self
.
n_sit_sot
*
2
info
[
'n_mit_mot'
]
=
self
.
n_mit_mot
*
2
info
[
'n_mit_mot'
]
=
self
.
n_mit_mot
*
2
info
[
'n_nit_sot'
]
=
self
.
n_nit_sot
*
2
info
[
'n_nit_sot'
]
=
self
.
n_nit_sot
*
2
info
[
'n_shared_outs'
]
=
self
.
n_shared_outs
*
2
info
[
'n_shared_outs'
]
=
self
.
n_shared_outs
*
2
info
[
'gpu'
]
=
False
info
[
'gpu'
]
=
False
info
[
'as_while'
]
=
self
.
as_while
info
[
'as_while'
]
=
self
.
as_while
info
[
'profile'
]
=
self
.
profile
info
[
'profile'
]
=
self
.
profile
info
[
'truncate_gradient'
]
=
self
.
truncate_gradient
info
[
'truncate_gradient'
]
=
self
.
truncate_gradient
if
self
.
name
:
if
self
.
name
:
info
[
'name'
]
=
'rop_of_'
+
self
.
name
info
[
'name'
]
=
'rop_of_'
+
self
.
name
else
:
else
:
info
[
'name'
]
=
None
info
[
'name'
]
=
None
info
[
'mode'
]
=
self
.
mode
info
[
'mode'
]
=
self
.
mode
info
[
'inplace'
]
=
False
info
[
'inplace'
]
=
False
info
[
'mit_mot_out_slices'
]
=
self
.
mit_mot_out_slices
*
2
info
[
'mit_mot_out_slices'
]
=
self
.
mit_mot_out_slices
*
2
new_tap_array
=
[]
new_tap_array
=
[]
b
=
0
b
=
0
e
=
self
.
n_mit_mot
e
=
self
.
n_mit_mot
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
b
=
e
b
=
e
e
+=
self
.
n_mit_sot
e
+=
self
.
n_mit_sot
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
b
=
e
b
=
e
e
+=
self
.
n_sit_sot
e
+=
self
.
n_sit_sot
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
new_tap_array
+=
self
.
tap_array
[
b
:
e
]
*
2
info
[
'tap_array'
]
=
new_tap_array
info
[
'tap_array'
]
=
new_tap_array
# Sequences ...
# Sequences ...
b
=
1
b
=
1
ib
=
0
ib
=
0
e
=
1
+
self
.
n_seqs
e
=
1
+
self
.
n_seqs
ie
=
self
.
n_seqs
ie
=
self
.
n_seqs
scan_seqs
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
scan_seqs
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
inner_seqs
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
inner_seqs
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
# MIT_MOT sequences ...
# MIT_MOT sequences ...
b
=
e
b
=
e
e
=
e
+
self
.
n_mit_mot
e
=
e
+
self
.
n_mit_mot
ib
=
ie
ib
=
ie
ie
=
ie
+
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
ie
=
ie
+
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
self
.
tap_array
[:
self
.
n_mit_mot
]]))
self
.
tap_array
[:
self
.
n_mit_mot
]]))
scan_mit_mot
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
scan_mit_mot
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
inner_mit_mot
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
inner_mit_mot
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
# MIT_SOT sequences ...
# MIT_SOT sequences ...
b
=
e
b
=
e
e
=
e
+
self
.
n_mit_sot
e
=
e
+
self
.
n_mit_sot
ib
=
ie
ib
=
ie
ie
=
ie
+
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
ie
=
ie
+
int
(
numpy
.
sum
([
len
(
x
)
for
x
in
self
.
tap_array
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]]))
self
.
tap_array
[
self
.
n_mit_mot
:
\
scan_mit_sot
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
self
.
n_mit_mot
+
self
.
n_mit_sot
]]))
scan_mit_sot
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
inner_mit_sot
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
inner_mit_sot
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
#SIT_SOT sequences ...
#SIT_SOT sequences ...
b
=
e
b
=
e
e
=
e
+
self
.
n_sit_sot
e
=
e
+
self
.
n_sit_sot
ib
=
ie
ib
=
ie
ie
=
ie
+
self
.
n_sit_sot
ie
=
ie
+
self
.
n_sit_sot
scan_sit_sot
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
scan_sit_sot
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
inner_sit_sot
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
inner_sit_sot
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
#Shared outs ...
#Shared outs ...
b
=
e
b
=
e
e
=
e
+
self
.
n_shared_outs
e
=
e
+
self
.
n_shared_outs
ib
=
ie
ib
=
ie
ie
=
ie
+
self
.
n_shared_outs
ie
=
ie
+
self
.
n_shared_outs
scan_shared
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
scan_shared
=
inputs
[
b
:
e
]
+
eval_points
[
b
:
e
]
inner_shared
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
inner_shared
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
# NIT_SOT sequences
# NIT_SOT sequences
b
=
e
b
=
e
e
=
e
+
self
.
n_nit_sot
e
=
e
+
self
.
n_nit_sot
scan_nit_sot
=
inputs
[
b
:
e
]
*
2
scan_nit_sot
=
inputs
[
b
:
e
]
*
2
# All other arguments
# All other arguments
scan_other
=
inputs
[
e
:]
+
eval_points
[
e
:]
scan_other
=
inputs
[
e
:]
+
eval_points
[
e
:]
...
@@ -1611,13 +1627,13 @@ class Scan(PureOp):
...
@@ -1611,13 +1627,13 @@ class Scan(PureOp):
e
=
e
+
self
.
n_shared_outs
e
=
e
+
self
.
n_shared_outs
inner_out_shared
=
self_outputs
[
b
:
e
]
+
rop_outs
[
b
:
e
]
inner_out_shared
=
self_outputs
[
b
:
e
]
+
rop_outs
[
b
:
e
]
inner_ins
=
(
inner_seqs
+
inner_ins
=
(
inner_seqs
+
inner_mit_mot
+
inner_mit_mot
+
inner_mit_sot
+
inner_mit_sot
+
inner_sit_sot
+
inner_sit_sot
+
inner_shared
+
inner_shared
+
inner_other
)
inner_other
)
inner_outs
=
(
inner_out_mit_mot
+
inner_outs
=
(
inner_out_mit_mot
+
inner_out_mit_sot
+
inner_out_mit_sot
+
inner_out_sit_sot
+
inner_out_sit_sot
+
inner_out_nit_sot
+
inner_out_nit_sot
+
...
@@ -1625,35 +1641,35 @@ class Scan(PureOp):
...
@@ -1625,35 +1641,35 @@ class Scan(PureOp):
if
self
.
as_while
:
if
self
.
as_while
:
inner_outs
+=
[
self_outputs
[
-
1
]]
inner_outs
+=
[
self_outputs
[
-
1
]]
scan_inputs
=
(
[
inputs
[
0
]]
+
scan_inputs
=
(
[
inputs
[
0
]]
+
scan_seqs
+
scan_seqs
+
scan_mit_mot
+
scan_mit_mot
+
scan_mit_sot
+
scan_mit_sot
+
scan_sit_sot
+
scan_sit_sot
+
scan_shared
+
scan_shared
+
scan_nit_sot
+
scan_nit_sot
+
scan_other
)
scan_other
)
local_op
=
Scan
(
inner_ins
,
inner_outs
,
info
)
local_op
=
Scan
(
inner_ins
,
inner_outs
,
info
)
outputs
=
local_op
(
*
scan_inputs
)
outputs
=
local_op
(
*
scan_inputs
)
if
type
(
outputs
)
not
in
(
list
,
tuple
):
if
type
(
outputs
)
not
in
(
list
,
tuple
):
outputs
=
[
outputs
]
outputs
=
[
outputs
]
# Select only the result of the R_op results
# Select only the result of the R_op results
final_outs
=
[]
final_outs
=
[]
b
=
self
.
n_mit_mot
b
=
self
.
n_mit_mot
e
=
self
.
n_mit_mot
*
2
e
=
self
.
n_mit_mot
*
2
final_outs
+=
outputs
[
b
:
e
]
final_outs
+=
outputs
[
b
:
e
]
b
=
e
+
self
.
n_mit_sot
b
=
e
+
self
.
n_mit_sot
e
=
e
+
self
.
n_mit_sot
*
2
e
=
e
+
self
.
n_mit_sot
*
2
final_outs
+=
outputs
[
b
:
e
]
final_outs
+=
outputs
[
b
:
e
]
b
=
e
+
self
.
n_sit_sot
b
=
e
+
self
.
n_sit_sot
e
=
e
+
self
.
n_sit_sot
*
2
e
=
e
+
self
.
n_sit_sot
*
2
final_outs
+=
outputs
[
b
:
e
]
final_outs
+=
outputs
[
b
:
e
]
b
=
e
+
self
.
n_nit_sot
b
=
e
+
self
.
n_nit_sot
e
=
e
+
self
.
n_nit_sot
*
2
e
=
e
+
self
.
n_nit_sot
*
2
final_outs
+=
outputs
[
b
:
e
]
final_outs
+=
outputs
[
b
:
e
]
b
=
e
+
self
.
n_shared_outs
b
=
e
+
self
.
n_shared_outs
e
=
e
+
self
.
n_shared_outs
*
2
e
=
e
+
self
.
n_shared_outs
*
2
final_outs
+=
outputs
[
b
:
e
]
final_outs
+=
outputs
[
b
:
e
]
return
final_outs
return
final_outs
...
@@ -1664,26 +1680,36 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
...
@@ -1664,26 +1680,36 @@ def profile_printer(fct_name, compile_time, fct_call_time, fct_call,
apply_time
,
apply_cimpl
,
message
,
outputs_size
,
apply_time
,
apply_cimpl
,
message
,
outputs_size
,
other_time
):
other_time
):
# Scan overhead profile
# Scan overhead profile
if
any
([
isinstance
(
node
.
op
,
Scan
)
and
v
>
0
for
(
_
,
node
),
v
in
if
any
([
isinstance
(
node
.
op
,
Scan
)
and
v
>
0
for
(
_
,
node
),
v
in
apply_time
.
items
()]):
apply_time
.
items
()]):
print
print
print
'Scan overhead:'
print
'Scan overhead:'
print
'<Scan op time(s)> <sub scan fct time(s)> <sub scan op time(s)> <sub scan fct time(
%
scan op time)> <sub scan op time(
%
scan op time)> <node>'
print
(
'<Scan op time(s)> <sub scan fct time(s)> <sub scan op '
'time(s)> <sub scan fct time(
%
scan op time)> <sub scan '
'op time(
%
scan op time)> <node>'
)
total_super_scan_time
=
0
total_super_scan_time
=
0
total_scan_fct_time
=
0
total_scan_fct_time
=
0
total_scan_op_time
=
0
total_scan_op_time
=
0
for
(
_
,
node
),
v
in
apply_time
.
items
():
for
(
_
,
node
),
v
in
apply_time
.
items
():
if
isinstance
(
node
.
op
,
Scan
):
if
isinstance
(
node
.
op
,
Scan
):
if
v
>
0
:
if
v
>
0
:
scan_fct_time
=
node
.
op
.
mode_instance
.
fn_time
scan_fct_time
=
node
.
op
.
mode_instance
.
fn_time
scan_op_time
=
node
.
op
.
mode_instance
.
local_time
scan_op_time
=
node
.
op
.
mode_instance
.
local_time
total_super_scan_time
+=
v
total_super_scan_time
+=
v
total_scan_fct_time
+=
scan_fct_time
total_scan_fct_time
+=
scan_fct_time
total_scan_op_time
+=
scan_op_time
total_scan_op_time
+=
scan_op_time
print
'
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
'
%
(
print
'
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
'
%
(
v
,
scan_fct_time
,
scan_op_time
,
scan_fct_time
/
v
*
100
,
v
,
scan_op_time
/
v
*
100
),
node
scan_fct_time
,
scan_op_time
,
scan_fct_time
/
v
*
100
,
scan_op_time
/
v
*
100
),
node
else
:
else
:
print
' The node took 0s, so we can not compute the overhead'
,
node
print
(
' The node took 0s, so we can not '
print
' total
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
'
%
(
'compute the overhead'
),
node
total_super_scan_time
,
total_scan_fct_time
,
total_scan_op_time
,
total_scan_fct_time
/
total_super_scan_time
*
100
,
total_scan_op_time
/
total_super_scan_time
*
100
)
print
' total
%5.1
fs
%5.1
fs
%5.1
fs
%5.1
f
%% %5.1
f
%%
'
%
(
total_super_scan_time
,
total_scan_fct_time
,
total_scan_op_time
,
total_scan_fct_time
/
total_super_scan_time
*
100
,
total_scan_op_time
/
total_super_scan_time
*
100
)
theano/scan_module/scan_opt.py
浏览文件 @
7cb2384c
...
@@ -417,7 +417,8 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -417,7 +417,8 @@ class ScanSaveMem(gof.Optimizer):
# change the number of steps in that case. To do this we set
# change the number of steps in that case. To do this we set
# global_nsteps to None which is seen as a flag that nothing needs
# global_nsteps to None which is seen as a flag that nothing needs
# to be done
# to be done
if
len
(
node
.
outputs
)
<=
c_outs
:
assert
len
(
node
.
outputs
)
>=
c_outs
if
len
(
node
.
outputs
)
==
c_outs
:
global_nsteps
=
{
'real'
:
-
1
,
'sym'
:
[]}
global_nsteps
=
{
'real'
:
-
1
,
'sym'
:
[]}
else
:
else
:
global_nsteps
=
None
global_nsteps
=
None
...
@@ -474,7 +475,7 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -474,7 +475,7 @@ class ScanSaveMem(gof.Optimizer):
break
break
# 2.3.2 extract the begin/end of the first dimension
# 2.3.2 extract the begin/end of the first dimension
if
i
>
op
.
n_mit_mot
:
if
i
>
=
op
.
n_mit_mot
:
try
:
try
:
length
=
shape_of
[
out
][
0
]
length
=
shape_of
[
out
][
0
]
except
KeyError
:
except
KeyError
:
...
@@ -650,7 +651,8 @@ class ScanSaveMem(gof.Optimizer):
...
@@ -650,7 +651,8 @@ class ScanSaveMem(gof.Optimizer):
tmp
=
tensor
.
as_tensor_variable
(
val
)
tmp
=
tensor
.
as_tensor_variable
(
val
)
initl
=
tensor
.
as_tensor_variable
(
init_l
[
i
])
initl
=
tensor
.
as_tensor_variable
(
init_l
[
i
])
tmp
=
tensor
.
maximum
(
tmp
,
initl
)
tmp
=
tensor
.
maximum
(
tmp
,
initl
)
tmp
=
pre_greedy_local_optimizer
(
list_opt_slice
,
tmp
)
tmp
=
pre_greedy_local_optimizer
(
list_opt_slice
,
tmp
)
tmp
=
pre_constant_merge
([
tmp
])[
0
]
tmp
=
pre_constant_merge
([
tmp
])[
0
]
nw_input
=
nw_inputs
[
offset
+
idx
][:
tmp
]
nw_input
=
nw_inputs
[
offset
+
idx
][:
tmp
]
...
...
theano/scan_module/scan_views.py
浏览文件 @
7cb2384c
...
@@ -5,10 +5,10 @@ See scan.py for details on scan
...
@@ -5,10 +5,10 @@ See scan.py for details on scan
"""
"""
__docformat__
=
'restructedtext en'
__docformat__
=
'restructedtext en'
__authors__
=
(
"Razvan Pascanu "
__authors__
=
(
"Razvan Pascanu "
"Frederic Bastien "
"Frederic Bastien "
"James Bergstra "
"James Bergstra "
"Pascal Lamblin "
)
"Pascal Lamblin "
)
__copyright__
=
"(c) 2010, Universite de Montreal"
__copyright__
=
"(c) 2010, Universite de Montreal"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
__contact__
=
"Razvan Pascanu <r.pascanu@gmail>"
...
@@ -27,13 +27,13 @@ _logger = logging.getLogger('theano.scan_module.scan_views')
...
@@ -27,13 +27,13 @@ _logger = logging.getLogger('theano.scan_module.scan_views')
# The ``map`` view of Scan Op.
# The ``map`` view of Scan Op.
def
map
(
fn
def
map
(
fn
,
,
sequences
sequences
,
,
non_sequences
=
None
non_sequences
=
None
,
,
truncate_gradient
=
-
1
truncate_gradient
=-
1
,
,
go_backwards
=
False
go_backwards
=
False
,
,
mode
=
None
mode
=
None
,
,
name
=
None
):
name
=
None
):
"""
"""
Similar behaviour as python's map.
Similar behaviour as python's map.
...
@@ -58,24 +58,24 @@ def map( fn
...
@@ -58,24 +58,24 @@ def map( fn
:param name: See ``scan``.
:param name: See ``scan``.
"""
"""
return
scan
.
scan
(
fn
=
fn
return
scan
.
scan
(
fn
=
fn
,
,
sequences
=
sequences
sequences
=
sequences
,
,
outputs_info
=
[]
outputs_info
=
[],
,
non_sequences
=
non_sequences
non_sequences
=
non_sequences
,
,
truncate_gradient
=
truncate_gradient
truncate_gradient
=
truncate_gradient
,
,
go_backwards
=
go_backwards
go_backwards
=
go_backwards
,
,
mode
=
mode
mode
=
mode
,
,
name
=
name
)
name
=
name
)
# The ``reduce`` view of Scan Op.
# The ``reduce`` view of Scan Op.
def
reduce
(
fn
def
reduce
(
fn
,
,
sequences
sequences
,
,
outputs_info
outputs_info
,
,
non_sequences
=
None
non_sequences
=
None
,
,
go_backwards
=
False
go_backwards
=
False
,
,
mode
=
None
mode
=
None
,
,
name
=
None
):
name
=
None
):
"""
"""
Similar behaviour as python's reduce
Similar behaviour as python's reduce
...
@@ -101,27 +101,27 @@ def reduce( fn
...
@@ -101,27 +101,27 @@ def reduce( fn
:param name: See ``scan``.
:param name: See ``scan``.
"""
"""
rval
=
scan
.
scan
(
fn
=
fn
rval
=
scan
.
scan
(
fn
=
fn
,
,
sequences
=
sequences
sequences
=
sequences
,
,
outputs_info
=
outputs_info
outputs_info
=
outputs_info
,
,
non_sequences
=
non_sequences
non_sequences
=
non_sequences
,
,
go_backwards
=
go_backwards
go_backwards
=
go_backwards
,
,
truncate_gradient
=
-
1
truncate_gradient
=-
1
,
,
mode
=
mode
mode
=
mode
,
,
name
=
name
)
name
=
name
)
if
isinstance
(
rval
[
0
],
(
list
,
tuple
)):
if
isinstance
(
rval
[
0
],
(
list
,
tuple
)):
return
[
x
[
-
1
]
for
x
in
rval
[
0
]],
rval
[
1
]
return
[
x
[
-
1
]
for
x
in
rval
[
0
]],
rval
[
1
]
else
:
else
:
return
rval
[
0
][
-
1
],
rval
[
1
]
return
rval
[
0
][
-
1
],
rval
[
1
]
# The ``foldl`` view of Scan Op.
# The ``foldl`` view of Scan Op.
def
foldl
(
fn
def
foldl
(
fn
,
,
sequences
sequences
,
,
outputs_info
outputs_info
,
,
non_sequences
=
None
non_sequences
=
None
,
,
mode
=
None
mode
=
None
,
,
name
=
None
):
name
=
None
):
"""
"""
Similar behaviour as haskell's foldl
Similar behaviour as haskell's foldl
...
@@ -143,22 +143,22 @@ def foldl( fn
...
@@ -143,22 +143,22 @@ def foldl( fn
:param name: See ``scan``.
:param name: See ``scan``.
"""
"""
return
reduce
(
fn
=
fn
return
reduce
(
fn
=
fn
,
,
sequences
=
sequences
sequences
=
sequences
,
,
outputs_info
=
outputs_info
outputs_info
=
outputs_info
,
,
non_sequences
=
non_sequences
non_sequences
=
non_sequences
,
,
go_backwards
=
False
go_backwards
=
False
,
,
mode
=
mode
mode
=
mode
,
,
name
=
name
)
name
=
name
)
# The ``foldl`` view of Scan Op.
# The ``foldl`` view of Scan Op.
def
foldr
(
fn
def
foldr
(
fn
,
,
sequences
sequences
,
,
outputs_info
outputs_info
,
,
non_sequences
=
None
non_sequences
=
None
,
,
mode
=
None
mode
=
None
,
,
name
=
None
):
name
=
None
):
"""
"""
Similar behaviour as haskell' foldr
Similar behaviour as haskell' foldr
...
@@ -180,10 +180,10 @@ def foldr( fn
...
@@ -180,10 +180,10 @@ def foldr( fn
:param name: See ``scan``.
:param name: See ``scan``.
"""
"""
return
reduce
(
fn
=
fn
return
reduce
(
fn
=
fn
,
,
sequences
=
sequences
sequences
=
sequences
,
,
outputs_info
=
outputs_info
outputs_info
=
outputs_info
,
,
non_sequences
=
non_sequences
non_sequences
=
non_sequences
,
,
go_backwards
=
True
go_backwards
=
True
,
,
mode
=
mode
mode
=
mode
,
,
name
=
name
)
name
=
name
)
theano/scan_module/tests/test_scan.py
浏览文件 @
7cb2384c
...
@@ -2585,6 +2585,26 @@ class T_Scan(unittest.TestCase):
...
@@ -2585,6 +2585,26 @@ class T_Scan(unittest.TestCase):
tf
=
theano
.
function
([
c
,
x
],
dP
)
tf
=
theano
.
function
([
c
,
x
],
dP
)
assert
tf
([
1.0
,
2.0
,
-
3.0
,
4.0
],
2.0
)
==
38
assert
tf
([
1.0
,
2.0
,
-
3.0
,
4.0
],
2.0
)
==
38
def
test_grad_of_grad_of_state
(
self
):
# Example provided Michael Forbes
# This tests ensures that we can compute gradients through cost
# defines in terms of gradients of scan
c
=
theano
.
tensor
.
vector
(
'c'
)
x
=
theano
.
tensor
.
scalar
(
'x'
)
_max_coefficients_supported
=
1000
full_range
=
theano
.
tensor
.
arange
(
_max_coefficients_supported
)
components
,
updates
=
theano
.
scan
(
fn
=
lambda
coeff
,
power
,
free_var
:
coeff
*
(
free_var
**
power
),
outputs_info
=
None
,
sequences
=
[
c
,
full_range
],
non_sequences
=
x
)
P
=
components
.
sum
()
dP
=
theano
.
tensor
.
grad
(
P
,
x
)
.
sum
()
ddP
=
theano
.
tensor
.
grad
(
dP
,
x
)
tf
=
theano
.
function
([
c
,
x
],
ddP
)
assert
tf
([
1.0
,
2.0
,
-
3.0
,
4.0
],
2.0
)
==
42
def
test_return_steps
(
self
):
def
test_return_steps
(
self
):
rng
=
numpy
.
random
.
RandomState
(
utt
.
fetch_seed
())
rng
=
numpy
.
random
.
RandomState
(
utt
.
fetch_seed
())
vW_in2
=
asarrayX
(
rng
.
uniform
(
size
=
(
2
,),
low
=
-
5.
,
high
=
5.
))
vW_in2
=
asarrayX
(
rng
.
uniform
(
size
=
(
2
,),
low
=
-
5.
,
high
=
5.
))
...
@@ -2705,9 +2725,9 @@ class T_Scan(unittest.TestCase):
...
@@ -2705,9 +2725,9 @@ class T_Scan(unittest.TestCase):
grad_fn
=
theano
.
function
([
xinit
,
w
],
[
gx
,
gw
],
grad_fn
=
theano
.
function
([
xinit
,
w
],
[
gx
,
gw
],
allow_input_downcast
=
True
)
allow_input_downcast
=
True
)
rng
=
numpy
.
random
.
RandomState
(
utt
.
fetch_seed
())
rng
=
numpy
.
random
.
RandomState
(
utt
.
fetch_seed
())
v_x
=
numpy
.
array
(
rng
.
uniform
(
size
=
(
5
,
2
,
3
),
low
=-
2.
,
high
=
2
.
),
v_x
=
numpy
.
array
(
rng
.
uniform
(
size
=
(
5
,
2
,
3
),
low
=-
3.
,
high
=
3
.
),
dtype
=
theano
.
config
.
floatX
)
dtype
=
theano
.
config
.
floatX
)
v_w
=
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,
2
)),
dtype
=
theano
.
config
.
floatX
)
v_w
=
numpy
.
array
(
rng
.
uniform
(
size
=
(
2
,
2
)
,
low
=-
3.
,
high
=
3.
),
dtype
=
theano
.
config
.
floatX
)
analytic_grad
=
grad_fn
(
v_x
,
v_w
)
analytic_grad
=
grad_fn
(
v_x
,
v_w
)
num_grad
=
multiple_outputs_numeric_grad
(
cost_fn
,
num_grad
=
multiple_outputs_numeric_grad
(
cost_fn
,
[
v_x
,
v_w
])
[
v_x
,
v_w
])
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论