Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
595ec4b2
提交
595ec4b2
authored
11月 02, 2012
作者:
lamblin
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1009 from pascanur/scan_grad_dtype_issue
Scan grad dtype issue
上级
13839a98
096c01f0
隐藏空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
692 行增加
和
383 行删除
+692
-383
ops.py
theano/compile/ops.py
+4
-0
type.py
theano/sandbox/cuda/type.py
+1
-0
scan.py
theano/scan_module/scan.py
+39
-7
scan_op.py
theano/scan_module/scan_op.py
+497
-337
scan_opt.py
theano/scan_module/scan_opt.py
+9
-4
scan_utils.py
theano/scan_module/scan_utils.py
+42
-3
test_scan.py
theano/scan_module/tests/test_scan.py
+90
-29
basic.py
theano/tensor/basic.py
+1
-0
elemwise.py
theano/tensor/elemwise.py
+6
-2
raw_random.py
theano/tensor/raw_random.py
+3
-1
没有找到文件。
theano/compile/ops.py
浏览文件 @
595ec4b2
...
...
@@ -179,3 +179,7 @@ class DeepCopyOp(gof.Op):
deep_copy_op
=
DeepCopyOp
()
# List of Theano Types that one can add an extra dimension and for which
# Scan can deal with.
expandable_types
=
()
theano/sandbox/cuda/type.py
浏览文件 @
595ec4b2
...
...
@@ -411,6 +411,7 @@ class CudaNdarrayType(Type):
def
c_compile_args
(
self
):
return
[]
theano
.
compile
.
ops
.
expandable_types
+=
(
CudaNdarrayType
,)
# Register C code for ViewOp on CudaNdarrayType
theano
.
compile
.
register_view_op_c_code
(
...
...
theano/scan_module/scan.py
浏览文件 @
595ec4b2
...
...
@@ -53,6 +53,7 @@ from theano.tensor import opt
from
theano
import
tensor
from
theano
import
config
from
theano.updates
import
Updates
from
theano.compile
import
ops
import
scan_op
...
...
@@ -843,17 +844,38 @@ def scan(fn,
shared_scan_inputs
=
[]
shared_inner_inputs
=
[]
shared_inner_outputs
=
[]
sit_sot_shared
=
[]
for
input
in
dummy_f
.
maker
.
expanded_inputs
:
if
isinstance
(
input
.
variable
,
SharedVariable
)
and
input
.
update
:
new_var
=
safe_new
(
input
.
variable
)
if
getattr
(
input
.
variable
,
'name'
,
None
)
is
not
None
:
new_var
.
name
=
input
.
variable
.
name
+
'_copy'
shared_inner_inputs
.
append
(
new_var
)
shared_scan_inputs
.
append
(
input
.
variable
)
shared_inner_outputs
.
append
(
input
.
update
)
givens
[
input
.
variable
]
=
new_var
n_shared_outs
+=
1
if
isinstance
(
new_var
.
type
,
ops
.
expandable_types
):
sit_sot_inner_inputs
.
append
(
new_var
)
sit_sot_scan_inputs
.
append
(
scan_utils
.
expand
(
tensor
.
unbroadcast
(
tensor
.
shape_padleft
(
input
.
variable
),
0
),
actual_n_steps
))
sit_sot_inner_outputs
.
append
(
input
.
update
)
# Not that pos is not a negative index. The sign of pos is used
# as a flag to indicate if this output should be part of the
# update rules or part of the standard outputs of scan.
# If `pos` is positive than it corresponds to the standard
# outputs of scan and it refers to output of index `pos`. If `pos`
# is negative that it corresponds to update rules of scan and it
# refers to update rule of index -1 - `pos`.
sit_sot_rightOrder
.
append
(
-
1
-
len
(
sit_sot_shared
))
sit_sot_shared
.
append
(
input
.
variable
)
givens
[
input
.
variable
]
=
new_var
else
:
shared_inner_inputs
.
append
(
new_var
)
shared_scan_inputs
.
append
(
input
.
variable
)
shared_inner_outputs
.
append
(
input
.
update
)
givens
[
input
.
variable
]
=
new_var
n_shared_outs
+=
1
n_sit_sot
=
len
(
sit_sot_inner_inputs
)
## Step 5.4 Outputs with no taps used in the input
n_nit_sot
=
0
nit_sot_inner_outputs
=
[]
...
...
@@ -1041,10 +1063,20 @@ def scan(fn,
nit_sot_rightOrder
)
scan_out_list
=
[
None
]
*
len
(
rightOrder
)
for
idx
,
pos
in
enumerate
(
rightOrder
):
scan_out_list
[
pos
]
=
_scan_out_list
[
idx
]
if
pos
>=
0
:
scan_out_list
[
pos
]
=
_scan_out_list
[
idx
]
else
:
# Not that pos is not a negative index. The sign of pos is used
# as a flag to indicate if this output should be part of the
# update rules or part of the standard outputs of scan.
# If `pos` is positive than it corresponds to the standard
# outputs of scan and it refers to output of index `pos`. If `pos`
# is negative that it corresponds to update rules of scan and it
# refers to update rule of index -1 - `pos`.
update_map
[
sit_sot_shared
[
abs
(
pos
)
-
1
]]
=
_scan_out_list
[
idx
][
-
1
]
scan_out_list
=
[
x
for
x
in
scan_out_list
if
x
is
not
None
]
if
len
(
scan_out_list
)
==
1
:
scan_out_list
=
scan_out_list
[
0
]
elif
len
(
scan_out_list
)
==
0
:
scan_out_list
=
None
return
(
scan_out_list
,
update_map
)
theano/scan_module/scan_op.py
浏览文件 @
595ec4b2
...
...
@@ -34,7 +34,7 @@ from theano.gradient import DisconnectedType
from
theano.compile.profiling
import
ScanProfileStats
import
scan_utils
from
scan_utils
import
safe_new
from
scan_utils
import
safe_new
,
forced_replace
# Logging function for sending warning or info
_logger
=
logging
.
getLogger
(
'theano.scan_module.scan_op'
)
...
...
@@ -259,7 +259,7 @@ class Scan(PureOp):
for
idx
,
(
inner_seq
,
outer_seq
)
in
enumerate
(
zip
(
self
.
inner_seqs
(
self
.
inputs
),
self
.
outer_seqs
(
inputs
))):
if
inner_seq
.
type
.
dtype
!=
outer_seq
[
idx
]
.
type
.
dtype
:
if
inner_seq
.
type
.
dtype
!=
outer_seq
[
0
]
.
type
.
dtype
:
assert
isinstance
(
idx
,
int
)
raise
ValueError
(
err_msg1
%
(
'sequence'
,
...
...
@@ -292,8 +292,11 @@ class Scan(PureOp):
str
(
outer_mitmot
),
argoffset
+
idx
,
outer_mitmot
.
type
.
dtype
,
outer_mitmot
.
type
.
ndim
,
str
(
inner_mitmot
[
ipos
+
k
]),
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
))
inner_mitmot
[
ipos
+
k
]
.
type
.
dtype
,
inner_mitmot
[
ipos
+
k
]
.
type
.
ndim
))
ipos
+=
len
(
itaps
)
for
k
in
xrange
(
len
(
otaps
)):
if
(
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
!=
\
...
...
@@ -304,7 +307,9 @@ class Scan(PureOp):
(
str
(
outer_mitmot
),
argoffset
+
idx
,
outer_mitmot
.
type
.
dtype
,
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
))
outer_mitmot
.
ndim
,
inner_mitmot_outs
[
opos
+
k
]
.
type
.
dtype
,
inner_mitmot_outs
[
opos
+
k
]
.
ndim
))
opos
+=
len
(
otaps
)
argoffset
+=
len
(
self
.
outer_mitmot
(
inputs
))
# Same checks as above but for outputs of type mit_sot
...
...
@@ -329,14 +334,14 @@ class Scan(PureOp):
inner_mitsots
[
ipos
+
k
]
.
type
.
ndim
))
ipos
+=
len
(
itaps
)
if
(
inner_mitsot_out
.
type
.
dtype
!=
outer_mitsot
.
type
.
dtype
or
inner_mitsot_out
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg2
%
(
str
(
outer_mitsot
),
argoffset
+
idx
,
outer_mitsot
.
type
.
dtype
,
outer_mitsot
.
type
.
ndim
,
inner_mitsot_out
.
type
.
dtype
,
inner_mitsot_out
.
type
.
ndim
))
inner_mitsot_out
.
ndim
!=
outer_mitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg2
%
(
str
(
outer_mitsot
),
argoffset
+
idx
,
outer_mitsot
.
type
.
dtype
,
outer_mitsot
.
type
.
ndim
,
inner_mitsot_out
.
type
.
dtype
,
inner_mitsot_out
.
type
.
ndim
))
argoffset
+=
len
(
self
.
outer_mitsot
(
inputs
))
# Same checks as above but for outputs of type sit_sot
...
...
@@ -348,22 +353,22 @@ class Scan(PureOp):
inner_sitsot
.
ndim
!=
outer_sitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg1
%
(
'initial state (outputs_info'
' in scan nomenclature) '
,
str
(
outer_sitsot
),
argoffset
+
idx
,
outer_sitsot
.
type
.
dtype
,
outer_sitsot
.
type
.
ndim
,
str
(
inner_sitsot
),
inner_sitsot
.
type
.
dtype
,
inner_sitsot
.
type
.
ndim
))
str
(
outer_sitsot
),
argoffset
+
idx
,
outer_sitsot
.
type
.
dtype
,
outer_sitsot
.
type
.
ndim
,
str
(
inner_sitsot
),
inner_sitsot
.
type
.
dtype
,
inner_sitsot
.
type
.
ndim
))
if
(
inner_sitsot_out
.
type
.
dtype
!=
outer_sitsot
.
type
.
dtype
or
inner_sitsot_out
.
ndim
!=
outer_sitsot
.
ndim
-
1
):
raise
ValueError
(
err_msg2
%
(
str
(
outer_sitsot
),
argoffset
+
idx
,
outer_sitsot
.
type
.
dtype
,
outer_sitsot
.
type
.
ndim
,
inner_sitsot_out
.
type
.
dtype
,
inner_sitsot_out
.
type
.
ndim
))
raise
ValueError
(
err_msg2
%
(
str
(
outer_sitsot
),
argoffset
+
idx
,
outer_sitsot
.
type
.
dtype
,
outer_sitsot
.
type
.
ndim
,
inner_sitsot_out
.
type
.
dtype
,
inner_sitsot_out
.
type
.
ndim
))
argoffset
+=
len
(
self
.
outer_sitsot
(
inputs
))
# Check that the shared variable and their update rule have the same
...
...
@@ -397,9 +402,7 @@ class Scan(PureOp):
for
inner_nonseq
,
outer_nonseq
in
zip
(
self
.
inner_non_seqs
(
self
.
inputs
),
self
.
outer_non_seqs
(
inputs
)):
if
(
inner_nonseq
.
type
.
dtype
!=
outer_nonseq
.
type
.
dtype
or
inner_nonseq
.
type
.
ndim
!=
outer_nonseq
.
type
.
ndim
):
if
inner_nonseq
.
type
!=
outer_nonseq
.
type
:
raise
ValueError
((
'Argument
%
s given to scan node does not'
' match its correspondance
%
s'
)
%
(
str
(
outer_nonseq
),
str
(
inner_nonseq
)))
...
...
@@ -1194,198 +1197,268 @@ class Scan(PureOp):
for
o
,
x
in
izip
(
node
.
outputs
,
scan_outs
)]
return
scan_outs
### GRAD FUNCTION
def
grad
(
self
,
args
,
g_outs
):
def
get_input_pos
(
self
,
output_index
):
ipos
=
self
.
n_seqs
opos
=
output_index
for
otaps
,
itaps
in
zip
(
self
.
mitmot_out_taps
(),
self
.
mitmot_taps
()):
if
len
(
otaps
)
>
opos
:
return
ipos
else
:
opos
=
opos
-
len
(
otaps
)
ipos
+=
len
(
itaps
)
for
dx
,
taps
in
enumerate
(
self
.
mitsot_taps
()):
if
opos
==
0
:
return
ipos
else
:
opos
=
opos
-
1
ipos
+=
len
(
taps
)
if
opos
<
self
.
info
[
'n_sit_sot'
]:
return
ipos
+
opos
else
:
return
-
1
def
get_output_pos
(
self
,
input_index
):
ipos
=
input_index
opos
=
0
for
otaps
,
itaps
in
zip
(
self
.
mitmot_out_taps
(),
self
.
mitmot_taps
()):
if
len
(
itaps
)
>
ipos
:
return
opos
else
:
opos
+=
len
(
otaps
)
ipos
-=
len
(
itaps
)
for
dx
,
taps
in
enumerate
(
self
.
mitsot_taps
()):
if
len
(
taps
)
>
ipos
:
return
opos
else
:
opos
+=
1
ipos
-=
len
(
taps
)
if
ipos
<
self
.
info
[
'n_sit_sot'
]:
return
ipos
+
opos
else
:
return
-
1
def
get_output_slice_idx
(
self
,
output_index
):
ipos
=
0
opos
=
output_index
for
otaps
in
zip
(
self
.
mitmot_out_taps
()):
if
len
(
otaps
)
>
0
:
return
ipos
else
:
opos
=
opos
-
1
ipos
+=
len
(
otaps
)
return
ipos
+
opos
def
connection_pattern
(
self
,
node
):
# The gradient wrt to n_steps is disconnected
connection_pattern
=
[[
False
for
output
in
node
.
outputs
]]
connection_pattern
+=
[[
False
for
output
in
node
.
outputs
]
for
x
in
node
.
inputs
[
1
:]]
def
compute_gradient
(
y
,
g_y
,
diff_inputs
):
rval
=
[]
gmp
=
{}
consider_inps
=
[
x
for
x
in
theano
.
gof
.
graph
.
inputs
([
y
])
if
x
in
diff_inputs
]
for
x
in
consider_inps
:
try
:
_gmp
=
gradient
.
grad_sources_inputs
(
[(
y
,
g_y
)],
[
x
])
gmp
[
x
]
=
_gmp
[
x
]
except
TypeError
:
# It means the gradient is undefined (which implies
# is connected)
gmp
[
x
]
=
x
return
[
gmp
.
get
(
p
,
None
)
for
p
in
diff_inputs
]
def
_get_inner_outs
(
oidx
):
s
=
0
if
self
.
n_mit_mot
>
0
:
e
=
len
(
self
.
mitmot_out_taps
()[
0
])
else
:
e
=
1
for
p
in
xrange
(
oidx
):
s
=
e
if
p
<
self
.
n_mit_mot
:
e
+=
len
(
self
.
mitmot_out_taps
()[
p
])
else
:
e
+=
1
return
self
.
outputs
[
s
:
e
]
# This discards information about whether incoming gradients are 0
# or disconnected from the cost
# TODO: upgrade scan op to report disconnection correctly
def
strip_disconnected
(
g
):
if
isinstance
(
g
.
type
,
DisconnectedType
):
def
_get_inner_inps
(
iidx
):
s
=
0
if
self
.
n_seqs
>
0
:
e
=
1
else
:
e
=
len
(
self
.
tap_array
[
0
])
p
=
iidx
if
node
.
inputs
[
iidx
+
1
]
in
self
.
outer_nitsot
(
node
):
return
None
return
g
if
node
.
inputs
[
iidx
+
1
]
in
self
.
outer_non_seqs
(
node
):
loc_idx
=
self
.
outer_non_seqs
(
node
)
.
index
(
node
.
inputs
[
iidx
+
1
])
return
[
self
.
inner_non_seqs
(
self
.
inputs
)[
loc_idx
]]
for
p
in
xrange
(
iidx
):
s
=
e
if
p
<
self
.
n_seqs
:
e
+=
1
elif
p
-
self
.
n_seqs
<
len
(
self
.
tap_array
):
e
+=
len
(
self
.
tap_array
[
p
-
self
.
n_seqs
])
else
:
e
+=
1
return
self
.
inputs
[
s
:
e
]
for
oidx
,
out
in
enumerate
(
node
.
outputs
):
for
iidx
,
inp
in
enumerate
(
node
.
inputs
[
1
:]):
ols
=
_get_inner_outs
(
oidx
)
ils
=
_get_inner_inps
(
iidx
)
g_outs
=
[
strip_disconnected
(
g
)
for
g
in
g_outs
]
if
ils
is
None
:
# The gradient should be disconnected
connection_pattern
[
iidx
+
1
][
oidx
]
=
False
else
:
for
inner_out
in
ols
:
# We check for the dtype because inner_out could be
# any Theano type like Generic or RandomState, for
# which we can not impose a dtype
if
hasattr
(
inner_out
,
'dtype'
):
# Note that we do not care about the output of
# this compute gradient. We just care to see if
# it is None or not. (i.e. disconnected or not)
tmp
=
compute_gradient
(
inner_out
,
safe_new
(
inner_out
,
dtype
=
'float64'
),
ils
)
else
:
# It should be undefined not disconnected
tmp
=
ils
if
any
([
x
is
not
None
for
x
in
tmp
]):
connection_pattern
[
iidx
+
1
][
oidx
]
=
True
return
connection_pattern
# 1. forward pass - get the outputs after applying scan
scan_outputs
=
self
(
*
args
)
# 2. make sure they are given as a list
if
not
(
type
(
scan_outputs
)
in
(
list
,
tuple
)):
scan_outputs
=
[
scan_outputs
]
# 3. un-group / unzip the inputs
# Note ! We don't want to use the actual same variable as the ones
# used by the original scan, rather create clones of them
### GRAD FUNCTION
def
grad
(
self
,
inputs
,
dC_douts
):
outs
=
self
(
*
inputs
)
if
not
isinstance
(
outs
,
(
list
,
tuple
)):
outs
=
[
outs
]
# `grad_step` equals the number of steps the original scan node has
# done (if the original scan is a while loop than this number is the
# length of the output sequence)
# We do not know what kind of outputs the original scan has, so we
# try first to see if it has a nit_sot output, then a sit_sot and
# then a mit_sot
if
self
.
n_nit_sot
>
0
:
grad_steps
=
self
.
outer_nitsot_outs
(
outs
)[
0
]
.
shape
[
0
]
elif
self
.
n_sit_sot
>
0
:
grad_steps
=
self
.
outer_sitsot_outs
(
outs
)[
0
]
.
shape
[
0
]
-
1
elif
self
.
n_mit_sot
>
0
:
grad_steps
=
self
.
outer_mitsot_outs
(
outs
)[
0
]
.
shape
[
0
]
+
\
self
.
mintaps
[
self
.
n_mit_mot
]
else
:
grad_steps
=
inputs
[
0
]
rval
=
scan_utils
.
reconstruct_graph
(
self
.
inputs
,
self
.
outputs
,
'_grad'
)
self
.
outputs
)
self_inputs
=
rval
[
0
]
self_outputs
=
rval
[
1
]
#differentiable inputs
diff_inputs
=
(
self
.
inner_seqs
(
self_inputs
)
+
self
.
inner_mitmot
(
self_inputs
)
+
self
.
inner_mitsot
(
self_inputs
)
+
self
.
inner_sitsot
(
self_inputs
)
+
self
.
inner_non_seqs
(
self_inputs
))
diff_outputs
=
(
self
.
inner_mitmot_outs
(
self_outputs
)
+
self
.
inner_mitsot_outs
(
self_outputs
)
+
self
.
inner_sitsot_outs
(
self_outputs
)
+
self
.
inner_nitsot_outs
(
self_outputs
))
seqs
=
self_inputs
[:
self
.
n_seqs
]
offset
=
self
.
n_seqs
n_ins_mit_mot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
in
xrange
(
self
.
n_mit_mot
)])
outs_mit_mot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_mot
]
offset
+=
n_ins_mit_mot
n_ins_mit_sot
=
numpy
.
sum
([
0
]
+
[
len
(
self
.
tap_array
[
x
])
for
x
in
xrange
(
self
.
n_mit_mot
,
self
.
n_mit_mot
+
self
.
n_mit_sot
)])
outs_mit_sot
=
self_inputs
[
offset
:
offset
+
n_ins_mit_sot
]
offset
+=
n_ins_mit_sot
outs_sit_sot
=
self_inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
offset
+=
self
.
n_sit_sot
old_scan_shared_ins
=
self_inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
out_offset
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_nit_sot
+
self
.
n_sit_sot
)
# shared variables as well as the condition
old_scan_shared_outs
=
self_outputs
[
out_offset
:]
arg_offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
old_scan_init
=
args
[
arg_offset
:
arg_offset
+
self
.
n_shared_outs
]
offset
+=
self
.
n_shared_outs
other_args
=
self_inputs
[
offset
:]
# 4. Collect (possibly) differentiable inputs
diff_inputs
=
(
seqs
+
outs_mit_mot
+
outs_mit_sot
+
outs_sit_sot
+
other_args
)
#args[-len(other_args):] )
# 5. construct the function that computes the gradient (we sum over
# the gradients with respect to all outputs)
def
compute_gradient
(
y
,
g_y
):
gmp
=
gradient
.
grad_sources_inputs
(
[(
y
,
g_y
)],
diff_inputs
)
[(
y
,
g_y
)],
[
x
for
x
in
theano
.
gof
.
graph
.
inputs
([
y
])
if
x
in
diff_inputs
])
return
[
gmp
.
get
(
p
,
None
)
for
p
in
diff_inputs
]
# 6. clean the outputs (i.e. remove update rules)
end
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
clean_outputs
=
self_outputs
[:
end
]
g_outs_no_shared
=
g_outs
[:
end
]
# 7.1. empty lists to hold gradients
# List of slices from outputs (used to compute the gradients)
inner_g_outs
=
[]
g_out_slices
=
[]
# List of outputs of the gradient function
inner_gfn_outs
=
[]
# slices of the input
prev_inner_gfn_outs
=
[]
zeros_like_diff_ins
=
[]
pos
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_sot
+
self
.
n_sit_sot
)
offset
=
len
(
args
)
-
len
(
other_args
)
-
pos
# 7.2. generate variables to represent previous steps of g_outs
for
idx
,
diff_in
in
enumerate
(
diff_inputs
):
prev_gfn_out
=
safe_new
(
diff_in
)
if
hasattr
(
diff_in
,
'name'
)
and
diff_in
.
name
:
prev_gfn_out
.
name
=
'g_prev_'
+
diff_in
.
name
else
:
prev_gfn_out
.
name
=
'g_prev_'
+
str
(
idx
)
prev_inner_gfn_outs
.
append
(
prev_gfn_out
)
if
idx
<
pos
:
zeros_like_diff_ins
.
append
(
tensor
.
zeros_like
(
diff_in
))
else
:
zeros_like_diff_ins
.
append
(
tensor
.
zeros_like
(
args
[
idx
+
offset
]))
# 7.3. compute gradients of the inputs given one output
for
dx
,
out
in
enumerate
(
clean_outputs
):
if
g_outs
[
dx
]
!=
None
:
inner_g_out
=
safe_new
(
g_outs
[
dx
][
0
])
dC_dinps_t
=
[
None
for
inp
in
diff_inputs
]
disconnected_dC_dinps_t
=
[
True
for
inp
in
diff_inputs
]
dC_dXts
=
[]
Xts
=
[]
for
idx
,
Xt
in
enumerate
(
diff_outputs
):
# We are looking for x[t-1] for a given x[t]
if
idx
>=
self
.
n_mit_mot_outs
:
Xt_placeholder
=
Xt
.
type
()
Xts
.
append
(
Xt_placeholder
)
if
Xt
not
in
self
.
inner_nitsot_outs
(
self_outputs
):
# What we do here is loop through dC_douts and collect all
# those that are connected to the specific one and do an
# upcast on all of their dtypes to get the dtype for this
# specific output. Deciding if the gradient with this
# specific previous step is defined or not is done somewhere
# else.
dtypes
=
[]
states
=
(
self
.
inner_mitmot
(
self_inputs
)
+
self
.
inner_mitsot
(
self_inputs
)
+
self
.
inner_sitsot
(
self_inputs
))
for
pos
,
inp
in
enumerate
(
states
):
if
inp
in
theano
.
gof
.
graph
.
inputs
([
Xt
]):
oidx
=
self
.
get_output_pos
(
pos
)
if
not
isinstance
(
dC_douts
[
oidx
]
.
type
,
DisconnectedType
):
dtypes
.
append
(
dC_douts
[
oidx
]
.
dtype
)
if
dtypes
:
new_dtype
=
theano
.
scalar
.
upcast
(
*
dtypes
)
else
:
new_dtype
=
theano
.
config
.
floatX
dC_dXt
=
safe_new
(
Xt
,
dtype
=
new_dtype
)
else
:
# We do not have a gradient on this output so we need a
# placeholder, which for now has the same dtype as the
# output
inner_g_out
=
safe_new
(
out
)
###
#### I need to clip the gradient HERE !!
if
g_outs_no_shared
[
dx
]:
g_out_slices
.
append
(
g_outs_no_shared
[
dx
][
0
])
if
isinstance
(
dC_douts
[
idx
]
.
type
,
DisconnectedType
):
continue
dC_dXt
=
safe_new
(
dC_douts
[
idx
][
0
])
dC_dXts
.
append
(
dC_dXt
)
_dC_dinps_t
=
compute_gradient
(
Xt
,
dC_dXt
)
for
jdx
in
xrange
(
len
(
_dC_dinps_t
)):
if
dC_dinps_t
[
jdx
]
is
None
:
dC_dinps_t
[
jdx
]
=
_dC_dinps_t
[
jdx
]
elif
_dC_dinps_t
[
jdx
]:
dC_dinps_t
[
jdx
]
+=
_dC_dinps_t
[
jdx
]
# mask inputs that get no gradients
for
dx
in
xrange
(
len
(
dC_dinps_t
)):
if
not
dC_dinps_t
[
dx
]:
dC_dinps_t
[
dx
]
=
tensor
.
zeros_like
(
diff_inputs
[
dx
])
else
:
g_out_slices
.
append
(
None
)
if
getattr
(
out
,
'name'
,
None
)
is
not
None
:
inner_g_out
.
name
=
'g_'
+
out
.
name
disconnected_dC_dinps_t
[
dx
]
=
False
for
Xt
,
Xt_placeholder
in
zip
(
diff_outputs
[
self
.
n_mit_mot_outs
:],
Xts
):
tmp
=
forced_replace
(
dC_dinps_t
[
dx
],
Xt
,
Xt_placeholder
)
dC_dinps_t
[
dx
]
=
tmp
# construct dX_dtm1
dC_dXtm1s
=
[]
for
pos
,
x
in
enumerate
(
dC_dinps_t
[
self
.
n_seqs
:]):
opos
=
self
.
get_output_pos
(
pos
)
if
opos
>=
0
:
dC_dXtm1s
.
append
(
dC_dXts
[
opos
]
.
type
())
if
x
.
dtype
!=
dC_dXts
[
opos
]
.
dtype
:
dC_dinps_t
[
pos
+
self
.
n_seqs
]
=
\
x
.
astype
(
dC_dXts
[
opos
]
.
dtype
)
else
:
inner_g_out
.
name
=
'g_'
+
str
(
dx
)
inner_g_outs
.
append
(
inner_g_out
)
_g_out
=
inner_g_out
grad_outs
=
compute_gradient
(
out
,
_g_out
)
if
not
inner_gfn_outs
:
for
idx
,
gfn_out
in
enumerate
(
grad_outs
):
if
idx
>=
self
.
n_seqs
:
inner_gfn_outs
.
append
(
prev_inner_gfn_outs
[
idx
])
else
:
inner_gfn_outs
.
append
(
None
)
# 7.4 Sum the gradients
# safety check, some of this inputs might still not be
# differentiable, for those we don't add them to the mix
# (assume their gradient is 0)
for
i
,
(
x
,
y
)
in
enumerate
(
zip
(
grad_outs
,
inner_gfn_outs
)):
if
x
and
y
:
inner_gfn_outs
[
i
]
=
x
+
y
elif
y
:
inner_gfn_outs
[
i
]
=
y
else
:
inner_gfn_outs
[
i
]
=
x
## 8. Mask the outputs that are not differentiable
# backwards pass
for
i
in
xrange
(
len
(
inner_gfn_outs
)):
if
inner_gfn_outs
[
i
]
is
None
:
inner_gfn_outs
[
i
]
=
tensor
.
zeros_like
(
diff_inputs
[
i
])
## 9. Mask the g_outs that are Nones :
for
i
,
out
in
enumerate
(
scan_outputs
):
if
g_outs
[
i
]
is
None
:
try
:
# this try is for catching non ndarray inputs (random
# states) it is more of a safety check ( all random
# states should be after n_outs_not_shared ...
g_outs
[
i
]
=
tensor
.
zeros_like
(
scan_outputs
[
i
])
except
Exception
:
g_outs
[
i
]
=
theano
.
tensor
.
constant
(
numpy
.
array
(
0
,
theano
.
config
.
floatX
))
## 10. Get your sequence in order for the scan:
n_seqs
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
)
offset
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
inner_seqs
=
(
seqs
+
outs_mit_mot
+
outs_mit_sot
+
outs_sit_sot
+
inner_g_outs
[
offset
:
offset
+
self
.
n_nit_sot
])
scan_seqs
=
[
x
[::
-
1
]
for
x
in
args
[
1
:
self
.
n_seqs
+
1
]]
offset
=
0
dC_dXtm1s
.
append
(
x
.
type
())
for
dx
,
dC_dXtm1
in
enumerate
(
dC_dXtm1s
):
dC_dinps_t
[
dx
+
self
.
n_seqs
]
+=
dC_dXtm1
# Construct scan op
# Seqs
outer_inp_seqs
=
[
x
[::
-
1
]
for
x
in
inputs
[
1
:
1
+
self
.
n_seqs
]]
for
idx
in
xrange
(
self
.
n_mit_mot
+
self
.
n_mit_sot
):
mintap
=
numpy
.
min
(
self
.
tap_array
[
idx
])
maxtap
=
numpy
.
max
(
self
.
tap_array
[
idx
])
seq
=
scan_outputs
[
offset
+
idx
]
seq
=
outs
[
idx
]
for
k
in
self
.
tap_array
[
idx
]:
# We cut the sequence such that seq[i] to correspond to
# seq[i-k]
if
maxtap
<
0
:
dim_offset
=
abs
(
maxtap
)
else
:
...
...
@@ -1397,126 +1470,187 @@ class Scan(PureOp):
-
(
maxtap
-
k
+
1
)][::
-
1
]
else
:
nw_seq
=
seq
[
dim_offset
+
k
-
mintap
-
1
:
-
1
][::
-
1
]
if
getattr
(
seq
,
'name'
,
None
)
is
not
None
:
nw_seq
.
name
=
seq
.
name
+
'[
%
d:]'
%
k
scan_seqs
.
append
(
nw_seq
)
offset
+=
self
.
n_mit_sot
for
idx
in
xrange
(
self
.
n_sit_sot
):
seq
=
scan_outputs
[
offset
+
idx
][:
-
1
]
scan_seqs
.
append
(
seq
[::
-
1
])
offset
=
(
self
.
n_mit_mot_outs
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
scan_seqs
+=
[
x
[::
-
1
]
for
x
in
g_outs
[
offset
:
offset
+
self
.
n_nit_sot
]]
scan_mit_mot
=
[]
inner_mit_mot
=
[]
scan_mit_mot_outs
=
[]
mit_mot_taps
=
[]
mit_mot_out_slices
=
[]
outer_inp_seqs
.
append
(
nw_seq
)
outer_inp_seqs
+=
[
x
[:
-
1
][::
-
1
]
for
x
in
self
.
outer_sitsot_outs
(
outs
)]
for
x
in
self
.
outer_nitsot_outs
(
dC_douts
):
if
not
isinstance
(
x
.
type
,
DisconnectedType
):
outer_inp_seqs
.
append
(
x
[::
-
1
])
outer_inp_seqs
+=
[
x
[::
-
1
]
for
x
in
self
.
outer_mitsot_outs
(
outs
)]
outer_inp_seqs
+=
[
x
[::
-
1
]
for
x
in
self
.
outer_sitsot_outs
(
outs
)]
outer_inp_seqs
+=
[
x
[::
-
1
]
for
x
in
self
.
outer_nitsot_outs
(
outs
)]
inner_inp_seqs
=
self
.
inner_seqs
(
self_inputs
)
inner_inp_seqs
+=
self
.
inner_mitmot
(
self_inputs
)
inner_inp_seqs
+=
self
.
inner_mitsot
(
self_inputs
)
inner_inp_seqs
+=
self
.
inner_sitsot
(
self_inputs
)
inner_inp_seqs
+=
self
.
inner_nitsot_outs
(
dC_dXts
)
inner_inp_seqs
+=
Xts
# mitmot
outer_inp_mitmot
=
[]
outer_out_mitmot
=
[]
inner_inp_mitmot
=
[]
inner_out_mitmot
=
[]
mitmot_inp_taps
=
[]
mitmot_out_taps
=
[]
type_outs
=
[]
out_pos
=
0
ins_pos
=
n_seqs
n_mit_mot_outs
=
0
n_mit_mot_ins
=
0
ins_pos
=
self
.
n_seqs
n_mitmot_outs
=
0
n_mitmot_inps
=
0
for
idx
in
xrange
(
self
.
n_mit_mot
):
scan_mit_mot
.
append
(
g_outs
[
idx
][::
-
1
])
mit_mot_taps
.
append
([])
mit_mot_out_slices
.
append
([])
outer_inp_mitmot
.
append
(
dC_douts
[
idx
][::
-
1
])
mitmot_inp_taps
.
append
([])
mitmot_out_taps
.
append
([])
undefined
=
False
disconnected
=
True
for
jdx
in
xrange
(
len
(
self
.
mit_mot_out_slices
[
idx
])):
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
])
mit_mot_taps
[
idx
]
.
append
(
\
-
self
.
mit_mot_out_slices
[
idx
][
jdx
])
n_mit_mot_ins
+=
1
inner_inp_mitmot
.
append
(
dC_dXts
[
out_pos
])
mitmot_inp_taps
[
idx
]
.
append
(
-
self
.
mit_mot_out_slices
[
idx
][
jdx
])
n_mitmot_inps
+=
1
out_pos
+=
1
for
jdx
in
xrange
(
len
(
self
.
tap_array
[
idx
])):
inner_mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_pos
])
scan_mit_mot_outs
.
append
(
\
inner_gfn_outs
[
ins_pos
])
n_mit_mot_ins
+=
1
inner_inp_mitmot
.
append
(
dC_dXtm1s
[
ins_pos
-
self
.
n_seqs
])
inner_out_mitmot
.
append
(
dC_dinps_t
[
ins_pos
])
if
not
disconnected_dC_dinps_t
[
ins_pos
]:
disconnected
=
False
for
_sh
in
self
.
inner_shared
(
self_inputs
):
if
_sh
in
gof
.
graph
.
inputs
([
dC_dinps_t
[
ins_pos
]]):
undefined
=
True
n_mitmot_inps_
+=
1
ins_pos
+=
1
n_mit_mot_outs
+=
1
mit_mot_taps
[
idx
]
.
append
(
-
self
.
tap_array
[
idx
][
jdx
])
mit_mot_out_slices
[
idx
]
.
append
(
\
-
self
.
tap_array
[
idx
][
jdx
])
n_mitmot_outs
+=
1
mitmot_inp_taps
[
idx
]
.
append
(
-
self
.
tap_array
[
idx
][
jdx
])
mitmot_out_taps
[
idx
]
.
append
(
-
self
.
tap_array
[
idx
][
jdx
])
if
undefined
:
type_outs
.
append
(
'undefined'
)
elif
disconnected
:
type_outs
.
append
(
'disconnected'
)
else
:
type_outs
.
append
(
'connected'
)
offset
=
self
.
n_mit_mot
for
idx
in
xrange
(
self
.
n_mit_sot
):
mit
_mot
_taps
.
append
([])
mit
_mot_out_slice
s
.
append
([])
scan_mit_mot
.
append
(
g_
outs
[
idx
+
offset
][::
-
1
])
mit
mot_inp
_taps
.
append
([])
mit
mot_out_tap
s
.
append
([])
outer_inp_mitmot
.
append
(
dC_d
outs
[
idx
+
offset
][::
-
1
])
idx_tap
=
idx
+
self
.
n_mit_mot
inner_inp_mitmot
.
append
(
dC_dXts
[
out_pos
])
out_pos
+=
1
n_mitmot_inps
+=
1
undefined
=
False
disconnected
=
True
mitmot_inp_taps
[
idx
+
offset
]
.
append
(
0
)
for
jdx
in
xrange
(
len
(
self
.
tap_array
[
idx_tap
])):
inner_mit_mot
.
append
(
prev_inner_gfn_outs
[
ins_pos
])
mit_mot_taps
[
idx
+
offset
]
.
append
(
\
inner_inp_mitmot
.
append
(
dC_dXtm1s
[
ins_pos
-
self
.
n_seqs
])
inner_out_mitmot
.
append
(
dC_dinps_t
[
ins_pos
])
mitmot_inp_taps
[
idx
+
offset
]
.
append
(
-
self
.
tap_array
[
idx_tap
][
jdx
])
mit
_mot_out_slices
[
idx
]
.
append
(
\
mit
mot_out_taps
[
idx
]
.
append
(
-
self
.
tap_array
[
idx_tap
][
jdx
])
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
])
n_mit_mot_ins
+=
1
if
not
disconnected_dC_dinps_t
[
ins_pos
]:
disconnected
=
False
for
_sh
in
self
.
inner_shared
(
self_inputs
):
if
_sh
in
gof
.
graph
.
inputs
([
dC_dinps_t
[
ins_pos
]]):
undefined
=
True
n_mitmot_inps
+=
1
ins_pos
+=
1
n_mit_mot_outs
+=
1
inner_mit_mot
.
append
(
inner_g_outs
[
out_pos
])
out_pos
+=
1
n_mit_mot_ins
+=
1
mit_mot_taps
[
idx
+
offset
]
.
append
(
0
)
n_mitmot_outs
+=
1
if
undefined
:
type_outs
.
append
(
'undefined'
)
elif
disconnected
:
type_outs
.
append
(
'disconnected'
)
else
:
type_outs
.
append
(
'connected'
)
offset
+=
self
.
n_mit_sot
for
idx
in
xrange
(
self
.
n_sit_sot
):
mit_mot_taps
.
append
([
0
,
1
])
mit_mot_out_slices
.
append
([
1
])
scan_mit_mot
.
append
(
g_outs
[
idx
+
offset
][::
-
1
])
scan_mit_mot_outs
.
append
(
inner_gfn_outs
[
ins_pos
])
inner_mit_mot
+=
[
inner_g_outs
[
out_pos
],
prev_inner_gfn_outs
[
ins_pos
]]
n_mit_mot_outs
+=
1
mitmot_inp_taps
.
append
([
0
,
1
])
mitmot_out_taps
.
append
([
1
])
undefined
=
False
if
not
isinstance
(
dC_douts
[
idx
+
offset
]
.
type
,
DisconnectedType
):
outer_inp_mitmot
.
append
(
dC_douts
[
idx
+
offset
][::
-
1
])
else
:
outer_inp_mitmot
.
append
(
tensor
.
zeros
(
outs
[
idx
+
offset
]
.
shape
,
dtype
=
dC_dinps_t
[
ins_pos
]
.
dtype
))
inner_out_mitmot
.
append
(
dC_dinps_t
[
ins_pos
])
for
_sh
in
self
.
inner_shared
(
self_inputs
):
if
_sh
in
gof
.
graph
.
inputs
([
dC_dinps_t
[
ins_pos
]]):
undefined
=
True
if
undefined
:
type_outs
.
append
(
'undefined'
)
elif
disconnected_dC_dinps_t
[
ins_pos
]:
type_outs
.
append
(
'disconnected'
)
else
:
type_outs
.
append
(
'connected'
)
inner_inp_mitmot
+=
[
dC_dXts
[
out_pos
],
dC_dXtm1s
[
ins_pos
-
self
.
n_seqs
]]
n_mitmot_outs
+=
1
out_pos
+=
1
ins_pos
+=
1
n_mit_mot_ins
+=
2
n_nit_sot
=
self
.
n_seqs
scan_nit_sot_outs
=
inner_gfn_outs
[:
self
.
n_seqs
]
n_mitmot_inps
+=
2
if
self
.
truncate_gradient
!=
-
1
:
do_steps
=
tensor
.
minimum
(
args
[
0
],
self
.
truncate_gradient
)
else
:
do_steps
=
args
[
0
]
offset
=
(
self
.
n_seqs
+
n_ins_mit_sot
+
n_ins_mit_mot
+
self
.
n_sit_sot
)
# Instead of shared outs use sit_sot
n_sitsot_outs
=
len
(
prev_inner_gfn_outs
[
offset
:])
scan_sitsot_ins
=
prev_inner_gfn_outs
[
offset
:]
scan_sitsot_init
=
[]
for
x
in
zeros_like_diff_ins
[
offset
:]:
shapes
=
[
x
.
shape
[
i
]
for
i
in
xrange
(
x
.
ndim
)]
empty
=
tensor
.
zeros
([
do_steps
+
1
]
+
shapes
,
dtype
=
x
.
dtype
)
scan_sitsot_init
.
append
(
empty
)
scan_sitsot_outs
=
inner_gfn_outs
[
offset
:]
tap_array
=
mit_mot_taps
+
[[
-
1
]
for
k
in
grad_steps
=
tensor
.
minimum
(
grad_steps
,
self
.
truncate_gradient
)
n_nit_sot
=
self
.
n_seqs
inner_out_nitsot
=
dC_dinps_t
[:
self
.
n_seqs
]
inner_out_sitsot
=
dC_dinps_t
[
ins_pos
:]
for
_p
,
vl
in
enumerate
(
inner_out_sitsot
):
undefined
=
False
for
_sh
in
self
.
inner_shared
(
self_inputs
):
if
_sh
in
gof
.
graph
.
inputs
([
vl
]):
undefined
=
True
if
undefined
:
type_outs
.
append
(
'undefined'
)
elif
disconnected_dC_dinps_t
[
_p
+
ins_pos
]:
type_outs
.
append
(
'disconnected'
)
else
:
type_outs
.
append
(
'connected'
)
for
_p
,
vl
in
enumerate
(
inner_out_nitsot
):
undefined
=
False
for
_sh
in
self
.
inner_shared
(
self_inputs
):
if
_sh
in
gof
.
graph
.
inputs
([
vl
]):
undefined
=
True
if
undefined
:
type_outs
.
append
(
'undefined'
)
elif
disconnected_dC_dinps_t
[
_p
]:
type_outs
.
append
(
'disconnected'
)
else
:
type_outs
.
append
(
'connected'
)
inner_inp_sitsot
=
dC_dXtm1s
[
ins_pos
-
self
.
n_seqs
:]
outer_inp_sitsot
=
[
tensor
.
zeros
([
grad_steps
+
1
]
+
[
x
.
shape
[
i
]
for
i
in
xrange
(
x
.
ndim
)],
dtype
=
y
.
dtype
)
for
y
,
x
in
zip
(
inner_inp_sitsot
,
self
.
outer_non_seqs
(
inputs
))]
n_sitsot_outs
=
len
(
outer_inp_sitsot
)
new_tap_array
=
mitmot_inp_taps
+
[[
-
1
]
for
k
in
xrange
(
n_sitsot_outs
)]
info
=
{}
info
[
'n_seqs'
]
=
n_seqs
info
[
'n_seqs'
]
=
len
(
outer_inp_seqs
)
info
[
'n_mit_sot'
]
=
0
info
[
'tap_array'
]
=
tap_array
info
[
'tap_array'
]
=
new_
tap_array
info
[
'gpu'
]
=
False
n_mit_mot
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
)
info
[
'n_mit_mot'
]
=
n_mit_mot
info
[
'n_mit_mot_outs'
]
=
n_mit_mot_outs
info
[
'mit_mot_out_slices'
]
=
mit_mot_out_slices
info
[
'n_mit_mot'
]
=
len
(
outer_inp_mitmot
)
info
[
'n_mit_mot_outs'
]
=
n_mitmot_outs
info
[
'mit_mot_out_slices'
]
=
mitmot_out_taps
info
[
'truncate_gradient'
]
=
self
.
truncate_gradient
info
[
'n_sit_sot'
]
=
n_sitsot_outs
info
[
'n_shared_outs'
]
=
self
.
n_shared_outs
info
[
'n_shared_outs'
]
=
0
info
[
'n_nit_sot'
]
=
n_nit_sot
info
[
'as_while'
]
=
self
.
as_whil
e
info
[
'as_while'
]
=
Fals
e
info
[
'profile'
]
=
self
.
profile
info
[
'destroy_map'
]
=
{}
if
self
.
name
:
...
...
@@ -1524,70 +1658,96 @@ class Scan(PureOp):
else
:
info
[
'name'
]
=
None
info
[
'mode'
]
=
self
.
mode
n_mit_sot
=
0
n_sit_sot
=
0
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
+
self
.
n_shared_outs
)
scan_inputs
=
([
do_steps
]
+
scan_seqs
+
scan_mit_mot
+
scan_sitsot_init
+
old_scan_init
+
[
args
[
0
]
for
x
in
xrange
(
n_nit_sot
)]
+
args
[
offset
:])
offset
=
(
self
.
n_seqs
+
n_ins_mit_mot
+
n_ins_mit_sot
+
self
.
n_sit_sot
+
self
.
n_shared_outs
)
outer_inputs
=
([
grad_steps
]
+
outer_inp_seqs
+
outer_inp_mitmot
+
outer_inp_sitsot
+
[
inputs
[
0
]
for
x
in
xrange
(
n_nit_sot
)]
+
self
.
outer_shared
(
inputs
)
+
self
.
outer_non_seqs
(
inputs
))
inner_other_args
=
self_inputs
[
offset
:]
inner_gfn_ins
=
(
inner_seqs
+
inner_mit_mot
+
scan_sitsot_ins
+
old_scan_shared_ins
+
inner_other_args
)
inner_gfn_outs
=
(
scan_mit_mot_outs
+
scan_sitsot_outs
+
scan_nit_sot_outs
+
old_scan_shared_outs
)
inner_gfn_ins
=
(
inner_inp_seqs
+
inner_inp_mitmot
+
inner_inp_sitsot
+
self
.
inner_shared
(
self_inputs
)
+
self
.
inner_non_seqs
(
self_inputs
))
inner_gfn_outs
=
(
inner_out_mitmot
+
inner_out_sitsot
+
inner_out_nitsot
)
local_op
=
Scan
(
inner_gfn_ins
,
inner_gfn_outs
,
info
)
outputs
=
local_op
(
*
scan
_inputs
)
outputs
=
local_op
(
*
outer
_inputs
)
if
type
(
outputs
)
not
in
(
list
,
tuple
):
outputs
=
[
outputs
]
# Re-order the gradients correctly
gradients
=
[
grad_undefined
(
self
,
0
,
args
[
0
],
'Number of steps'
)]
gradients
=
[
DisconnectedType
()(
)]
offset
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
n_sitsot_outs
)
gradients
+=
[
x
[::
-
1
]
for
x
in
outputs
[
offset
:
offset
+
self
.
n_seqs
]]
for
p
,
(
x
,
t
)
in
enumerate
(
zip
(
outputs
[
offset
:
offset
+
self
.
n_seqs
],
type_outs
[
offset
:
offset
+
self
.
n_seqs
])):
if
t
==
'disconnected'
:
gradients
.
append
(
DisconnectedType
()())
elif
t
==
'undefined'
:
gradients
.
append
(
grad_undefined
(
self
,
p
+
1
,
inputs
[
p
+
1
],
'Depends on a shared variable'
))
else
:
gradients
.
append
(
x
[::
-
1
])
end
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
gradients
+=
[
x
[::
-
1
]
for
x
in
outputs
[:
end
]]
for
p
,
(
x
,
t
)
in
enumerate
(
zip
(
outputs
[:
end
],
type_outs
[:
end
])):
if
t
==
'disconnected'
:
gradients
.
append
(
DisconnectedType
()())
elif
t
==
'undefined'
:
gradients
.
append
(
grad_undefined
(
self
,
p
+
1
+
self
.
n_seqs
,
inputs
[
p
+
1
+
self
.
n_seqs
],
'Depends on a shared variable'
))
else
:
gradients
.
append
(
x
[::
-
1
])
start
=
len
(
gradients
)
gradients
+=
[
grad_undefined
(
self
,
x
+
start
,
args
[
x
+
start
],
'Shared Variable with update'
)
for
x
in
xrange
(
self
.
n_shared_outs
)]
node
=
outs
[
0
]
.
owner
for
idx
in
xrange
(
self
.
n_shared_outs
):
disconnected
=
True
connected_flags
=
self
.
connection_pattern
(
node
)[
idx
+
start
]
for
dC_dout
,
connected
in
zip
(
dC_douts
,
connected_flags
):
if
(
not
isinstance
(
dC_dout
.
type
,
DisconnectedType
)
and
connected
):
disconnected
=
False
if
disconnected
:
gradients
.
append
(
DisconnectedType
()())
else
:
gradients
.
append
(
grad_undefined
(
self
,
idx
,
inputs
[
idx
],
'Shared Variable with update'
))
start
=
len
(
gradients
)
gradients
+=
[
grad_undefined
(
self
,
x
+
start
,
args
[
x
+
start
],
'Dimension of memory buffer for output'
)
gradients
+=
[
DisconnectedType
()()
for
x
in
xrange
(
self
.
n_nit_sot
)]
begin
=
end
end
=
begin
+
n_sitsot_outs
gradients
+=
[
x
[
-
1
]
for
x
in
outputs
[
begin
:
end
]]
for
p
,
(
x
,
t
)
in
enumerate
(
zip
(
outputs
[
begin
:
end
],
type_outs
[
begin
:
end
])):
if
t
==
'disconnected'
:
gradients
.
append
(
DisconnectedType
()())
elif
t
==
'undefined'
:
gradients
.
append
(
grad_undefined
(
self
,
p
+
begin
+
1
,
inputs
[
p
+
begin
+
1
],
'Depends on a shared variable'
))
else
:
gradients
.
append
(
x
[
-
1
])
return
gradients
def
R_op
(
self
,
inputs
,
eval_points
):
...
...
theano/scan_module/scan_opt.py
浏览文件 @
595ec4b2
...
...
@@ -23,7 +23,8 @@ from theano import gof
from
theano.gof.python25
import
maxsize
from
theano.gof.opt
import
Optimizer
from
theano.gof
import
toolbox
,
DestroyHandler
,
InconsistencyError
from
theano.compile
import
deep_copy_op
,
optdb
from
theano.compile
import
optdb
from
theano.compile.function_module
import
deep_copy_op
import
scan_op
import
scan_utils
...
...
@@ -221,7 +222,7 @@ class PushOutNonSeqScan(gof.Optimizer):
'to move some computation fron scan '
'which is not allowed to move. Report '
'this on theano-users list'
),
x
)
outside_ins
=
[
x
.
type
.
filter_variable
(
y
)
for
x
,
y
in
outside_ins
=
[
x
.
type
.
filter_variable
(
y
)
for
x
,
y
in
zip
(
nd
.
inputs
,
outside_ins
)]
nw_outer_node
=
nd
.
op
.
make_node
(
*
outside_ins
)
# Step 2. Create variables for replacements
...
...
@@ -681,14 +682,18 @@ class ScanSaveMem(gof.Optimizer):
if
(
nw_inputs
[
offset
+
idx
]
.
owner
and
isinstance
(
nw_inputs
[
offset
+
idx
]
.
owner
.
op
,
tensor
.
IncSubtensor
)
and
isinstance
(
nw_inputs
[
offset
+
idx
]
.
owner
.
op
.
idx_list
[
0
],
slice
)):
isinstance
(
nw_inputs
[
offset
+
idx
]
.
owner
.
op
.
idx_list
[
0
],
slice
)):
_nw_input
=
nw_inputs
[
offset
+
idx
]
.
owner
.
inputs
[
1
]
cval
=
tensor
.
as_tensor_variable
(
val
)
initl
=
tensor
.
as_tensor_variable
(
init_l
[
i
])
tmp_idx
=
tensor
.
switch
(
cval
<
initl
,
cval
+
initl
,
cval
-
initl
)
tmp
=
pre_greedy_local_optimizer
(
list_opt_slice
,
tmp_idx
)
tmp
=
pre_greedy_local_optimizer
(
list_opt_slice
,
tmp_idx
)
tmp
=
pre_constant_merge
([
tmp
])[
0
]
nw_input
=
scan_utils
.
expand
(
_nw_input
,
tmp
)
...
...
theano/scan_module/scan_utils.py
浏览文件 @
595ec4b2
...
...
@@ -33,7 +33,7 @@ from theano.tensor.basic import get_constant_value
_logger
=
logging
.
getLogger
(
'theano.scan_utils'
)
def
safe_new
(
x
,
tag
=
''
):
def
safe_new
(
x
,
tag
=
''
,
dtype
=
None
):
"""
Internal function that constructs a new variable from x with the same
type, but with a different name (old name + tag). This function is used
...
...
@@ -46,12 +46,18 @@ def safe_new(x, tag=''):
else
:
nw_name
=
None
if
isinstance
(
x
,
theano
.
Constant
):
return
x
.
clone
()
if
dtype
and
x
.
dtype
!=
dtype
:
return
x
.
clone
()
.
astype
(
dtype
)
else
:
return
x
.
clone
()
# Note, as_tensor_variable will convert the Scalar into a
# TensorScalar that will require a ScalarFromTensor op,
# making the pushout optimization fail
elif
isinstance
(
x
,
scalar
.
ScalarVariable
):
nw_x
=
x
.
type
()
if
dtype
:
new_x
=
scalar
.
Scalar
(
dtype
=
dtype
)()
else
:
nw_x
=
x
.
type
()
nw_x
.
name
=
nw_name
return
nw_x
else
:
...
...
@@ -63,6 +69,8 @@ def safe_new(x, tag=''):
# ndarrays
pass
nw_x
=
x
.
type
()
if
dtype
and
nw_x
.
dtype
!=
dtype
:
nw_x
=
nw_x
.
astype
(
dtype
)
nw_x
.
name
=
nw_name
# Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions
...
...
@@ -930,3 +938,34 @@ class scan_args(object):
'mit_sot_in_slices'
)):
getattr
(
res
,
attr
)
.
extend
(
getattr
(
other
,
attr
))
return
res
def
forced_replace
(
out
,
x
,
y
):
"""
:param out: Theano Variable
:param x: Theano Variable
:param y: Theano Variable
This function checks all internal values of the graph that computes the
variable ``out`` for occurances of values identical with ``x``. If such
occurances are encountered then they are replaced with variable ``y``.
For example:
out := sigmoid(wu)*(1-sigmoid(wu))
x := sigmoid(wu)
forced_replace(out, x, y) := y*(1-y)
"""
if
out
is
None
:
return
None
def
traverse
(
graph
,
x
):
if
equal_computations
([
graph
],
[
x
]):
return
[
graph
]
elif
not
graph
.
owner
:
return
[]
else
:
rval
=
[]
for
inp
in
graph
.
owner
.
inputs
:
rval
+=
traverse
(
inp
,
x
)
return
rval
to_replace
=
traverse
(
out
,
x
)
return
clone
(
out
,
replace
=
dict
((
v
,
y
)
for
v
in
to_replace
))
theano/scan_module/tests/test_scan.py
浏览文件 @
595ec4b2
...
...
@@ -513,7 +513,7 @@ class T_Scan(unittest.TestCase):
def
f_rnn
(
u_t
,
x_tm1
,
W_in
,
W
):
return
(
u_t
*
W_in
+
x_tm1
*
W
,
tensor
.
cast
(
u_t
+
x_tm1
,
'int64'
))
tensor
.
cast
(
u_t
+
x_tm1
,
'int64'
))
u
=
theano
.
tensor
.
fvector
(
'u'
)
x0
=
theano
.
tensor
.
fscalar
(
'x0'
)
...
...
@@ -561,7 +561,6 @@ class T_Scan(unittest.TestCase):
scan_node
=
scan_node
[
0
]
assert
scan_node
.
op
.
gpu
# simple rnn, one input, one state, weights for each; input/state
# are vectors, weights are scalars; using shared variables
def
test_one_sequence_one_output_weights_shared
(
self
):
...
...
@@ -1124,6 +1123,29 @@ class T_Scan(unittest.TestCase):
assert
numpy
.
allclose
(
W1
.
get_value
(),
numpy_W1
)
assert
numpy
.
allclose
(
W2
.
get_value
(),
numpy_W2
)
def
test_grad_dtype_change
(
self
):
x
=
tensor
.
fscalar
(
'x'
)
y
=
tensor
.
fscalar
(
'y'
)
c
=
tensor
.
iscalar
(
'c'
)
def
inner_fn
(
cond
,
x
,
y
):
new_cond
=
tensor
.
cast
(
tensor
.
switch
(
cond
,
x
,
y
),
'int32'
)
new_x
=
tensor
.
switch
(
cond
,
tensor
.
nnet
.
sigmoid
(
y
*
x
),
x
)
new_y
=
tensor
.
switch
(
cond
,
y
,
tensor
.
nnet
.
sigmoid
(
x
))
return
new_cond
,
new_x
,
new_y
values
,
_
=
theano
.
scan
(
inner_fn
,
outputs_info
=
[
c
,
x
,
y
],
n_steps
=
10
,
truncate_gradient
=-
1
,
go_backwards
=
False
)
gX
,
gY
=
tensor
.
grad
(
values
[
1
]
.
sum
(),
[
x
,
y
])
f
=
theano
.
function
([
c
,
x
,
y
],
[
gX
,
gY
],
allow_input_downcast
=
True
)
# Check for runtime errors
f
(
numpy
.
int32
(
0
),
numpy
.
float32
(
1.
),
numpy
.
float32
(
.
5
))
def
test_simple_shared_mrg_random
(
self
):
theano_rng
=
theano
.
sandbox
.
rng_mrg
.
MRG_RandomStreams
(
utt
.
fetch_seed
())
...
...
@@ -1470,8 +1492,11 @@ class T_Scan(unittest.TestCase):
truncate_gradient
=-
1
,
go_backwards
=
False
)
vparams
=
[
v_u1
,
v_u2
,
v_x0
,
v_y0
,
vW_in1
]
# y0 is actually not used in the computation of the cost
params
=
[
u1
,
u2
,
x0
,
y0
,
W_in1
]
gparams
=
theano
.
tensor
.
grad
(
cost
,
params
)
gparams
=
theano
.
grad
(
cost
,
params
,
disconnected_inputs
=
'ignore'
)
grad_fn
=
theano
.
function
([
u1
,
u2
,
x0
,
y0
,
W_in1
],
gparams
,
updates
=
updates
,
...
...
@@ -1711,8 +1736,8 @@ class T_Scan(unittest.TestCase):
def
f_rnn_cmpl
(
u_t
,
x_tm1
,
W_in
):
trng1
=
theano
.
tensor
.
shared_randomstreams
.
RandomStreams
(
123
)
x_t
=
theano
.
dot
(
u_t
,
W_in
)
+
x_tm1
+
\
trng1
.
uniform
(
low
=-.
1
,
high
=.
1
)
rnd_nb
=
trng1
.
uniform
(
low
=-.
1
,
high
=.
1
)
x_t
=
theano
.
dot
(
u_t
,
W_in
)
+
x_tm1
+
rnd_nb
x_t
=
theano
.
tensor
.
cast
(
x_t
,
dtype
=
theano
.
config
.
floatX
)
return
x_t
...
...
@@ -1874,8 +1899,8 @@ class T_Scan(unittest.TestCase):
def
test_scan_extra_inputs_hessian
(
self
):
x
=
theano
.
tensor
.
vector
(
'x'
)
A
=
theano
.
tensor
.
matrix
(
'A'
)
fc1
=
theano
.
shared
(
0.5
,
name
=
'fc1'
)
fc2
=
theano
.
shared
(
0.9
,
name
=
'fc2'
)
fc1
=
theano
.
shared
(
0.5
,
name
=
'fc1'
)
fc2
=
theano
.
shared
(
0.9
,
name
=
'fc2'
)
y
=
fc1
*
theano
.
dot
(
x
*
x
,
theano
.
dot
(
A
,
x
))
y
.
name
=
'y'
gy
=
theano
.
tensor
.
grad
(
y
,
x
)
...
...
@@ -2316,12 +2341,13 @@ class T_Scan(unittest.TestCase):
allow_input_downcast
=
True
,
mode
=
mode_with_opt
)
self
.
assertTrue
(
numpy
.
allclose
(
f
([
1
,
2
,
3
]),
2.
/
3
))
#theano.printing.debugprint(f, print_type=True)
topo
=
f
.
maker
.
fgraph
.
toposort
()
# this new assert is here to test if scan_merging works ..
nb_scan
=
len
([
n
for
n
in
topo
if
isinstance
(
n
.
op
,
theano
.
scan_module
.
scan_op
.
Scan
)])
self
.
assertTrue
(
nb_scan
==
1
)
# For this to work we need an optimization that it will be pushed in
# a new pull request
self
.
assertTrue
(
nb_scan
==
2
)
nb_shape_i
=
len
([
n
for
n
in
topo
if
isinstance
(
n
.
op
,
theano
.
tensor
.
opt
.
Shape_i
)])
if
theano
.
config
.
mode
!=
'FAST_COMPILE'
:
...
...
@@ -2511,10 +2537,10 @@ class T_Scan(unittest.TestCase):
def
rnn_fn
(
_u
,
_y
,
_W
):
srng
=
theano
.
tensor
.
shared_randomstreams
.
RandomStreams
(
seed
)
sl_o
=
theano
.
tensor
.
tanh
(
theano
.
tensor
.
dot
(
_W
,
(
_u
+
_y
+
\
srng
.
uniform
(
size
=
v_h0
.
shape
)
*
numpy
.
asarray
(
1e-6
,
dtype
=
floatX
))
))
return
sl_o
tmp_val
=
_u
+
_y
+
srng
.
uniform
(
size
=
v_h0
.
shape
)
*
\
numpy
.
asarray
(
1e-6
,
dtype
=
floatX
)
sl_o
=
theano
.
tensor
.
tanh
(
theano
.
tensor
.
dot
(
_W
,
tmp_val
))
return
sl_o
,
tmp_val
u
=
theano
.
tensor
.
matrix
(
'U'
)
h0
=
theano
.
tensor
.
vector
(
'h0'
)
...
...
@@ -2527,9 +2553,9 @@ class T_Scan(unittest.TestCase):
_W
=
theano
.
tensor
.
specify_shape
(
W
,
v_W
.
shape
)
_W
.
name
=
'_W'
o
,
_
=
theano
.
scan
(
rnn_fn
,
[
o
,
_
]
,
_
=
theano
.
scan
(
rnn_fn
,
sequences
=
_u
,
outputs_info
=
_h0
,
outputs_info
=
[
_h0
,
None
]
,
non_sequences
=
_W
,
name
=
'rnn_fn'
)
o
=
o
[
-
1
]
...
...
@@ -3110,6 +3136,7 @@ class T_Scan(unittest.TestCase):
loss
,
no_default_updates
=
True
,
allow_input_downcast
=
True
)
gw
,
gx
=
tensor
.
grad
(
loss
,
[
w
,
xinit
])
grad_fn
=
theano
.
function
([
xinit
,
w
],
[
gx
,
gw
],
allow_input_downcast
=
True
)
...
...
@@ -3135,6 +3162,20 @@ class T_Scan(unittest.TestCase):
raise
Exception
(
theano
.
tensor
.
verify_grad
.
E_grad
,
(
max_err
,
1e-2
,
max_err_pos
))
def
test_grad_numeric_shared
(
self
):
shared_var
=
theano
.
shared
(
numpy
.
float32
(
1.
))
def
inner_fn
():
return
[],
{
shared_var
:
shared_var
+
numpy
.
float32
(
1.
)}
_
,
updates
=
theano
.
scan
(
inner_fn
,
n_steps
=
10
,
truncate_gradient
=-
1
,
go_backwards
=
False
)
cost
=
updates
.
values
()[
0
]
g_sh
=
tensor
.
grad
(
cost
,
shared_var
)
fgrad
=
theano
.
function
([],
g_sh
)
assert
fgrad
()
==
1
def
test_rop_mitmot
(
self
):
# this test is a copy paste from the script given by Justin Bayer to
# reproduce this bug
...
...
@@ -3188,17 +3229,17 @@ class T_Scan(unittest.TestCase):
Hp
=
tensor
.
Rop
(
d_cost_wrt_pars
,
pars
,
p
)
def
test_seq_tap_bug_jeremiah
(
self
):
inp
=
numpy
.
arange
(
10
)
.
reshape
(
-
1
,
1
)
.
astype
(
theano
.
config
.
floatX
)
exp_out
=
numpy
.
zeros
((
10
,
1
))
.
astype
(
theano
.
config
.
floatX
)
inp
=
numpy
.
arange
(
10
)
.
reshape
(
-
1
,
1
)
.
astype
(
theano
.
config
.
floatX
)
exp_out
=
numpy
.
zeros
((
10
,
1
))
.
astype
(
theano
.
config
.
floatX
)
exp_out
[
4
:]
=
inp
[:
-
4
]
def
onestep
(
x
,
x_tm4
):
return
x
,
x_tm4
seq
=
tensor
.
matrix
()
initial_value
=
theano
.
shared
(
numpy
.
zeros
((
4
,
1
),
initial_value
=
theano
.
shared
(
numpy
.
zeros
((
4
,
1
),
dtype
=
theano
.
config
.
floatX
))
outputs_info
=
[{
'initial'
:
initial_value
,
'taps'
:
[
-
4
]},
None
]
outputs_info
=
[{
'initial'
:
initial_value
,
'taps'
:
[
-
4
]},
None
]
results
,
updates
=
theano
.
scan
(
fn
=
onestep
,
sequences
=
seq
,
outputs_info
=
outputs_info
)
...
...
@@ -3208,27 +3249,49 @@ class T_Scan(unittest.TestCase):
def
test_borrow_bug_jeremiah
(
self
):
# This test fails if scan uses wrongly the borrow flag
inp
=
numpy
.
arange
(
10
)
.
reshape
(
-
1
,
1
)
.
astype
(
theano
.
config
.
floatX
)
exp_out
=
numpy
.
zeros
((
10
,
1
))
.
astype
(
theano
.
config
.
floatX
)
inp
=
numpy
.
arange
(
10
)
.
reshape
(
-
1
,
1
)
.
astype
(
theano
.
config
.
floatX
)
exp_out
=
numpy
.
zeros
((
10
,
1
))
.
astype
(
theano
.
config
.
floatX
)
exp_out
[
4
:]
=
inp
[:
-
4
]
def
onestep
(
x
,
x_tm4
):
return
x
,
x_tm4
seq
=
tensor
.
matrix
()
initial_value
=
theano
.
shared
(
numpy
.
zeros
((
4
,
1
),
initial_value
=
theano
.
shared
(
numpy
.
zeros
((
4
,
1
),
dtype
=
theano
.
config
.
floatX
))
outputs_info
=
[{
'initial'
:
initial_value
,
'taps'
:
[
-
4
]},
None
]
outputs_info
=
[{
'initial'
:
initial_value
,
'taps'
:
[
-
4
]},
None
]
results
,
_
=
theano
.
scan
(
fn
=
onestep
,
sequences
=
seq
,
outputs_info
=
outputs_info
)
sharedvar
=
theano
.
shared
(
numpy
.
zeros
((
1
,
1
),
sharedvar
=
theano
.
shared
(
numpy
.
zeros
((
1
,
1
),
dtype
=
theano
.
config
.
floatX
))
updates
=
{
sharedvar
:
results
[
0
][
-
1
:]}
updates
=
{
sharedvar
:
results
[
0
][
-
1
:]}
f
=
theano
.
function
([
seq
],
results
[
1
],
updates
=
updates
)
assert
numpy
.
all
(
exp_out
==
f
(
inp
))
def
test_grad_connectivity_matrix
(
self
):
def
inner_fn
(
x_tm1
,
y_tm1
,
z_tm1
):
x_tm1
.
name
=
'x'
y_tm1
.
name
=
'y'
z_tm1
.
name
=
'z'
return
x_tm1
**
2
,
x_tm1
+
y_tm1
,
x_tm1
+
1
x0
=
tensor
.
vector
(
'X'
)
y0
=
tensor
.
vector
(
'y0'
)
z0
=
tensor
.
vector
(
'Z'
)
[
x
,
y
,
z
],
_
=
theano
.
scan
(
inner_fn
,
outputs_info
=
[
x0
,
y0
,
z0
],
n_steps
=
10
)
cost
=
(
x
+
y
+
z
)
.
sum
()
gx0
=
tensor
.
grad
(
cost
,
x0
)
# defined
gy0
=
tensor
.
grad
(
cost
,
y0
)
# defined
self
.
assertRaises
(
ValueError
,
tensor
.
grad
,
cost
,
z0
)
cost
=
x
.
sum
()
self
.
assertRaises
(
ValueError
,
tensor
.
grad
,
cost
,
y0
)
def
test_speed
():
#
# This function prints out the speed of very simple recurrent
...
...
@@ -3576,9 +3639,7 @@ if __name__ == '__main__':
def
test_compute_test_value
():
"""
Verify that test values can be used with scan.
"""
# Verify that test values can be used with scan.
backup
=
theano
.
config
.
compute_test_value
theano
.
config
.
compute_test_value
=
'raise'
try
:
...
...
@@ -3590,7 +3651,7 @@ def test_compute_test_value():
fn
=
lambda
u
,
v
:
u
+
v
,
sequences
=
[
x
,
y
])
assert
not
_
z
.
name
=
'z'
z
.
name
=
'z'
# The gradient computation used to crash before 6af465e.
g
=
tensor
.
grad
(
z
.
sum
(),
x
)
#f = theano.function([x], g)
...
...
theano/tensor/basic.py
浏览文件 @
595ec4b2
...
...
@@ -1076,6 +1076,7 @@ class TensorType(Type):
"""
return
numpy
.
zeros
(
shape
,
dtype
=
self
.
dtype
)
theano
.
compile
.
ops
.
expandable_types
+=
(
TensorType
,)
# Register TensorType C code for ViewOp.
theano
.
compile
.
register_view_op_c_code
(
...
...
theano/tensor/elemwise.py
浏览文件 @
595ec4b2
...
...
@@ -390,8 +390,12 @@ PyArray_SetBaseObject(%(res)s, (PyObject*)%(basename)s);
# Do not make the DimShuffle inplace as an optimization at the
# canonicalization optimization phase will remove the inplace.
# The inplace will be reintroduced automatically later in the graph.
return
[
DimShuffle
(
gz
.
type
.
broadcastable
,
grad_order
)(
Elemwise
(
scalar
.
identity
)(
gz
))]
if
'int'
in
inp
[
0
]
.
dtype
:
return
[
theano
.
tensor
.
zeros_like
(
inp
[
0
],
dtype
=
theano
.
config
.
floatX
)]
else
:
return
[
DimShuffle
(
gz
.
type
.
broadcastable
,
grad_order
)(
Elemwise
(
scalar
.
identity
)(
gz
))]
class
DimShufflePrinter
:
...
...
theano/tensor/raw_random.py
浏览文件 @
595ec4b2
...
...
@@ -256,7 +256,9 @@ class RandomFunction(gof.Op):
out
[
0
]
=
rval
def
grad
(
self
,
inputs
,
outputs
):
return
[
None
for
i
in
inputs
]
return
[
theano
.
gradient
.
grad_undefined
(
self
,
k
,
inp
,
'No gradient defined through raw random numbers op'
)
for
k
,
inp
in
enumerate
(
inputs
)]
def
R_op
(
self
,
inputs
,
eval_points
):
return
[
None
for
i
in
eval_points
]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论