Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6c6d81c6
提交
6c6d81c6
authored
1月 03, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
1月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rename theano.gof.graph.inputs to graph_inputs
上级
27605fae
隐藏空白字符变更
内嵌
并排
正在显示
15 个修改的文件
包含
55 行增加
和
44 行删除
+55
-44
test_graph.py
tests/gof/test_graph.py
+7
-5
test_basic.py
tests/scan/test_basic.py
+6
-6
test_basic.py
tests/tensor/random/test_basic.py
+8
-9
test_gradient.py
tests/test_gradient.py
+1
-1
builders.py
theano/compile/builders.py
+3
-1
debugmode.py
theano/compile/debugmode.py
+1
-1
types.py
theano/compile/function/types.py
+12
-4
formatting.py
theano/d3viz/formatting.py
+1
-1
graph.py
theano/gof/graph.py
+1
-1
toolbox.py
theano/gof/toolbox.py
+4
-4
printing.py
theano/printing.py
+1
-1
basic.py
theano/scan/basic.py
+1
-1
op.py
theano/scan/op.py
+1
-1
opt.py
theano/scan/opt.py
+1
-1
utils.py
theano/scan/utils.py
+7
-7
没有找到文件。
tests/gof/test_graph.py
浏览文件 @
6c6d81c6
...
@@ -14,7 +14,7 @@ from theano.gof.graph import (
...
@@ -14,7 +14,7 @@ from theano.gof.graph import (
clone
,
clone
,
equal_computations
,
equal_computations
,
general_toposort
,
general_toposort
,
inputs
,
graph_
inputs
,
io_toposort
,
io_toposort
,
is_in_ancestors
,
is_in_ancestors
,
list_of_nodes
,
list_of_nodes
,
...
@@ -132,8 +132,10 @@ class TestClone(X):
...
@@ -132,8 +132,10 @@ class TestClone(X):
_
,
new
=
clone
([
r1
,
r2
,
r5
],
node
.
outputs
,
False
)
_
,
new
=
clone
([
r1
,
r2
,
r5
],
node
.
outputs
,
False
)
new_node
=
new
[
0
]
.
owner
new_node
=
new
[
0
]
.
owner
new_node
.
inputs
=
[
MyVariable
(
7
),
MyVariable
(
8
)]
new_node
.
inputs
=
[
MyVariable
(
7
),
MyVariable
(
8
)]
assert
self
.
str
(
inputs
(
new_node
.
outputs
),
new_node
.
outputs
)
==
[
"MyOp(R7, R8)"
]
assert
self
.
str
(
graph_inputs
(
new_node
.
outputs
),
new_node
.
outputs
)
==
[
assert
self
.
str
(
inputs
(
node
.
outputs
),
node
.
outputs
)
==
[
"MyOp(R7, R8)"
]
assert
self
.
str
(
graph_inputs
(
node
.
outputs
),
node
.
outputs
)
==
[
"MyOp(MyOp(R1, R2), R5)"
"MyOp(MyOp(R1, R2), R5)"
]
]
...
@@ -384,7 +386,7 @@ def test_ancestors():
...
@@ -384,7 +386,7 @@ def test_ancestors():
assert
res_list
==
[
o2
,
r3
,
o1
]
assert
res_list
==
[
o2
,
r3
,
o1
]
def
test_inputs
():
def
test_
graph_
inputs
():
r1
,
r2
,
r3
=
MyVariable
(
1
),
MyVariable
(
2
),
MyVariable
(
3
)
r1
,
r2
,
r3
=
MyVariable
(
1
),
MyVariable
(
2
),
MyVariable
(
3
)
o1
=
MyOp
(
r1
,
r2
)
o1
=
MyOp
(
r1
,
r2
)
...
@@ -392,7 +394,7 @@ def test_inputs():
...
@@ -392,7 +394,7 @@ def test_inputs():
o2
=
MyOp
(
r3
,
o1
)
o2
=
MyOp
(
r3
,
o1
)
o2
.
name
=
"o2"
o2
.
name
=
"o2"
res
=
inputs
([
o2
],
blockers
=
None
)
res
=
graph_
inputs
([
o2
],
blockers
=
None
)
res_list
=
list
(
res
)
res_list
=
list
(
res
)
assert
res_list
==
[
r3
,
r1
,
r2
]
assert
res_list
==
[
r3
,
r1
,
r2
]
...
...
tests/scan/test_basic.py
浏览文件 @
6c6d81c6
...
@@ -1966,7 +1966,7 @@ class TestScan:
...
@@ -1966,7 +1966,7 @@ class TestScan:
f1
=
z
*
(
x
+
y
)
**
2
+
5
f1
=
z
*
(
x
+
y
)
**
2
+
5
f2
=
theano
.
clone
(
f1
,
replace
=
None
,
strict
=
True
,
share_inputs
=
True
)
f2
=
theano
.
clone
(
f1
,
replace
=
None
,
strict
=
True
,
share_inputs
=
True
)
f2_inp
=
theano
.
gof
.
graph
.
inputs
([
f2
])
f2_inp
=
theano
.
gof
.
graph
.
graph_
inputs
([
f2
])
assert
z
in
f2_inp
assert
z
in
f2_inp
assert
x
in
f2_inp
assert
x
in
f2_inp
...
@@ -1982,7 +1982,7 @@ class TestScan:
...
@@ -1982,7 +1982,7 @@ class TestScan:
f1
=
z
*
(
x
+
y
)
**
2
+
5
f1
=
z
*
(
x
+
y
)
**
2
+
5
f2
=
theano
.
clone
(
f1
,
replace
=
None
,
strict
=
True
,
share_inputs
=
False
)
f2
=
theano
.
clone
(
f1
,
replace
=
None
,
strict
=
True
,
share_inputs
=
False
)
f2_inp
=
theano
.
gof
.
graph
.
inputs
([
f2
])
f2_inp
=
theano
.
gof
.
graph
.
graph_
inputs
([
f2
])
assert
z
not
in
f2_inp
assert
z
not
in
f2_inp
assert
x
not
in
f2_inp
assert
x
not
in
f2_inp
...
@@ -2001,7 +2001,7 @@ class TestScan:
...
@@ -2001,7 +2001,7 @@ class TestScan:
f2
=
theano
.
clone
(
f2
=
theano
.
clone
(
f1
,
replace
=
OrderedDict
([(
y
,
y2
)]),
strict
=
True
,
share_inputs
=
True
f1
,
replace
=
OrderedDict
([(
y
,
y2
)]),
strict
=
True
,
share_inputs
=
True
)
)
f2_inp
=
theano
.
gof
.
graph
.
inputs
([
f2
])
f2_inp
=
theano
.
gof
.
graph
.
graph_
inputs
([
f2
])
assert
z
in
f2_inp
assert
z
in
f2_inp
assert
x
in
f2_inp
assert
x
in
f2_inp
assert
y2
in
f2_inp
assert
y2
in
f2_inp
...
@@ -2019,7 +2019,7 @@ class TestScan:
...
@@ -2019,7 +2019,7 @@ class TestScan:
f2
=
theano
.
clone
(
f2
=
theano
.
clone
(
f1
,
replace
=
OrderedDict
([(
y
,
y2
)]),
strict
=
False
,
share_inputs
=
True
f1
,
replace
=
OrderedDict
([(
y
,
y2
)]),
strict
=
False
,
share_inputs
=
True
)
)
f2_inp
=
theano
.
gof
.
graph
.
inputs
([
f2
])
f2_inp
=
theano
.
gof
.
graph
.
graph_
inputs
([
f2
])
assert
z
in
f2_inp
assert
z
in
f2_inp
assert
x
in
f2_inp
assert
x
in
f2_inp
assert
y2
in
f2_inp
assert
y2
in
f2_inp
...
@@ -2035,7 +2035,7 @@ class TestScan:
...
@@ -2035,7 +2035,7 @@ class TestScan:
f1
=
z
*
(
x
+
y
)
**
2
+
5
f1
=
z
*
(
x
+
y
)
**
2
+
5
f2
=
theano
.
clone
(
f1
,
replace
=
[(
y
,
y2
)],
strict
=
True
,
share_inputs
=
False
)
f2
=
theano
.
clone
(
f1
,
replace
=
[(
y
,
y2
)],
strict
=
True
,
share_inputs
=
False
)
f2_inp
=
theano
.
gof
.
graph
.
inputs
([
f2
])
f2_inp
=
theano
.
gof
.
graph
.
graph_
inputs
([
f2
])
assert
z
not
in
f2_inp
assert
z
not
in
f2_inp
assert
x
not
in
f2_inp
assert
x
not
in
f2_inp
assert
y2
not
in
f2_inp
assert
y2
not
in
f2_inp
...
@@ -2051,7 +2051,7 @@ class TestScan:
...
@@ -2051,7 +2051,7 @@ class TestScan:
f1
=
z
*
(
x
+
y
)
**
2
+
5
f1
=
z
*
(
x
+
y
)
**
2
+
5
f2
=
theano
.
clone
(
f1
,
replace
=
[(
y
,
y2
)],
strict
=
False
,
share_inputs
=
False
)
f2
=
theano
.
clone
(
f1
,
replace
=
[(
y
,
y2
)],
strict
=
False
,
share_inputs
=
False
)
f2_inp
=
theano
.
gof
.
graph
.
inputs
([
f2
])
f2_inp
=
theano
.
gof
.
graph
.
graph_
inputs
([
f2
])
assert
z
not
in
f2_inp
assert
z
not
in
f2_inp
assert
x
not
in
f2_inp
assert
x
not
in
f2_inp
assert
y2
not
in
f2_inp
assert
y2
not
in
f2_inp
...
...
tests/tensor/random/test_basic.py
浏览文件 @
6c6d81c6
...
@@ -8,8 +8,7 @@ from pytest import fixture, importorskip, raises
...
@@ -8,8 +8,7 @@ from pytest import fixture, importorskip, raises
import
theano.tensor
as
tt
import
theano.tensor
as
tt
from
theano
import
change_flags
,
config
from
theano
import
change_flags
,
config
from
theano.gof.fg
import
FunctionGraph
from
theano.gof.fg
import
FunctionGraph
from
theano.gof.graph
import
Variable
from
theano.gof.graph
import
Variable
,
graph_inputs
from
theano.gof.graph
import
inputs
as
tt_inputs
from
theano.gof.op
import
get_test_value
from
theano.gof.op
import
get_test_value
from
theano.tensor.random.basic
import
(
from
theano.tensor.random.basic
import
(
bernoulli
,
bernoulli
,
...
@@ -145,7 +144,7 @@ def test_normal_ShapeFeature():
...
@@ -145,7 +144,7 @@ def test_normal_ShapeFeature():
d_rv
.
tag
.
test_value
d_rv
.
tag
.
test_value
fg
=
FunctionGraph
(
fg
=
FunctionGraph
(
[
i
for
i
in
tt
_inputs
([
d_rv
])
if
not
isinstance
(
i
,
tt
.
Constant
)],
[
i
for
i
in
graph
_inputs
([
d_rv
])
if
not
isinstance
(
i
,
tt
.
Constant
)],
[
d_rv
],
[
d_rv
],
clone
=
False
,
clone
=
False
,
features
=
[
tt
.
opt
.
ShapeFeature
()],
features
=
[
tt
.
opt
.
ShapeFeature
()],
...
@@ -296,7 +295,7 @@ def test_mvnormal_ShapeFeature():
...
@@ -296,7 +295,7 @@ def test_mvnormal_ShapeFeature():
d_rv
=
multivariate_normal
(
tt
.
ones
((
M_tt
,)),
tt
.
eye
(
M_tt
),
size
=
2
)
d_rv
=
multivariate_normal
(
tt
.
ones
((
M_tt
,)),
tt
.
eye
(
M_tt
),
size
=
2
)
fg
=
FunctionGraph
(
fg
=
FunctionGraph
(
[
i
for
i
in
tt
_inputs
([
d_rv
])
if
not
isinstance
(
i
,
tt
.
Constant
)],
[
i
for
i
in
graph
_inputs
([
d_rv
])
if
not
isinstance
(
i
,
tt
.
Constant
)],
[
d_rv
],
[
d_rv
],
clone
=
False
,
clone
=
False
,
features
=
[
tt
.
opt
.
ShapeFeature
()],
features
=
[
tt
.
opt
.
ShapeFeature
()],
...
@@ -305,7 +304,7 @@ def test_mvnormal_ShapeFeature():
...
@@ -305,7 +304,7 @@ def test_mvnormal_ShapeFeature():
s1
,
s2
=
fg
.
shape_feature
.
shape_of
[
d_rv
]
s1
,
s2
=
fg
.
shape_feature
.
shape_of
[
d_rv
]
assert
get_test_value
(
s1
)
==
2
assert
get_test_value
(
s1
)
==
2
assert
M_tt
in
tt
_inputs
([
s2
])
assert
M_tt
in
graph
_inputs
([
s2
])
# Test broadcasted shapes
# Test broadcasted shapes
mean
=
tt
.
tensor
(
config
.
floatX
,
[
True
,
False
])
mean
=
tt
.
tensor
(
config
.
floatX
,
[
True
,
False
])
...
@@ -319,7 +318,7 @@ def test_mvnormal_ShapeFeature():
...
@@ -319,7 +318,7 @@ def test_mvnormal_ShapeFeature():
d_rv
=
multivariate_normal
(
mean
,
cov
,
size
=
[
2
,
3
])
d_rv
=
multivariate_normal
(
mean
,
cov
,
size
=
[
2
,
3
])
fg
=
FunctionGraph
(
fg
=
FunctionGraph
(
[
i
for
i
in
tt
_inputs
([
d_rv
])
if
not
isinstance
(
i
,
tt
.
Constant
)],
[
i
for
i
in
graph
_inputs
([
d_rv
])
if
not
isinstance
(
i
,
tt
.
Constant
)],
[
d_rv
],
[
d_rv
],
clone
=
False
,
clone
=
False
,
features
=
[
tt
.
opt
.
ShapeFeature
()],
features
=
[
tt
.
opt
.
ShapeFeature
()],
...
@@ -392,7 +391,7 @@ def test_dirichlet_ShapeFeature():
...
@@ -392,7 +391,7 @@ def test_dirichlet_ShapeFeature():
d_rv
=
dirichlet
(
tt
.
ones
((
M_tt
,
N_tt
)),
name
=
"Gamma"
)
d_rv
=
dirichlet
(
tt
.
ones
((
M_tt
,
N_tt
)),
name
=
"Gamma"
)
fg
=
FunctionGraph
(
fg
=
FunctionGraph
(
[
i
for
i
in
tt
_inputs
([
d_rv
])
if
not
isinstance
(
i
,
tt
.
Constant
)],
[
i
for
i
in
graph
_inputs
([
d_rv
])
if
not
isinstance
(
i
,
tt
.
Constant
)],
[
d_rv
],
[
d_rv
],
clone
=
False
,
clone
=
False
,
features
=
[
tt
.
opt
.
ShapeFeature
()],
features
=
[
tt
.
opt
.
ShapeFeature
()],
...
@@ -400,8 +399,8 @@ def test_dirichlet_ShapeFeature():
...
@@ -400,8 +399,8 @@ def test_dirichlet_ShapeFeature():
s1
,
s2
=
fg
.
shape_feature
.
shape_of
[
d_rv
]
s1
,
s2
=
fg
.
shape_feature
.
shape_of
[
d_rv
]
assert
M_tt
in
tt
_inputs
([
s1
])
assert
M_tt
in
graph
_inputs
([
s1
])
assert
N_tt
in
tt
_inputs
([
s2
])
assert
N_tt
in
graph
_inputs
([
s2
])
def
test_poisson_samples
():
def
test_poisson_samples
():
...
...
tests/test_gradient.py
浏览文件 @
6c6d81c6
...
@@ -22,7 +22,7 @@ def grad_sources_inputs(sources, inputs):
...
@@ -22,7 +22,7 @@ def grad_sources_inputs(sources, inputs):
the new interface so the tests don't need to be rewritten.
the new interface so the tests don't need to be rewritten.
"""
"""
if
inputs
is
None
:
if
inputs
is
None
:
inputs
=
list
(
theano
.
gof
.
graph
.
inputs
([
source
[
0
]
for
source
in
sources
]))
inputs
=
list
(
theano
.
gof
.
graph
.
graph_
inputs
([
source
[
0
]
for
source
in
sources
]))
return
dict
(
return
dict
(
zip
(
zip
(
inputs
,
inputs
,
...
...
theano/compile/builders.py
浏览文件 @
6c6d81c6
...
@@ -339,7 +339,9 @@ class OpFromGraph(Op):
...
@@ -339,7 +339,9 @@ class OpFromGraph(Op):
# To correctly support shared variables the inner fct should
# To correctly support shared variables the inner fct should
# not see them. Otherwise there is a problem with the gradient.
# not see them. Otherwise there is a problem with the gradient.
self
.
shared_inputs
=
[
self
.
shared_inputs
=
[
var
for
var
in
gof
.
graph
.
inputs
(
outputs
)
if
isinstance
(
var
,
SharedVariable
)
var
for
var
in
gof
.
graph
.
graph_inputs
(
outputs
)
if
isinstance
(
var
,
SharedVariable
)
]
]
shared_vars
=
[
var
.
type
()
for
var
in
self
.
shared_inputs
]
shared_vars
=
[
var
.
type
()
for
var
in
self
.
shared_inputs
]
...
...
theano/compile/debugmode.py
浏览文件 @
6c6d81c6
...
@@ -2416,7 +2416,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
...
@@ -2416,7 +2416,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
outputs
=
[
self
.
wrap_out
(
o
)
for
o
in
outputs
]
outputs
=
[
self
.
wrap_out
(
o
)
for
o
in
outputs
]
_inputs
=
list
(
_inputs
=
list
(
gof
.
graph
.
inputs
(
gof
.
graph
.
graph_
inputs
(
[
o
.
variable
for
o
in
outputs
]
[
o
.
variable
for
o
in
outputs
]
+
[
i
.
update
for
i
in
inputs
if
getattr
(
i
,
"update"
,
False
)]
+
[
i
.
update
for
i
in
inputs
if
getattr
(
i
,
"update"
,
False
)]
)
)
...
...
theano/compile/function/types.py
浏览文件 @
6c6d81c6
...
@@ -1206,7 +1206,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
...
@@ -1206,7 +1206,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
}
}
# We can't use fgraph.inputs as this don't include Constant Value.
# We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs
=
list
(
gof
.
graph
.
inputs
(
fgraph
.
outputs
))
all_graph_inputs
=
list
(
gof
.
graph
.
graph_
inputs
(
fgraph
.
outputs
))
has_destroyers_attr
=
hasattr
(
fgraph
,
"has_destroyers"
)
has_destroyers_attr
=
hasattr
(
fgraph
,
"has_destroyers"
)
for
i
in
range
(
len
(
fgraph
.
outputs
)):
for
i
in
range
(
len
(
fgraph
.
outputs
)):
...
@@ -1454,10 +1454,18 @@ class FunctionMaker:
...
@@ -1454,10 +1454,18 @@ class FunctionMaker:
t2
=
f2
.
outputs
[
i
]
t2
=
f2
.
outputs
[
i
]
givens
=
dict
(
givens
=
dict
(
zip
(
gof
.
graph
.
inputs
([
t1
]),
gof
.
graph
.
inputs
([
t2
]))
zip
(
gof
.
graph
.
graph_inputs
([
t1
]),
gof
.
graph
.
graph_inputs
([
t2
]),
)
)
)
temp
=
dict
(
zip
(
gof
.
graph
.
inputs
([
t1
]),
gof
.
graph
.
inputs
([
t2
])))
temp
=
dict
(
zip
(
gof
.
graph
.
graph_inputs
([
t1
]),
gof
.
graph
.
graph_inputs
([
t2
]),
)
)
# hack to remove inconstent entry in givens
# hack to remove inconstent entry in givens
# seems to work that but source of inconsistency
# seems to work that but source of inconsistency
...
@@ -1554,7 +1562,7 @@ class FunctionMaker:
...
@@ -1554,7 +1562,7 @@ class FunctionMaker:
inputs
=
[
self
.
wrap_in
(
i
)
for
i
in
inputs
]
inputs
=
[
self
.
wrap_in
(
i
)
for
i
in
inputs
]
outputs
=
[
self
.
wrap_out
(
o
)
for
o
in
outputs
]
outputs
=
[
self
.
wrap_out
(
o
)
for
o
in
outputs
]
_inputs
=
list
(
_inputs
=
list
(
gof
.
graph
.
inputs
(
gof
.
graph
.
graph_
inputs
(
[
o
.
variable
for
o
in
outputs
]
[
o
.
variable
for
o
in
outputs
]
+
[
i
.
update
for
i
in
inputs
if
getattr
(
i
,
"update"
,
False
)]
+
[
i
.
update
for
i
in
inputs
if
getattr
(
i
,
"update"
,
False
)]
)
)
...
...
theano/d3viz/formatting.py
浏览文件 @
6c6d81c6
...
@@ -11,7 +11,7 @@ import theano
...
@@ -11,7 +11,7 @@ import theano
from
theano.compile
import
Function
,
builders
from
theano.compile
import
Function
,
builders
from
theano.gof.fg
import
FunctionGraph
from
theano.gof.fg
import
FunctionGraph
from
theano.gof.graph
import
Apply
,
Constant
,
Variable
from
theano.gof.graph
import
Apply
,
Constant
,
Variable
from
theano.gof.graph
import
inputs
as
graph_inputs
from
theano.gof.graph
import
graph_
inputs
as
graph_inputs
from
theano.printing
import
pydot_imported
,
pydot_imported_msg
from
theano.printing
import
pydot_imported
,
pydot_imported_msg
...
...
theano/gof/graph.py
浏览文件 @
6c6d81c6
...
@@ -757,7 +757,7 @@ def ancestors(
...
@@ -757,7 +757,7 @@ def ancestors(
yield
from
walk
(
graphs
,
expand
,
False
)
yield
from
walk
(
graphs
,
expand
,
False
)
def
inputs
(
def
graph_
inputs
(
graphs
:
Iterable
[
Variable
],
blockers
:
Collection
[
Variable
]
=
None
graphs
:
Iterable
[
Variable
],
blockers
:
Collection
[
Variable
]
=
None
)
->
Generator
[
Variable
,
None
,
None
]:
)
->
Generator
[
Variable
,
None
,
None
]:
"""Return the inputs required to compute the given Variables.
"""Return the inputs required to compute the given Variables.
...
...
theano/gof/toolbox.py
浏览文件 @
6c6d81c6
...
@@ -10,7 +10,7 @@ import numpy as np
...
@@ -10,7 +10,7 @@ import numpy as np
import
theano
import
theano
from
theano.configdefaults
import
config
from
theano.configdefaults
import
config
from
theano.gof.graph
import
equal_computations
,
inputs
,
io_toposort
,
vars_between
from
theano.gof.graph
import
equal_computations
,
graph_
inputs
,
io_toposort
,
vars_between
class
AlreadyThere
(
Exception
):
class
AlreadyThere
(
Exception
):
...
@@ -807,10 +807,10 @@ def is_same_graph_with_merge(var1, var2, givens=None):
...
@@ -807,10 +807,10 @@ def is_same_graph_with_merge(var1, var2, givens=None):
vars
=
copied
[
0
:
2
]
vars
=
copied
[
0
:
2
]
givens
=
copied
[
2
]
givens
=
copied
[
2
]
# Create FunctionGraph.
# Create FunctionGraph.
graph_inputs
=
list
(
inputs
(
vars
))
inputs
=
list
(
graph_
inputs
(
vars
))
# The clone isn't needed as we did a deepcopy and we cloning will
# The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens.
# break the mapping in givens.
fgraph
=
theano
.
gof
.
fg
.
FunctionGraph
(
graph_
inputs
,
vars
,
clone
=
False
)
fgraph
=
theano
.
gof
.
fg
.
FunctionGraph
(
inputs
,
vars
,
clone
=
False
)
# Perform Variable substitution.
# Perform Variable substitution.
for
to_replace
,
replace_by
in
givens
.
items
():
for
to_replace
,
replace_by
in
givens
.
items
():
fgraph
.
replace
(
to_replace
,
replace_by
)
fgraph
.
replace
(
to_replace
,
replace_by
)
...
@@ -893,7 +893,7 @@ def is_same_graph(var1, var2, givens=None):
...
@@ -893,7 +893,7 @@ def is_same_graph(var1, var2, givens=None):
in_xs
=
[]
in_xs
=
[]
in_ys
=
[]
in_ys
=
[]
# Compute the sets of all variables found in each computational graph.
# Compute the sets of all variables found in each computational graph.
inputs_var
=
list
(
map
(
inputs
,
([
var1
],
[
var2
])))
inputs_var
=
list
(
map
(
graph_
inputs
,
([
var1
],
[
var2
])))
all_vars
=
[
all_vars
=
[
set
(
vars_between
(
v_i
,
v_o
))
set
(
vars_between
(
v_i
,
v_o
))
for
v_i
,
v_o
in
((
inputs_var
[
0
],
[
var1
]),
(
inputs_var
[
1
],
[
var2
]))
for
v_i
,
v_o
in
((
inputs_var
[
0
],
[
var1
]),
(
inputs_var
[
1
],
[
var2
]))
...
...
theano/printing.py
浏览文件 @
6c6d81c6
...
@@ -820,7 +820,7 @@ def pydotprint(
...
@@ -820,7 +820,7 @@ def pydotprint(
fct
=
fct
.
outputs
fct
=
fct
.
outputs
assert
isinstance
(
fct
,
(
list
,
tuple
))
assert
isinstance
(
fct
,
(
list
,
tuple
))
assert
all
(
isinstance
(
v
,
gof
.
Variable
)
for
v
in
fct
)
assert
all
(
isinstance
(
v
,
gof
.
Variable
)
for
v
in
fct
)
fct
=
gof
.
FunctionGraph
(
inputs
=
list
(
gof
.
graph
.
inputs
(
fct
)),
outputs
=
fct
)
fct
=
gof
.
FunctionGraph
(
inputs
=
list
(
gof
.
graph
.
graph_
inputs
(
fct
)),
outputs
=
fct
)
profile
=
None
profile
=
None
outputs
=
fct
.
outputs
outputs
=
fct
.
outputs
topo
=
fct
.
toposort
()
topo
=
fct
.
toposort
()
...
...
theano/scan/basic.py
浏览文件 @
6c6d81c6
...
@@ -803,7 +803,7 @@ def scan(
...
@@ -803,7 +803,7 @@ def scan(
and
not
isinstance
(
x
,
SharedVariable
)
and
not
isinstance
(
x
,
SharedVariable
)
and
not
isinstance
(
x
,
gof
.
Constant
)
and
not
isinstance
(
x
,
gof
.
Constant
)
),
),
gof
.
graph
.
inputs
(
fake_outputs
),
gof
.
graph
.
graph_
inputs
(
fake_outputs
),
)
)
extra_inputs
=
[
x
for
x
in
all_inputs
if
x
not
in
args
+
fake_nonseqs
]
extra_inputs
=
[
x
for
x
in
all_inputs
if
x
not
in
args
+
fake_nonseqs
]
non_seqs
+=
extra_inputs
non_seqs
+=
extra_inputs
...
...
theano/scan/op.py
浏览文件 @
6c6d81c6
...
@@ -62,7 +62,7 @@ from theano.compile.profiling import ScanProfileStats, register_profiler_printer
...
@@ -62,7 +62,7 @@ from theano.compile.profiling import ScanProfileStats, register_profiler_printer
from
theano.configdefaults
import
config
from
theano.configdefaults
import
config
from
theano.gof.fg
import
MissingInputError
from
theano.gof.fg
import
MissingInputError
from
theano.gof.graph
import
Apply
,
Variable
,
equal_computations
from
theano.gof.graph
import
Apply
,
Variable
,
equal_computations
from
theano.gof.graph
import
inputs
as
graph_inputs
from
theano.gof.graph
import
graph_
inputs
as
graph_inputs
from
theano.gof.graph
import
io_connection_pattern
from
theano.gof.graph
import
io_connection_pattern
from
theano.gof.op
import
Op
,
ops_with_inner_function
from
theano.gof.op
import
Op
,
ops_with_inner_function
from
theano.gof.toolbox
import
NoOutputFromInplace
from
theano.gof.toolbox
import
NoOutputFromInplace
...
...
theano/scan/opt.py
浏览文件 @
6c6d81c6
...
@@ -150,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
...
@@ -150,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
# Same for the outer graph, initialized w/ number of steps
# Same for the outer graph, initialized w/ number of steps
nw_outer
=
[
node
.
inputs
[
0
]]
nw_outer
=
[
node
.
inputs
[
0
]]
all_ins
=
list
(
gof
.
graph
.
inputs
(
op_outs
))
all_ins
=
list
(
gof
.
graph
.
graph_
inputs
(
op_outs
))
for
idx
in
range
(
op
.
n_seqs
):
for
idx
in
range
(
op
.
n_seqs
):
node_inp
=
node
.
inputs
[
idx
+
1
]
node_inp
=
node
.
inputs
[
idx
+
1
]
if
(
if
(
...
...
theano/scan/utils.py
浏览文件 @
6c6d81c6
...
@@ -268,7 +268,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
...
@@ -268,7 +268,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
return
new_graph
return
new_graph
graphs
=
list
(
graphs
)
graphs
=
list
(
graphs
)
inputs_
=
list
(
set
(
list
(
gof
.
graph
.
inputs
(
graphs
))
+
list
(
additional_inputs
)))
inputs_
=
list
(
set
(
list
(
gof
.
graph
.
graph_
inputs
(
graphs
))
+
list
(
additional_inputs
)))
# perform any desired replacement of input variables. these
# perform any desired replacement of input variables. these
# aren't replaced by the local optimizer approach because they are
# aren't replaced by the local optimizer approach because they are
...
@@ -280,7 +280,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
...
@@ -280,7 +280,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
if
new_input
is
not
input_
if
new_input
is
not
input_
]
]
graphs
=
clone
(
graphs
,
share_inputs
=
True
,
replace
=
replacements
)
graphs
=
clone
(
graphs
,
share_inputs
=
True
,
replace
=
replacements
)
inputs_
=
list
(
set
(
list
(
gof
.
graph
.
inputs
(
graphs
))
+
list
(
additional_inputs
)))
inputs_
=
list
(
set
(
list
(
gof
.
graph
.
graph_
inputs
(
graphs
))
+
list
(
additional_inputs
)))
fg
=
gof
.
fg
.
FunctionGraph
(
inputs_
,
graphs
,
clone
=
False
)
fg
=
gof
.
fg
.
FunctionGraph
(
inputs_
,
graphs
,
clone
=
False
)
...
@@ -330,7 +330,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
...
@@ -330,7 +330,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
replacements
=
[
wrapped_replacer
(
o
)
for
o
in
node
.
outputs
]
replacements
=
[
wrapped_replacer
(
o
)
for
o
in
node
.
outputs
]
# Add inputs to replacement graphs as inputs to this `fgraph`
# Add inputs to replacement graphs as inputs to this `fgraph`
for
i
in
gof
.
graph
.
inputs
(
replacements
):
for
i
in
gof
.
graph
.
graph_
inputs
(
replacements
):
fgraph
.
add_input
(
i
)
fgraph
.
add_input
(
i
)
return
replacements
return
replacements
...
@@ -370,7 +370,7 @@ def _map_variables_inner(
...
@@ -370,7 +370,7 @@ def _map_variables_inner(
other_inputs
=
[]
other_inputs
=
[]
constants
=
[]
constants
=
[]
for
input_
in
gof
.
graph
.
inputs
([
new_graph
]):
for
input_
in
gof
.
graph
.
graph_
inputs
([
new_graph
]):
if
isinstance
(
input_
,
gof
.
Variable
):
if
isinstance
(
input_
,
gof
.
Variable
):
if
isinstance
(
input_
,
gof
.
Constant
):
if
isinstance
(
input_
,
gof
.
Constant
):
constants
.
append
(
input_
)
constants
.
append
(
input_
)
...
@@ -714,7 +714,7 @@ def scan_can_remove_outs(op, out_idxs):
...
@@ -714,7 +714,7 @@ def scan_can_remove_outs(op, out_idxs):
"""
"""
non_removable
=
[
o
for
i
,
o
in
enumerate
(
op
.
outputs
)
if
i
not
in
out_idxs
]
non_removable
=
[
o
for
i
,
o
in
enumerate
(
op
.
outputs
)
if
i
not
in
out_idxs
]
required_inputs
=
list
(
gof
.
graph
.
inputs
(
non_removable
))
required_inputs
=
list
(
gof
.
graph
.
graph_
inputs
(
non_removable
))
out_ins
=
[]
out_ins
=
[]
offset
=
op
.
n_seqs
offset
=
op
.
n_seqs
...
@@ -734,7 +734,7 @@ def scan_can_remove_outs(op, out_idxs):
...
@@ -734,7 +734,7 @@ def scan_can_remove_outs(op, out_idxs):
if
out_idxs_mask
[
pos
]
and
any
([
x
in
required_inputs
for
x
in
out_ins
[
idx
]]):
if
out_idxs_mask
[
pos
]
and
any
([
x
in
required_inputs
for
x
in
out_ins
[
idx
]]):
# This output is required ..
# This output is required ..
out_idxs_mask
[
pos
]
=
0
out_idxs_mask
[
pos
]
=
0
required_inputs
+=
list
(
gof
.
graph
.
inputs
([
op
.
outputs
[
idx
]]))
required_inputs
+=
list
(
gof
.
graph
.
graph_
inputs
([
op
.
outputs
[
idx
]]))
added
=
True
added
=
True
required_outs
=
[
x
for
i
,
x
in
enumerate
(
out_idxs
)
if
out_idxs_mask
[
i
]
==
0
]
required_outs
=
[
x
for
i
,
x
in
enumerate
(
out_idxs
)
if
out_idxs_mask
[
i
]
==
0
]
...
@@ -900,7 +900,7 @@ def reconstruct_graph(inputs, outputs, tag=None):
...
@@ -900,7 +900,7 @@ def reconstruct_graph(inputs, outputs, tag=None):
givens
=
OrderedDict
()
givens
=
OrderedDict
()
for
nw_x
,
x
in
zip
(
nw_inputs
,
inputs
):
for
nw_x
,
x
in
zip
(
nw_inputs
,
inputs
):
givens
[
x
]
=
nw_x
givens
[
x
]
=
nw_x
allinputs
=
list
(
theano
.
gof
.
graph
.
inputs
(
outputs
))
allinputs
=
list
(
theano
.
gof
.
graph
.
graph_
inputs
(
outputs
))
for
inp
in
allinputs
:
for
inp
in
allinputs
:
if
isinstance
(
inp
,
theano
.
Constant
):
if
isinstance
(
inp
,
theano
.
Constant
):
givens
[
inp
]
=
inp
.
clone
()
givens
[
inp
]
=
inp
.
clone
()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论