Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
dc5617c1
提交
dc5617c1
authored
5月 09, 2013
作者:
lamblin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1341 from pascanur/recent_scan_bugs
Recent scan bugs
上级
cc574b89
1aa41c7d
显示空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
117 行增加
和
51 行删除
+117
-51
scan_op.py
theano/scan_module/scan_op.py
+77
-46
scan_opt.py
theano/scan_module/scan_opt.py
+1
-1
scan_utils.py
theano/scan_module/scan_utils.py
+6
-4
test_scan.py
theano/scan_module/tests/test_scan.py
+33
-0
没有找到文件。
theano/scan_module/scan_op.py
浏览文件 @
dc5617c1
...
@@ -192,19 +192,18 @@ class Scan(PureOp):
...
@@ -192,19 +192,18 @@ class Scan(PureOp):
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
"""
"""
Conventions:
Conventions:
inner_
? - the variable corresponding to ?
in the inner function
inner_
X - the variable corresponding to X
in the inner function
of scan (the lambda function executed at every time
of scan (the lambda function executed at every time
step)
step)
outer_
? - the variable corresponding to ?
in the outer graph,
outer_
X - the variable corresponding to X
in the outer graph,
i.e. the main graph (where the scan op lives)
i.e. the main graph (where the scan op lives)
inner_
?_out - the variable representing the new value of ?
after
inner_
X_out - the variable representing the new value of X
after
executing one step of scan (i.e. outputs given by
executing one step of scan (i.e. outputs given by
the inner function)
the inner function)
"""
"""
assert
numpy
.
all
(
isinstance
(
i
,
gof
.
Variable
)
for
i
in
inputs
)
assert
numpy
.
all
(
isinstance
(
i
,
gof
.
Variable
)
for
i
in
inputs
)
# Check that the number of inputs to the Scan node corresponds to
# Check that the number of inputs to the Scan node corresponds to
# the number of inputs of the inner function of scan
# the number of inputs of the inner function of scan
n_outer_ins
=
len
(
inputs
)
-
len
(
self
.
outer_nitsot
(
inputs
))
-
1
n_outer_ins
=
len
(
inputs
)
-
len
(
self
.
outer_nitsot
(
inputs
))
-
1
n_inner_ins
=
(
len
(
self
.
inner_seqs
(
self
.
inputs
))
+
n_inner_ins
=
(
len
(
self
.
inner_seqs
(
self
.
inputs
))
+
len
(
self
.
mitmot_taps
())
+
len
(
self
.
mitmot_taps
())
+
...
@@ -215,7 +214,7 @@ class Scan(PureOp):
...
@@ -215,7 +214,7 @@ class Scan(PureOp):
assert
n_outer_ins
==
n_inner_ins
,
\
assert
n_outer_ins
==
n_inner_ins
,
\
(
"The number of inputs given to the inner function of scan"
(
"The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan."
)
" does not match the number of inputs given to scan."
)
new_inputs
=
[
inputs
[
0
]]
# assert dtype is consistent
# assert dtype is consistent
err_msg1
=
(
'When compiling the inner function of scan the '
err_msg1
=
(
'When compiling the inner function of scan the '
'following error has been encountered: The '
'following error has been encountered: The '
...
@@ -235,42 +234,35 @@ class Scan(PureOp):
...
@@ -235,42 +234,35 @@ class Scan(PureOp):
'and
%
d dimension(s). This could happen if the inner '
'and
%
d dimension(s). This could happen if the inner '
'graph of scan results in an upcast or downcast. '
'graph of scan results in an upcast or downcast. '
'Please make sure that you use dtypes consistently'
)
'Please make sure that you use dtypes consistently'
)
# TODO make the assert exact
# TODO assert the type(dtype, nbdim of self.inputs and
# inputs correspond)
#assert len(inputs) >= len(self.inputs)
#if self.info['as_while']:
# assert len(inputs) == len(self.inputs) + 2 + \
# self.info["n_nit_sot"]
#else:
# assert len(inputs) == len(self.inputs) + 1 + \
# self.info["n_nit_sot"]
# Flags that indicate which inputs are vectors
self
.
vector_seqs
=
[
seq
.
ndim
==
1
for
seq
in
def
format
(
var
,
as_var
):
inputs
[
1
:
1
+
self
.
n_seqs
]]
""" This functions ensures that ``out`` has the same dtype as
self
.
vector_outs
=
[
arg
.
ndim
==
1
for
arg
in
``inp`` as well as calling filter_variable to make sure they are
inputs
[
1
+
self
.
n_seqs
:
(
1
+
self
.
n_seqs
+
both TensorType or CudaNdarrayType. It internally deals with the
self
.
n_outs
)]]
corner case where inp.ndim + 1 = out.ndim
self
.
vector_outs
+=
[
False
]
*
self
.
n_nit_sot
"""
if
not
hasattr
(
var
,
'dtype'
):
return
var
rval
=
var
if
rval
.
type
.
dtype
!=
as_var
.
type
.
dtype
:
rval
=
rval
.
astype
(
as_var
.
type
.
dtype
)
if
rval
.
ndim
==
as_var
.
ndim
:
rval
=
as_var
.
type
.
filter_variable
(
rval
)
else
:
tmp
=
as_var
.
type
.
__class__
(
broadcastable
=
tuple
(
var
.
broadcastable
[:
1
])
+
\
tuple
(
as_var
.
broadcastable
),
dtype
=
as_var
.
dtype
)
rval
=
tmp
.
filter_variable
(
rval
)
return
rval
# Check if input sequences and variables representing a slice of
# Check if input sequences and variables representing a slice of
# them have the same dtype
# them have the same dtype
argoffset
=
0
argoffset
=
0
for
idx
,
(
inner_seq
,
outer_seq
)
in
enumerate
(
for
inner_seq
,
outer_seq
in
zip
(
self
.
inner_seqs
(
self
.
inputs
),
zip
(
self
.
inner_seqs
(
self
.
inputs
),
self
.
outer_seqs
(
inputs
)):
self
.
outer_seqs
(
inputs
))):
new_inputs
.
append
(
format
(
outer_seq
,
as_var
=
inner_seq
))
if
inner_seq
.
type
.
dtype
!=
outer_seq
[
0
]
.
type
.
dtype
:
assert
isinstance
(
idx
,
int
)
raise
ValueError
(
err_msg1
%
(
'sequence'
,
str
(
outer_seq
),
idx
,
outer_seq
.
type
.
dtype
,
outer_seq
.
ndim
,
str
(
inner_seq
),
inner_seq
.
type
.
dtype
,
inner_seq
.
ndim
))
argoffset
+=
len
(
self
.
outer_seqs
(
inputs
))
argoffset
+=
len
(
self
.
outer_seqs
(
inputs
))
# Check that this 3 things have the same dtype for mit_mot:
# Check that this 3 things have the same dtype for mit_mot:
# - initial state of the output
# - initial state of the output
...
@@ -280,10 +272,12 @@ class Scan(PureOp):
...
@@ -280,10 +272,12 @@ class Scan(PureOp):
opos
=
0
opos
=
0
inner_mitmot
=
self
.
inner_mitmot
(
self
.
inputs
)
inner_mitmot
=
self
.
inner_mitmot
(
self
.
inputs
)
inner_mitmot_outs
=
self
.
inner_mitmot_outs
(
self
.
outputs
)
inner_mitmot_outs
=
self
.
inner_mitmot_outs
(
self
.
outputs
)
for
idx
,
(
itaps
,
otaps
,
outer_mitmot
)
in
enumerate
(
for
idx
,
(
itaps
,
otaps
,
_
outer_mitmot
)
in
enumerate
(
zip
(
self
.
mitmot_taps
(),
zip
(
self
.
mitmot_taps
(),
self
.
mitmot_out_taps
(),
self
.
mitmot_out_taps
(),
self
.
outer_mitmot
(
inputs
))):
self
.
outer_mitmot
(
inputs
))):
outer_mitmot
=
format
(
_outer_mitmot
,
as_var
=
inner_mitmot
[
ipos
])
new_inputs
.
append
(
outer_mitmot
)
for
k
in
xrange
(
len
(
itaps
)):
for
k
in
xrange
(
len
(
itaps
)):
if
(
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
!=
if
(
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
!=
outer_mitmot
.
type
.
dtype
or
outer_mitmot
.
type
.
dtype
or
...
@@ -316,10 +310,13 @@ class Scan(PureOp):
...
@@ -316,10 +310,13 @@ class Scan(PureOp):
# Same checks as above but for outputs of type mit_sot
# Same checks as above but for outputs of type mit_sot
ipos
=
0
ipos
=
0
inner_mitsots
=
self
.
inner_mitsot
(
self
.
inputs
)
inner_mitsots
=
self
.
inner_mitsot
(
self
.
inputs
)
for
idx
,
(
itaps
,
outer_mitsot
,
inner_mitsot_out
)
in
enumerate
(
for
idx
,
(
itaps
,
_
outer_mitsot
,
inner_mitsot_out
)
in
enumerate
(
zip
(
self
.
mitsot_taps
(),
zip
(
self
.
mitsot_taps
(),
self
.
outer_mitsot
(
inputs
),
self
.
outer_mitsot
(
inputs
),
self
.
inner_mitsot_outs
(
self
.
outputs
))):
self
.
inner_mitsot_outs
(
self
.
outputs
))):
outer_mitsot
=
format
(
_outer_mitsot
,
as_var
=
inner_mitsots
[
ipos
])
new_inputs
.
append
(
outer_mitsot
)
for
k
in
xrange
(
len
(
itaps
)):
for
k
in
xrange
(
len
(
itaps
)):
if
(
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
!=
\
if
(
inner_mitsots
[
ipos
+
k
]
.
type
.
dtype
!=
\
outer_mitsot
.
type
.
dtype
or
outer_mitsot
.
type
.
dtype
or
...
@@ -346,12 +343,13 @@ class Scan(PureOp):
...
@@ -346,12 +343,13 @@ class Scan(PureOp):
argoffset
+=
len
(
self
.
outer_mitsot
(
inputs
))
argoffset
+=
len
(
self
.
outer_mitsot
(
inputs
))
# Same checks as above but for outputs of type sit_sot
# Same checks as above but for outputs of type sit_sot
for
idx
,
(
inner_sitsot
,
outer_sitsot
,
inner_sitsot_out
)
in
enumerate
(
for
idx
,
(
inner_sitsot
,
_
outer_sitsot
,
inner_sitsot_out
)
in
enumerate
(
zip
(
self
.
inner_sitsot
(
self
.
inputs
),
zip
(
self
.
inner_sitsot
(
self
.
inputs
),
self
.
outer_sitsot
(
inputs
),
self
.
outer_sitsot
(
inputs
),
self
.
inner_sitsot_outs
(
self
.
outputs
))):
self
.
inner_sitsot_outs
(
self
.
outputs
))):
if
(
inner_sitsot
.
type
.
dtype
!=
outer_sitsot
.
type
.
dtype
or
outer_sitsot
=
format
(
_outer_sitsot
,
as_var
=
inner_sitsot
)
inner_sitsot
.
ndim
!=
outer_sitsot
.
ndim
-
1
):
new_inputs
.
append
(
outer_sitsot
)
if
(
inner_sitsot
.
ndim
!=
outer_sitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
' in scan nomenclature) '
,
str
(
outer_sitsot
),
str
(
outer_sitsot
),
...
@@ -374,10 +372,12 @@ class Scan(PureOp):
...
@@ -374,10 +372,12 @@ class Scan(PureOp):
argoffset
+=
len
(
self
.
outer_sitsot
(
inputs
))
argoffset
+=
len
(
self
.
outer_sitsot
(
inputs
))
# Check that the shared variable and their update rule have the same
# Check that the shared variable and their update rule have the same
# dtype. Maybe even same type ?!
# dtype. Maybe even same type ?!
for
idx
,
(
inner_shared
,
inner_shared_out
,
outer_shared
)
in
enumerate
(
for
idx
,
(
inner_shared
,
inner_shared_out
,
_
outer_shared
)
in
enumerate
(
zip
(
self
.
inner_shared
(
self
.
inputs
),
zip
(
self
.
inner_shared
(
self
.
inputs
),
self
.
inner_shared_outs
(
self
.
outputs
),
self
.
inner_shared_outs
(
self
.
outputs
),
self
.
outer_shared
(
inputs
))):
self
.
outer_shared
(
inputs
))):
outer_shared
=
format
(
_outer_shared
,
as_var
=
inner_shared
)
new_inputs
.
append
(
outer_shared
)
if
(
hasattr
(
outer_shared
,
'dtype'
)
and
if
(
hasattr
(
outer_shared
,
'dtype'
)
and
(
outer_shared
.
dtype
!=
inner_shared_out
.
dtype
or
(
outer_shared
.
dtype
!=
inner_shared_out
.
dtype
or
outer_shared
.
ndim
!=
inner_shared_out
.
ndim
)):
outer_shared
.
ndim
!=
inner_shared_out
.
ndim
)):
...
@@ -400,13 +400,25 @@ class Scan(PureOp):
...
@@ -400,13 +400,25 @@ class Scan(PureOp):
str
(
inner_shared
),
str
(
inner_shared
),
inner_shared
.
dtype
,
inner_shared
.
dtype
,
inner_shared
.
ndim
))
inner_shared
.
ndim
))
for
inner_nonseq
,
outer_nonseq
in
zip
(
# We do not need to call `format` on outer_nisot arguments.
# outer_nitsot stands for no input tap single output tap. This means
# these are states that do not feed anything back in the recurrent
# computation, and hence they do not have an initial state. The scan
# node however receives an input for each such argument, the input
# in this case is just a int saying how many steps of this output we
# need to store. This input does not have the same dtype, nor is it the same
# type of tensor as the output, it is always a scalar int.
new_inputs
+=
self
.
outer_nitsot
(
inputs
)
for
inner_nonseq
,
_outer_nonseq
in
zip
(
self
.
inner_non_seqs
(
self
.
inputs
),
self
.
inner_non_seqs
(
self
.
inputs
),
self
.
outer_non_seqs
(
inputs
)):
self
.
outer_non_seqs
(
inputs
)):
outer_nonseq
=
format
(
_outer_nonseq
,
as_var
=
inner_nonseq
)
new_inputs
.
append
(
outer_nonseq
)
if
inner_nonseq
.
type
!=
outer_nonseq
.
type
:
if
inner_nonseq
.
type
!=
outer_nonseq
.
type
:
raise
ValueError
((
'Argument
%
s given to scan node does not'
raise
ValueError
((
'Argument
%
s given to scan node does not'
' match its correspondance
%
s'
)
%
' match its correspondance
%
s'
)
%
(
str
(
outer_nonseq
),
str
(
inner_nonseq
)))
(
str
(
outer_nonseq
),
str
(
inner_nonseq
)))
for
outer_nitsot
in
self
.
outer_nitsot
(
inputs
):
for
outer_nitsot
in
self
.
outer_nitsot
(
inputs
):
# For every nit_sot input we get as input a int/uint that
# For every nit_sot input we get as input a int/uint that
# depicts the size in memory for that sequence. This feature is
# depicts the size in memory for that sequence. This feature is
...
@@ -415,9 +427,16 @@ class Scan(PureOp):
...
@@ -415,9 +427,16 @@ class Scan(PureOp):
outer_nitsot
.
ndim
!=
0
):
outer_nitsot
.
ndim
!=
0
):
raise
ValueError
(
'For output
%
s you need to provide a '
raise
ValueError
(
'For output
%
s you need to provide a '
'scalar int !'
,
str
(
outer_nitsot
))
'scalar int !'
,
str
(
outer_nitsot
))
assert
len
(
new_inputs
)
==
len
(
inputs
)
self
.
vector_seqs
=
[
seq
.
ndim
==
1
for
seq
in
new_inputs
[
1
:
1
+
self
.
n_seqs
]]
self
.
vector_outs
=
[
arg
.
ndim
==
1
for
arg
in
new_inputs
[
1
+
self
.
n_seqs
:
(
1
+
self
.
n_seqs
+
self
.
n_outs
)]]
self
.
vector_outs
+=
[
False
]
*
self
.
n_nit_sot
apply_node
=
Apply
(
self
,
apply_node
=
Apply
(
self
,
inputs
,
new_
inputs
,
[
t
()
for
t
in
self
.
output_types
])
[
t
()
for
t
in
self
.
output_types
])
return
apply_node
return
apply_node
...
@@ -1199,6 +1218,9 @@ class Scan(PureOp):
...
@@ -1199,6 +1218,9 @@ class Scan(PureOp):
return
scan_outs
return
scan_outs
def
get_input_pos
(
self
,
output_index
):
def
get_input_pos
(
self
,
output_index
):
""" For a given ``output_index``, an index in the inner outputs of
scan, find a corresponding first index in the inner inputs of scan
"""
ipos
=
self
.
n_seqs
ipos
=
self
.
n_seqs
opos
=
output_index
opos
=
output_index
for
otaps
,
itaps
in
zip
(
self
.
mitmot_out_taps
(),
self
.
mitmot_taps
()):
for
otaps
,
itaps
in
zip
(
self
.
mitmot_out_taps
(),
self
.
mitmot_taps
()):
...
@@ -1219,6 +1241,9 @@ class Scan(PureOp):
...
@@ -1219,6 +1241,9 @@ class Scan(PureOp):
return
-
1
return
-
1
def
get_output_pos
(
self
,
input_index
):
def
get_output_pos
(
self
,
input_index
):
""" For a given ``input_index``, an index in the inner inputs of
scan, find a corresponding first index in the inner outputs of scan
"""
ipos
=
input_index
ipos
=
input_index
opos
=
0
opos
=
0
for
otaps
,
itaps
in
zip
(
self
.
mitmot_out_taps
(),
self
.
mitmot_taps
()):
for
otaps
,
itaps
in
zip
(
self
.
mitmot_out_taps
(),
self
.
mitmot_taps
()):
...
@@ -1239,6 +1264,9 @@ class Scan(PureOp):
...
@@ -1239,6 +1264,9 @@ class Scan(PureOp):
return
-
1
return
-
1
def
get_output_slice_idx
(
self
,
output_index
):
def
get_output_slice_idx
(
self
,
output_index
):
""" For an ``output_index``, an index in the outter ouputs of scan,
find a corresponding index in the inner outputs of scan.
"""
ipos
=
0
ipos
=
0
opos
=
output_index
opos
=
output_index
for
otaps
in
zip
(
self
.
mitmot_out_taps
()):
for
otaps
in
zip
(
self
.
mitmot_out_taps
()):
...
@@ -1339,6 +1367,7 @@ class Scan(PureOp):
...
@@ -1339,6 +1367,7 @@ class Scan(PureOp):
# Applying Floyd-Warshall to find all paths connecting inputs to
# Applying Floyd-Warshall to find all paths connecting inputs to
# outputs. Note that if `x` is an input to `y_t` and `y_tm1` is an
# outputs. Note that if `x` is an input to `y_t` and `y_tm1` is an
# input to `z_t` then `x` is an input to `z_t`.
# input to `z_t` then `x` is an input to `z_t`.
n_outs
=
len
(
node
.
outputs
)
n_outs
=
len
(
node
.
outputs
)
for
steps
in
xrange
(
n_outs
):
for
steps
in
xrange
(
n_outs
):
for
iidx
in
xrange
(
n_outs
):
for
iidx
in
xrange
(
n_outs
):
...
@@ -1429,11 +1458,13 @@ class Scan(PureOp):
...
@@ -1429,11 +1458,13 @@ class Scan(PureOp):
odx
=
get_out_idx
(
self_outputs
.
index
(
y
))
odx
=
get_out_idx
(
self_outputs
.
index
(
y
))
wrt
=
[
x
for
x
in
theano
.
gof
.
graph
.
inputs
([
y
])
wrt
=
[
x
for
x
in
theano
.
gof
.
graph
.
inputs
([
y
])
if
(
x
in
diff_inputs
)
and
if
(
x
in
diff_inputs
)
and
connection_pattern
[
get_inp_idx
(
self_inputs
.
index
(
x
))][
odx
]]
(
connection_pattern
[
get_inp_idx
(
self_inputs
.
index
(
x
))][
odx
])]
grads
=
gradient
.
grad
(
grads
=
gradient
.
grad
(
cost
=
None
,
cost
=
None
,
known_grads
=
{
y
:
g_y
},
known_grads
=
{
y
:
g_y
},
wrt
=
wrt
,
consider_constant
=
wrt
,
wrt
=
wrt
,
consider_constant
=
wrt
,
disconnected_inputs
=
'ignore'
,
disconnected_inputs
=
'ignore'
,
return_disconnected
=
'None'
)
return_disconnected
=
'None'
)
gmp
=
dict
(
zip
(
wrt
,
grads
))
gmp
=
dict
(
zip
(
wrt
,
grads
))
...
...
theano/scan_module/scan_opt.py
浏览文件 @
dc5617c1
...
@@ -1159,7 +1159,7 @@ class ScanMerge(gof.Optimizer):
...
@@ -1159,7 +1159,7 @@ class ScanMerge(gof.Optimizer):
Questionable, we should also consider profile ?
Questionable, we should also consider profile ?
"""
"""
rep
=
set_nodes
[
0
]
rep
=
set_nodes
[
0
]
if
not
rep
.
op
.
as_while
and
node
.
op
.
as_while
:
if
rep
.
op
.
as_while
!=
node
.
op
.
as_while
:
return
False
return
False
nsteps
=
node
.
inputs
[
0
]
nsteps
=
node
.
inputs
[
0
]
...
...
theano/scan_module/scan_utils.py
浏览文件 @
dc5617c1
...
@@ -48,7 +48,10 @@ def safe_new(x, tag='', dtype=None):
...
@@ -48,7 +48,10 @@ def safe_new(x, tag='', dtype=None):
nw_name
=
None
nw_name
=
None
if
isinstance
(
x
,
theano
.
Constant
):
if
isinstance
(
x
,
theano
.
Constant
):
if
dtype
and
x
.
dtype
!=
dtype
:
if
dtype
and
x
.
dtype
!=
dtype
:
return
x
.
clone
()
.
astype
(
dtype
)
casted_x
=
x
.
astype
(
dtype
)
nwx
=
x
.
__class__
(
casted_x
.
type
,
x
.
data
,
x
.
name
)
nwx
.
tag
=
copy
(
x
.
tag
)
return
nwx
else
:
else
:
return
x
.
clone
()
return
x
.
clone
()
# Note, as_tensor_variable will convert the Scalar into a
# Note, as_tensor_variable will convert the Scalar into a
...
@@ -70,6 +73,8 @@ def safe_new(x, tag='', dtype=None):
...
@@ -70,6 +73,8 @@ def safe_new(x, tag='', dtype=None):
# ndarrays
# ndarrays
pass
pass
nw_x
=
x
.
type
()
nw_x
=
x
.
type
()
if
dtype
and
nw_x
.
dtype
!=
dtype
:
nw_x
=
nw_x
.
astype
(
dtype
)
.
type
()
nw_x
.
name
=
nw_name
nw_x
.
name
=
nw_name
# Preserve test values so that the 'compute_test_value' option can be used.
# Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions
# The test value is deep-copied to ensure there can be no interactions
...
@@ -82,9 +87,6 @@ def safe_new(x, tag='', dtype=None):
...
@@ -82,9 +87,6 @@ def safe_new(x, tag='', dtype=None):
# This means `x` has no test value.
# This means `x` has no test value.
pass
pass
if
dtype
and
nw_x
.
dtype
!=
dtype
:
nw_x
=
nw_x
.
astype
(
dtype
)
return
nw_x
return
nw_x
...
...
theano/scan_module/tests/test_scan.py
浏览文件 @
dc5617c1
...
@@ -3452,6 +3452,39 @@ class T_Scan(unittest.TestCase):
...
@@ -3452,6 +3452,39 @@ class T_Scan(unittest.TestCase):
assert
numpy
.
allclose
(
test
(
x
,
tensor
.
sum
((
x
+
1
)
**
2
),
mention_y
=
True
),
assert
numpy
.
allclose
(
test
(
x
,
tensor
.
sum
((
x
+
1
)
**
2
),
mention_y
=
True
),
1.21000003815
)
1.21000003815
)
def
test_grad_find_input
(
self
):
w
=
theano
.
shared
(
numpy
.
array
(
0
,
dtype
=
'float32'
),
name
=
'w'
)
init
=
tensor
.
fscalar
(
'init'
)
out
,
_
=
theano
.
scan
(
fn
=
lambda
prev
:
w
,
outputs_info
=
init
,
n_steps
=
2
,
)
tensor
.
grad
(
out
[
-
1
],
w
)
def
test_scan_merge_nodes
(
self
):
inps
=
tensor
.
vector
()
state
=
tensor
.
scalar
()
y1
,
_
=
theano
.
scan
(
lambda
x
,
y
:
x
*
y
,
sequences
=
inps
,
outputs_info
=
state
,
n_steps
=
5
)
y2
,
_
=
theano
.
scan
(
lambda
x
,
y
:
(
x
+
y
,
theano
.
scan_module
.
until
(
x
>
0
)),
sequences
=
inps
,
outputs_info
=
state
,
n_steps
=
5
)
scan_node1
=
y1
.
owner
.
inputs
[
0
]
.
owner
assert
isinstance
(
scan_node1
.
op
,
theano
.
scan_module
.
scan_op
.
Scan
)
scan_node2
=
y2
.
owner
.
inputs
[
0
]
.
owner
assert
isinstance
(
scan_node2
.
op
,
theano
.
scan_module
.
scan_op
.
Scan
)
opt_obj
=
theano
.
scan_module
.
scan_opt
.
ScanMerge
()
# Test the method belongs_to of this class. Specifically see if it
# detects the two scan_nodes as not being similar
assert
not
opt_obj
.
belongs_to_set
(
scan_node1
,
[
scan_node2
])
assert
not
opt_obj
.
belongs_to_set
(
scan_node2
,
[
scan_node1
])
def
test_speed
():
def
test_speed
():
#
#
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论