Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d5f65500
提交
d5f65500
authored
1月 03, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
1月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use direct theano.gof imports in theano.scan.op
上级
9c445e0a
显示空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
25 行增加
和
23 行删除
+25
-23
op.py
theano/scan/op.py
+25
-23
没有找到文件。
theano/scan/op.py
浏览文件 @
d5f65500
...
@@ -53,23 +53,27 @@ from collections import OrderedDict
...
@@ -53,23 +53,27 @@ from collections import OrderedDict
import
numpy
as
np
import
numpy
as
np
import
theano
import
theano
from
theano
import
compile
,
gof
,
gradient
,
tensor
from
theano
import
tensor
from
theano.compile.builders
import
infer_shape
from
theano.compile.builders
import
infer_shape
from
theano.compile.function
import
function
from
theano.compile.function
import
function
from
theano.compile.io
import
In
,
Out
from
theano.compile.io
import
In
,
Out
from
theano.compile.mode
import
AddFeatureOptimizer
from
theano.compile.mode
import
AddFeatureOptimizer
,
get_mode
from
theano.compile.profiling
import
ScanProfileStats
from
theano.compile.profiling
import
ScanProfileStats
,
register_profiler_printer
from
theano.configdefaults
import
config
from
theano.configdefaults
import
config
from
theano.gof
import
Apply
,
Op
from
theano.gof.fg
import
MissingInputError
from
theano.gof.graph
import
equal_computations
,
io_connection_pattern
from
theano.gof.graph
import
Apply
,
Variable
,
equal_computations
from
theano.gof.graph
import
inputs
as
graph_inputs
from
theano.gof.graph
import
io_connection_pattern
from
theano.gof.op
import
Op
,
ops_with_inner_function
from
theano.gof.toolbox
import
NoOutputFromInplace
from
theano.gof.toolbox
import
NoOutputFromInplace
from
theano.gradient
import
DisconnectedType
,
NullType
,
grad_undefined
from
theano.gradient
import
DisconnectedType
,
NullType
,
grad
,
grad
_undefined
from
theano.link.c.basic
import
CLinker
from
theano.link.c.basic
import
CLinker
from
theano.link.c.exceptions
import
MissingGXX
from
theano.link.c.exceptions
import
MissingGXX
from
theano.link.utils
import
raise_with_op
from
theano.link.utils
import
raise_with_op
from
theano.scan.utils
import
Validator
,
forced_replace
,
hash_listsDictsTuples
,
safe_new
from
theano.scan.utils
import
Validator
,
forced_replace
,
hash_listsDictsTuples
,
safe_new
from
theano.tensor
import
TensorType
,
as_tensor_variable
from
theano.tensor
.basic
import
as_tensor_variable
from
theano.tensor.opt
import
Shape_i
from
theano.tensor.opt
import
Shape_i
from
theano.tensor.type
import
TensorType
__docformat__
=
"restructedtext en"
__docformat__
=
"restructedtext en"
...
@@ -169,7 +173,7 @@ class Scan(Op):
...
@@ -169,7 +173,7 @@ class Scan(Op):
if
self
.
as_while
:
if
self
.
as_while
:
self
.
output_types
=
self
.
output_types
[:
-
1
]
self
.
output_types
=
self
.
output_types
[:
-
1
]
mode_instance
=
compile
.
mode
.
get_mode
(
self
.
mode
)
mode_instance
=
get_mode
(
self
.
mode
)
# Clone mode_instance, altering "allow_gc" for the linker,
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
# and adding a message if we profile
if
self
.
name
:
if
self
.
name
:
...
@@ -202,11 +206,9 @@ class Scan(Op):
...
@@ -202,11 +206,9 @@ class Scan(Op):
self
.
_hash_inner_graph
=
self
.
info
[
"gpu_hash"
]
self
.
_hash_inner_graph
=
self
.
info
[
"gpu_hash"
]
else
:
else
:
# Do the missing inputs check here to have the error early.
# Do the missing inputs check here to have the error early.
for
var
in
theano
.
gof
.
graph
.
inputs
(
self
.
outputs
,
self
.
inputs
):
for
var
in
graph_
inputs
(
self
.
outputs
,
self
.
inputs
):
if
var
not
in
self
.
inputs
and
not
isinstance
(
var
,
theano
.
Constant
):
if
var
not
in
self
.
inputs
and
not
isinstance
(
var
,
theano
.
Constant
):
raise
theano
.
gof
.
MissingInputError
(
raise
MissingInputError
(
f
"ScanOp is missing an input: {repr(var)}"
)
f
"ScanOp is missing an input: {repr(var)}"
)
self
.
_cmodule_key
=
CLinker
()
.
cmodule_key_variables
(
self
.
_cmodule_key
=
CLinker
()
.
cmodule_key_variables
(
self
.
inputs
,
self
.
outputs
,
[]
self
.
inputs
,
self
.
outputs
,
[]
)
)
...
@@ -317,7 +319,7 @@ class Scan(Op):
...
@@ -317,7 +319,7 @@ class Scan(Op):
the inner function)
the inner function)
"""
"""
assert
np
.
all
(
isinstance
(
i
,
gof
.
Variable
)
for
i
in
inputs
)
assert
np
.
all
(
isinstance
(
i
,
Variable
)
for
i
in
inputs
)
# Check that the number of inputs to the Scan node corresponds to
# Check that the number of inputs to the Scan node corresponds to
# the number of inputs of the inner function of scan
# the number of inputs of the inner function of scan
n_outer_ins
=
len
(
inputs
)
-
len
(
self
.
outer_nitsot
(
inputs
))
-
1
n_outer_ins
=
len
(
inputs
)
-
len
(
self
.
outer_nitsot
(
inputs
))
-
1
...
@@ -2173,7 +2175,7 @@ class Scan(Op):
...
@@ -2173,7 +2175,7 @@ class Scan(Op):
wrt
=
[
wrt
=
[
x
x
for
x
in
theano
.
gof
.
graph
.
inputs
(
y_s
)
for
x
in
graph_
inputs
(
y_s
)
if
(
x
in
diff_inputs
)
if
(
x
in
diff_inputs
)
and
get_inp_idx
(
self_inputs
.
index
(
x
))
in
connected_inputs
and
get_inp_idx
(
self_inputs
.
index
(
x
))
in
connected_inputs
]
]
...
@@ -2188,7 +2190,7 @@ class Scan(Op):
...
@@ -2188,7 +2190,7 @@ class Scan(Op):
# to X.
# to X.
known_grads
=
OrderedDict
([(
k
.
copy
(),
v
)
for
(
k
,
v
)
in
known_grads
.
items
()])
known_grads
=
OrderedDict
([(
k
.
copy
(),
v
)
for
(
k
,
v
)
in
known_grads
.
items
()])
grads
=
grad
ient
.
grad
(
grads
=
grad
(
cost
=
None
,
cost
=
None
,
known_grads
=
known_grads
,
known_grads
=
known_grads
,
wrt
=
wrt
,
wrt
=
wrt
,
...
@@ -2238,7 +2240,7 @@ class Scan(Op):
...
@@ -2238,7 +2240,7 @@ class Scan(Op):
)
)
for
pos
,
inp
in
enumerate
(
states
):
for
pos
,
inp
in
enumerate
(
states
):
if
inp
in
theano
.
gof
.
graph
.
inputs
([
Xt
]):
if
inp
in
graph_
inputs
([
Xt
]):
# Get the index of the outer output that to which
# Get the index of the outer output that to which
# the state variable 'inp' corresponds.
# the state variable 'inp' corresponds.
outer_oidx
=
self
.
var_mappings
[
"outer_out_from_inner_inp"
][
outer_oidx
=
self
.
var_mappings
[
"outer_out_from_inner_inp"
][
...
@@ -2456,7 +2458,7 @@ class Scan(Op):
...
@@ -2456,7 +2458,7 @@ class Scan(Op):
disconnected
=
False
disconnected
=
False
for
_sh
in
self
.
inner_shared
(
self_inputs
):
for
_sh
in
self
.
inner_shared
(
self_inputs
):
if
_sh
in
g
of
.
graph
.
inputs
([
dC_dinps_t
[
ins_pos
]]):
if
_sh
in
g
raph_
inputs
([
dC_dinps_t
[
ins_pos
]]):
through_shared
=
True
through_shared
=
True
ins_pos
+=
1
ins_pos
+=
1
...
@@ -2511,7 +2513,7 @@ class Scan(Op):
...
@@ -2511,7 +2513,7 @@ class Scan(Op):
if
not
disconnected_dC_dinps_t
[
ins_pos
]:
if
not
disconnected_dC_dinps_t
[
ins_pos
]:
disconnected
=
False
disconnected
=
False
for
_sh
in
self
.
inner_shared
(
self_inputs
):
for
_sh
in
self
.
inner_shared
(
self_inputs
):
if
_sh
in
g
of
.
graph
.
inputs
([
dC_dinps_t
[
ins_pos
]]):
if
_sh
in
g
raph_
inputs
([
dC_dinps_t
[
ins_pos
]]):
through_shared
=
True
through_shared
=
True
n_mitmot_inps
+=
1
n_mitmot_inps
+=
1
...
@@ -2559,7 +2561,7 @@ class Scan(Op):
...
@@ -2559,7 +2561,7 @@ class Scan(Op):
inner_out_mitmot
.
append
(
dC_dinps_t
[
ins_pos
])
inner_out_mitmot
.
append
(
dC_dinps_t
[
ins_pos
])
for
_sh
in
self
.
inner_shared
(
self_inputs
):
for
_sh
in
self
.
inner_shared
(
self_inputs
):
if
_sh
in
g
of
.
graph
.
inputs
([
dC_dinps_t
[
ins_pos
]]):
if
_sh
in
g
raph_
inputs
([
dC_dinps_t
[
ins_pos
]]):
through_shared
=
True
through_shared
=
True
if
isinstance
(
dC_dinps_t
[
ins_pos
]
.
type
,
NullType
):
if
isinstance
(
dC_dinps_t
[
ins_pos
]
.
type
,
NullType
):
...
@@ -2583,7 +2585,7 @@ class Scan(Op):
...
@@ -2583,7 +2585,7 @@ class Scan(Op):
for
_p
,
vl
in
enumerate
(
inner_out_sitsot
):
for
_p
,
vl
in
enumerate
(
inner_out_sitsot
):
through_shared
=
False
through_shared
=
False
for
_sh
in
self
.
inner_shared
(
self_inputs
):
for
_sh
in
self
.
inner_shared
(
self_inputs
):
if
_sh
in
g
of
.
graph
.
inputs
([
vl
]):
if
_sh
in
g
raph_
inputs
([
vl
]):
through_shared
=
True
through_shared
=
True
if
isinstance
(
vl
.
type
,
NullType
):
if
isinstance
(
vl
.
type
,
NullType
):
type_outs
.
append
(
vl
.
type
.
why_null
)
type_outs
.
append
(
vl
.
type
.
why_null
)
...
@@ -2602,7 +2604,7 @@ class Scan(Op):
...
@@ -2602,7 +2604,7 @@ class Scan(Op):
for
_p
,
vl
in
enumerate
(
inner_out_nitsot
):
for
_p
,
vl
in
enumerate
(
inner_out_nitsot
):
through_shared
=
False
through_shared
=
False
for
_sh
in
self
.
inner_shared
(
self_inputs
):
for
_sh
in
self
.
inner_shared
(
self_inputs
):
if
_sh
in
g
of
.
graph
.
inputs
([
vl
]):
if
_sh
in
g
raph_
inputs
([
vl
]):
through_shared
=
True
through_shared
=
True
if
isinstance
(
vl
.
type
,
NullType
):
if
isinstance
(
vl
.
type
,
NullType
):
type_outs
.
append
(
vl
.
type
.
why_null
)
type_outs
.
append
(
vl
.
type
.
why_null
)
...
@@ -3043,10 +3045,10 @@ class Scan(Op):
...
@@ -3043,10 +3045,10 @@ class Scan(Op):
# Since Scan is an op that contains a Theano compiled function, it is
# Since Scan is an op that contains a Theano compiled function, it is
# useful to let DebugMode know about it.
# useful to let DebugMode know about it.
gof
.
ops_with_inner_function
[
Scan
]
=
"fn"
ops_with_inner_function
[
Scan
]
=
"fn"
@
theano.compile.profiling.
register_profiler_printer
@register_profiler_printer
def
profile_printer
(
def
profile_printer
(
message
,
compile_time
,
fct_call_time
,
apply_time
,
apply_cimpl
,
outputs_size
,
file
message
,
compile_time
,
fct_call_time
,
apply_time
,
apply_cimpl
,
outputs_size
,
file
):
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论