Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9abca4b8
提交
9abca4b8
authored
1月 02, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
1月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Convert theano.gof.graph graph walking functions to generators
上级
88dfd88f
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
13 个修改的文件
包含
215 行增加
和
74 行删除
+215
-74
test_graph.py
tests/gof/test_graph.py
+176
-42
test_elemwise.py
tests/tensor/test_elemwise.py
+3
-2
test_gradient.py
tests/test_gradient.py
+1
-1
debugmode.py
theano/compile/debugmode.py
+5
-3
types.py
theano/compile/function/types.py
+14
-10
fg.py
theano/gof/fg.py
+1
-1
graph.py
theano/gof/graph.py
+0
-0
opt.py
theano/gof/opt.py
+6
-6
toolbox.py
theano/gof/toolbox.py
+1
-1
basic.py
theano/link/c/basic.py
+1
-1
printing.py
theano/printing.py
+1
-1
opt.py
theano/scan/opt.py
+1
-1
utils.py
theano/scan/utils.py
+5
-5
没有找到文件。
tests/gof/test_graph.py
浏览文件 @
9abca4b8
...
...
@@ -8,22 +8,24 @@ from theano import shared, tensor
from
theano.gof.graph
import
(
Apply
,
Variable
,
ancestors
,
as_string
,
clone
,
equal_computations
,
general_toposort
,
inputs
,
io_toposort
,
is_in_ancestors
,
list_of_nodes
,
ops
,
orphans
,
stack_search
,
variables
,
)
from
theano.gof.op
import
Op
from
theano.gof.type
import
Type
def
as_variable
(
x
):
assert
isinstance
(
x
,
Variable
)
return
x
class
MyType
(
Type
):
def
__init__
(
self
,
thingy
):
self
.
thingy
=
thingy
...
...
@@ -47,32 +49,16 @@ class MyOp(Op):
__props__
=
()
def
make_node
(
self
,
*
inputs
):
inputs
=
list
(
map
(
as_variable
,
inputs
))
for
input
in
inputs
:
if
not
isinstance
(
input
.
type
,
MyType
):
print
(
input
,
input
.
type
,
type
(
input
),
type
(
input
.
type
))
raise
Exception
(
"Error 1"
)
outputs
=
[
MyVariable
(
sum
([
input
.
type
.
thingy
for
input
in
inputs
]))]
return
Apply
(
self
,
inputs
,
outputs
)
assert
isinstance
(
input
,
Variable
)
assert
isinstance
(
input
.
type
,
MyType
)
outputs
=
[
MyVariable
(
sum
(
input
.
type
.
thingy
for
input
in
inputs
))]
return
Apply
(
self
,
list
(
inputs
),
outputs
)
MyOp
=
MyOp
()
class
TestInputs
:
def
test_inputs
(
self
):
r1
,
r2
=
MyVariable
(
1
),
MyVariable
(
2
)
node
=
MyOp
.
make_node
(
r1
,
r2
)
assert
inputs
(
node
.
outputs
)
==
[
r1
,
r2
]
def
test_inputs_deep
(
self
):
r1
,
r2
,
r5
=
MyVariable
(
1
),
MyVariable
(
2
),
MyVariable
(
5
)
node
=
MyOp
.
make_node
(
r1
,
r2
)
node2
=
MyOp
.
make_node
(
node
.
outputs
[
0
],
r5
)
i
=
inputs
(
node2
.
outputs
)
assert
i
==
[
r1
,
r2
,
r5
],
i
class
X
:
def
leaf_formatter
(
self
,
leaf
):
return
str
(
leaf
.
type
)
...
...
@@ -145,7 +131,7 @@ class TestClone(X):
node
=
MyOp
.
make_node
(
MyOp
.
make_node
(
r1
,
r2
)
.
outputs
[
0
],
r5
)
_
,
new
=
clone
([
r1
,
r2
,
r5
],
node
.
outputs
,
False
)
new_node
=
new
[
0
]
.
owner
new_node
.
inputs
=
MyVariable
(
7
),
MyVariable
(
8
)
new_node
.
inputs
=
[
MyVariable
(
7
),
MyVariable
(
8
)]
assert
self
.
str
(
inputs
(
new_node
.
outputs
),
new_node
.
outputs
)
==
[
"MyOp(R7, R8)"
]
assert
self
.
str
(
inputs
(
node
.
outputs
),
node
.
outputs
)
==
[
"MyOp(MyOp(R1, R2), R5)"
...
...
@@ -156,7 +142,7 @@ class TestClone(X):
node
=
MyOp
.
make_node
(
MyOp
.
make_node
(
r1
,
r2
)
.
outputs
[
0
],
r5
)
_
,
new
=
clone
([
r1
,
r2
,
r5
],
node
.
outputs
,
False
)
new_node
=
new
[
0
]
.
owner
new_node
.
inputs
=
MyVariable
(
7
),
MyVariable
(
8
)
new_node
.
inputs
=
[
MyVariable
(
7
),
MyVariable
(
8
)]
c1
=
tensor
.
constant
(
1.5
)
i
,
o
=
clone
([
c1
],
[
c1
])
...
...
@@ -181,19 +167,36 @@ def prenode(obj):
class
TestToposort
:
def
test_
0
(
self
):
def
test_
simple
(
self
):
# Test a simple graph
r1
,
r2
,
r5
=
MyVariable
(
1
),
MyVariable
(
2
),
MyVariable
(
5
)
o
=
MyOp
.
make_node
(
r1
,
r2
)
o2
=
MyOp
.
make_node
(
o
.
outputs
[
0
],
r5
)
all
=
general_toposort
(
o2
.
outputs
,
prenode
)
assert
all
==
[
r5
,
r2
,
r1
,
o
,
o
.
outputs
[
0
],
o2
,
o2
.
outputs
[
0
]]
all
=
io_toposort
([
r5
],
o2
.
outputs
)
assert
all
==
[
o
,
o2
]
def
test_1
(
self
):
o
=
MyOp
(
r1
,
r2
)
o
.
name
=
"o1"
o2
=
MyOp
(
o
,
r5
)
o2
.
name
=
"o2"
clients
=
{}
res
=
general_toposort
([
o2
],
prenode
,
clients
=
clients
)
assert
clients
==
{
o2
.
owner
:
[
o2
],
o
:
[
o2
.
owner
],
r5
:
[
o2
.
owner
],
o
.
owner
:
[
o
],
r1
:
[
o
.
owner
],
r2
:
[
o
.
owner
],
}
assert
res
==
[
r5
,
r2
,
r1
,
o
.
owner
,
o
,
o2
.
owner
,
o2
]
with
pytest
.
raises
(
ValueError
):
general_toposort
(
[
o2
],
prenode
,
compute_deps_cache
=
lambda
x
:
None
,
deps_cache
=
None
)
res
=
io_toposort
([
r5
],
[
o2
])
assert
res
==
[
o
.
owner
,
o2
.
owner
]
def
test_double_dependencies
(
self
):
# Test a graph with double dependencies
r1
,
r5
=
MyVariable
(
1
),
MyVariable
(
5
)
o
=
MyOp
.
make_node
(
r1
,
r1
)
...
...
@@ -201,7 +204,7 @@ class TestToposort:
all
=
general_toposort
(
o2
.
outputs
,
prenode
)
assert
all
==
[
r5
,
r1
,
o
,
o
.
outputs
[
0
],
o2
,
o2
.
outputs
[
0
]]
def
test_
2
(
self
):
def
test_
inputs_owners
(
self
):
# Test a graph where the inputs have owners
r1
,
r5
=
MyVariable
(
1
),
MyVariable
(
5
)
o
=
MyOp
.
make_node
(
r1
,
r1
)
...
...
@@ -214,7 +217,7 @@ class TestToposort:
all
=
io_toposort
([
r2b
],
o2
.
outputs
)
assert
all
==
[
o2
]
def
test_
3
(
self
):
def
test_
not_connected
(
self
):
# Test a graph which is not connected
r1
,
r2
,
r3
,
r4
=
MyVariable
(
1
),
MyVariable
(
2
),
MyVariable
(
3
),
MyVariable
(
4
)
o0
=
MyOp
.
make_node
(
r1
,
r2
)
...
...
@@ -222,7 +225,7 @@ class TestToposort:
all
=
io_toposort
([
r1
,
r2
,
r3
,
r4
],
o0
.
outputs
+
o1
.
outputs
)
assert
all
==
[
o1
,
o0
]
or
all
==
[
o0
,
o1
]
def
test_
4
(
self
):
def
test_
io_chain
(
self
):
# Test inputs and outputs mixed together in a chain graph
r1
,
r2
=
MyVariable
(
1
),
MyVariable
(
2
)
o0
=
MyOp
.
make_node
(
r1
,
r2
)
...
...
@@ -230,7 +233,7 @@ class TestToposort:
all
=
io_toposort
([
r1
,
o0
.
outputs
[
0
]],
[
o0
.
outputs
[
0
],
o1
.
outputs
[
0
]])
assert
all
==
[
o1
]
def
test_
5
(
self
):
def
test_
outputs_clients
(
self
):
# Test when outputs have clients
r1
,
r2
,
r4
=
MyVariable
(
1
),
MyVariable
(
2
),
MyVariable
(
4
)
o0
=
MyOp
.
make_node
(
r1
,
r2
)
...
...
@@ -326,3 +329,134 @@ def test_equal_computations():
max_argmax1
=
tensor
.
max_and_argmax
(
m
)
max_argmax2
=
tensor
.
max_and_argmax
(
m
)
assert
equal_computations
(
max_argmax1
,
max_argmax2
)
def
test_stack_search
():
r1
,
r2
,
r3
=
MyVariable
(
1
),
MyVariable
(
2
),
MyVariable
(
3
)
o1
=
MyOp
(
r1
,
r2
)
o1
.
name
=
"o1"
o2
=
MyOp
(
r3
,
o1
)
o2
.
name
=
"o2"
def
expand
(
r
):
if
r
.
owner
:
return
r
.
owner
.
inputs
res
=
stack_search
([
o2
],
expand
,
bfs
=
True
,
return_children
=
False
)
res_list
=
list
(
res
)
assert
res_list
==
[
o2
,
r3
,
o1
,
r1
,
r2
]
res
=
stack_search
([
o2
],
expand
,
bfs
=
False
,
return_children
=
False
)
res_list
=
list
(
res
)
assert
res_list
==
[
o2
,
o1
,
r2
,
r1
,
r3
]
res
=
stack_search
([
o2
],
expand
,
bfs
=
True
,
return_children
=
True
)
res_list
=
list
(
res
)
assert
res_list
==
[
(
o2
,
[
r3
,
o1
]),
(
r3
,
None
),
(
o1
,
[
r1
,
r2
]),
(
r1
,
None
),
(
r2
,
None
),
]
def
test_ancestors
():
r1
,
r2
,
r3
=
MyVariable
(
1
),
MyVariable
(
2
),
MyVariable
(
3
)
o1
=
MyOp
(
r1
,
r2
)
o1
.
name
=
"o1"
o2
=
MyOp
(
r3
,
o1
)
o2
.
name
=
"o2"
res
=
ancestors
([
o2
],
blockers
=
None
)
res_list
=
list
(
res
)
assert
res_list
==
[
o2
,
r3
,
o1
,
r1
,
r2
]
res
=
ancestors
([
o2
],
blockers
=
None
)
assert
r3
in
res
res_list
=
list
(
res
)
assert
res_list
==
[
o1
,
r1
,
r2
]
res
=
ancestors
([
o2
],
blockers
=
[
o1
])
res_list
=
list
(
res
)
assert
res_list
==
[
o2
,
r3
,
o1
]
def
test_inputs
():
r1
,
r2
,
r3
=
MyVariable
(
1
),
MyVariable
(
2
),
MyVariable
(
3
)
o1
=
MyOp
(
r1
,
r2
)
o1
.
name
=
"o1"
o2
=
MyOp
(
r3
,
o1
)
o2
.
name
=
"o2"
res
=
inputs
([
o2
],
blockers
=
None
)
res_list
=
list
(
res
)
assert
res_list
==
[
r3
,
r1
,
r2
]
def
test_variables_and_orphans
():
r1
,
r2
,
r3
=
MyVariable
(
1
),
MyVariable
(
2
),
MyVariable
(
3
)
o1
=
MyOp
(
r1
,
r2
)
o1
.
name
=
"o1"
o2
=
MyOp
(
r3
,
o1
)
o2
.
name
=
"o2"
vars_res
=
variables
([
r1
,
r2
],
[
o2
])
orphans_res
=
orphans
([
r1
,
r2
],
[
o2
])
vars_res_list
=
list
(
vars_res
)
orphans_res_list
=
list
(
orphans_res
)
assert
vars_res_list
==
[
o2
,
o1
,
r3
,
r2
,
r1
]
assert
orphans_res_list
==
[
r3
]
def
test_ops
():
r1
,
r2
,
r3
,
r4
=
MyVariable
(
1
),
MyVariable
(
2
),
MyVariable
(
3
),
MyVariable
(
4
)
o1
=
MyOp
(
r1
,
r2
)
o1
.
name
=
"o1"
o2
=
MyOp
(
r3
,
r4
)
o2
.
name
=
"o2"
o3
=
MyOp
(
r3
,
o1
,
o2
)
o3
.
name
=
"o3"
res
=
ops
([
r1
,
r2
],
[
o3
])
res_list
=
list
(
res
)
assert
res_list
==
[
o3
.
owner
,
o2
.
owner
,
o1
.
owner
]
def
test_list_of_nodes
():
r1
,
r2
,
r3
=
MyVariable
(
1
),
MyVariable
(
2
),
MyVariable
(
3
)
o1
=
MyOp
(
r1
,
r2
)
o1
.
name
=
"o1"
o2
=
MyOp
(
r3
,
o1
)
o2
.
name
=
"o2"
res
=
list_of_nodes
([
r1
,
r2
],
[
o2
])
assert
res
==
[
o2
.
owner
,
o1
.
owner
]
def
test_is_in_ancestors
():
r1
,
r2
,
r3
=
MyVariable
(
1
),
MyVariable
(
2
),
MyVariable
(
3
)
o1
=
MyOp
(
r1
,
r2
)
o1
.
name
=
"o1"
o2
=
MyOp
(
r3
,
o1
)
o2
.
name
=
"o2"
assert
is_in_ancestors
(
o2
.
owner
,
o1
.
owner
)
@pytest.mark.xfail
(
reason
=
"Not implemented"
)
def
test_io_connection_pattern
():
raise
AssertionError
()
@pytest.mark.xfail
(
reason
=
"Not implemented"
)
def
test_view_roots
():
raise
AssertionError
()
tests/tensor/test_elemwise.py
浏览文件 @
9abca4b8
...
...
@@ -1336,7 +1336,6 @@ def test_grad_useless_sum():
x
=
TensorType
(
theano
.
config
.
floatX
,
(
True
,))(
"x"
)
l
=
tt
.
log
(
1.0
-
sigmoid
(
x
))[
0
]
g
=
tt
.
grad
(
l
,
x
)
nodes
=
theano
.
gof
.
graph
.
ops
([
x
],
[
g
])
f
=
theano
.
function
([
x
],
g
,
mode
=
mode
)
test_values
=
[
-
100
,
-
1
,
0
,
1
,
100
]
...
...
@@ -1349,7 +1348,9 @@ def test_grad_useless_sum():
finally
:
TensorType
.
values_eq_approx
=
old_values_eq_approx
assert
not
any
([
isinstance
(
node
.
op
,
Sum
)
for
node
in
nodes
])
assert
not
any
(
[
isinstance
(
node
.
op
,
Sum
)
for
node
in
theano
.
gof
.
graph
.
ops
([
x
],
[
g
])]
)
assert
np
.
allclose
(
outputs
,
[[
-
3.72007598e-44
],
[
-
0.26894142
],
[
-
0.5
],
[
-
0.73105858
],
[
-
1.0
]]
)
...
...
tests/test_gradient.py
浏览文件 @
9abca4b8
...
...
@@ -22,7 +22,7 @@ def grad_sources_inputs(sources, inputs):
the new interface so the tests don't need to be rewritten.
"""
if
inputs
is
None
:
inputs
=
theano
.
gof
.
graph
.
inputs
([
source
[
0
]
for
source
in
sources
]
)
inputs
=
list
(
theano
.
gof
.
graph
.
inputs
([
source
[
0
]
for
source
in
sources
])
)
return
dict
(
zip
(
inputs
,
...
...
theano/compile/debugmode.py
浏览文件 @
9abca4b8
...
...
@@ -2415,9 +2415,11 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
inputs
=
[
self
.
wrap_in
(
i
)
for
i
in
inputs
]
outputs
=
[
self
.
wrap_out
(
o
)
for
o
in
outputs
]
_inputs
=
gof
.
graph
.
inputs
(
[
o
.
variable
for
o
in
outputs
]
+
[
i
.
update
for
i
in
inputs
if
getattr
(
i
,
"update"
,
False
)]
_inputs
=
list
(
gof
.
graph
.
inputs
(
[
o
.
variable
for
o
in
outputs
]
+
[
i
.
update
for
i
in
inputs
if
getattr
(
i
,
"update"
,
False
)]
)
)
# Check if some input variables are unused
...
...
theano/compile/function/types.py
浏览文件 @
9abca4b8
...
...
@@ -1206,7 +1206,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
}
# We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs
=
gof
.
graph
.
inputs
(
fgraph
.
outputs
)
all_graph_inputs
=
list
(
gof
.
graph
.
inputs
(
fgraph
.
outputs
)
)
has_destroyers_attr
=
hasattr
(
fgraph
,
"has_destroyers"
)
for
i
in
range
(
len
(
fgraph
.
outputs
)):
...
...
@@ -1553,9 +1553,11 @@ class FunctionMaker:
# Wrap them in In or Out instances if needed.
inputs
=
[
self
.
wrap_in
(
i
)
for
i
in
inputs
]
outputs
=
[
self
.
wrap_out
(
o
)
for
o
in
outputs
]
_inputs
=
gof
.
graph
.
inputs
(
[
o
.
variable
for
o
in
outputs
]
+
[
i
.
update
for
i
in
inputs
if
getattr
(
i
,
"update"
,
False
)]
_inputs
=
list
(
gof
.
graph
.
inputs
(
[
o
.
variable
for
o
in
outputs
]
+
[
i
.
update
for
i
in
inputs
if
getattr
(
i
,
"update"
,
False
)]
)
)
# Check if some input variables are unused
...
...
@@ -1697,12 +1699,14 @@ class FunctionMaker:
# There should be two categories of variables in inputs:
# - variables that have to be provided (used_inputs)
# - shared variables that will be updated
used_inputs
=
gof
.
graph
.
ancestors
(
(
[
o
.
variable
for
o
in
outputs
]
+
[
i
.
update
for
i
in
inputs
if
getattr
(
i
,
"update"
,
False
)]
),
blockers
=
[
i
.
variable
for
i
in
inputs
],
used_inputs
=
list
(
gof
.
graph
.
ancestors
(
(
[
o
.
variable
for
o
in
outputs
]
+
[
i
.
update
for
i
in
inputs
if
getattr
(
i
,
"update"
,
False
)]
),
blockers
=
[
i
.
variable
for
i
in
inputs
],
)
)
msg
=
(
...
...
theano/gof/fg.py
浏览文件 @
9abca4b8
...
...
@@ -710,7 +710,7 @@ class FunctionGraph(utils.MetaObject):
Call this for a diagnosis if things go awry.
"""
nodes
=
ops_between
(
self
.
inputs
,
self
.
outputs
)
nodes
=
set
(
ops_between
(
self
.
inputs
,
self
.
outputs
)
)
if
self
.
apply_nodes
!=
nodes
:
missing
=
nodes
.
difference
(
self
.
apply_nodes
)
excess
=
self
.
apply_nodes
.
difference
(
nodes
)
...
...
theano/gof/graph.py
浏览文件 @
9abca4b8
差异被折叠。
点击展开。
theano/gof/opt.py
浏览文件 @
9abca4b8
...
...
@@ -35,10 +35,6 @@ _logger = logging.getLogger("theano.gof.opt")
_optimizer_idx
=
[
0
]
def
_list_of_nodes
(
fgraph
):
return
list
(
graph
.
io_toposort
(
fgraph
.
inputs
,
fgraph
.
outputs
))
class
LocalMetaOptimizerSkipAssertionError
(
AssertionError
):
"""This is an AssertionError, but instead of having the
LocalMetaOptimizer print the error, it just skip that
...
...
@@ -1344,7 +1340,9 @@ class LocalOptGroup(LocalOptimizer):
else
:
# It must be a dict
new_vars
=
list
(
new_repl
.
values
())
if
self
.
profile
:
self
.
node_created
[
opt
]
+=
len
(
graph
.
ops
(
fgraph
.
variables
,
new_vars
))
self
.
node_created
[
opt
]
+=
len
(
list
(
graph
.
ops
(
fgraph
.
variables
,
new_vars
))
)
self
.
applied_true
[
opt
]
+=
1
break
# break from the for loop over optimization.
if
not
new_repl
:
# No optimization applied in the last iteration
...
...
@@ -1454,7 +1452,9 @@ class GraphToGPULocalOptGroup(LocalOptGroup):
if
not
new_repl
:
continue
if
self
.
profile
:
self
.
node_created
[
opt
]
+=
len
(
graph
.
ops
(
fgraph
.
variables
,
new_repl
))
self
.
node_created
[
opt
]
+=
len
(
list
(
graph
.
ops
(
fgraph
.
variables
,
new_repl
))
)
self
.
applied_true
[
opt
]
+=
1
return
new_repl
...
...
theano/gof/toolbox.py
浏览文件 @
9abca4b8
...
...
@@ -807,7 +807,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
vars
=
copied
[
0
:
2
]
givens
=
copied
[
2
]
# Create FunctionGraph.
graph_inputs
=
inputs
(
vars
)
graph_inputs
=
list
(
inputs
(
vars
)
)
# The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens.
fgraph
=
theano
.
gof
.
fg
.
FunctionGraph
(
graph_inputs
,
vars
,
clone
=
False
)
...
...
theano/link/c/basic.py
浏览文件 @
9abca4b8
...
...
@@ -637,7 +637,7 @@ class CLinker(Linker):
# We need to include the unused inputs in our variables,
# otherwise we can't pass them to the module.
self
.
variables
=
[
var
for
var
in
self
.
inputs
if
not
len
(
fgraph
.
clients
[
var
])]
self
.
variables
+=
get_variables
(
self
.
inputs
,
self
.
outputs
)
self
.
variables
+=
list
(
get_variables
(
self
.
inputs
,
self
.
outputs
)
)
# This adds a hidden input which is the params for each node
# that needs it
...
...
theano/printing.py
浏览文件 @
9abca4b8
...
...
@@ -820,7 +820,7 @@ def pydotprint(
fct
=
fct
.
outputs
assert
isinstance
(
fct
,
(
list
,
tuple
))
assert
all
(
isinstance
(
v
,
gof
.
Variable
)
for
v
in
fct
)
fct
=
gof
.
FunctionGraph
(
inputs
=
gof
.
graph
.
inputs
(
fct
),
outputs
=
fct
)
fct
=
gof
.
FunctionGraph
(
inputs
=
list
(
gof
.
graph
.
inputs
(
fct
)
),
outputs
=
fct
)
profile
=
None
outputs
=
fct
.
outputs
topo
=
fct
.
toposort
()
...
...
theano/scan/opt.py
浏览文件 @
9abca4b8
...
...
@@ -150,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
# Same for the outer graph, initialized w/ number of steps
nw_outer
=
[
node
.
inputs
[
0
]]
all_ins
=
gof
.
graph
.
inputs
(
op_outs
)
all_ins
=
list
(
gof
.
graph
.
inputs
(
op_outs
)
)
for
idx
in
range
(
op
.
n_seqs
):
node_inp
=
node
.
inputs
[
idx
+
1
]
if
(
...
...
theano/scan/utils.py
浏览文件 @
9abca4b8
...
...
@@ -268,7 +268,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
return
new_graph
graphs
=
list
(
graphs
)
inputs_
=
list
(
set
(
gof
.
graph
.
inputs
(
graphs
)
+
list
(
additional_inputs
)))
inputs_
=
list
(
set
(
list
(
gof
.
graph
.
inputs
(
graphs
)
)
+
list
(
additional_inputs
)))
# perform any desired replacement of input variables. these
# aren't replaced by the local optimizer approach because they are
...
...
@@ -280,7 +280,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
if
new_input
is
not
input_
]
graphs
=
clone
(
graphs
,
share_inputs
=
True
,
replace
=
replacements
)
inputs_
=
list
(
set
(
gof
.
graph
.
inputs
(
graphs
)
+
list
(
additional_inputs
)))
inputs_
=
list
(
set
(
list
(
gof
.
graph
.
inputs
(
graphs
)
)
+
list
(
additional_inputs
)))
fg
=
gof
.
fg
.
FunctionGraph
(
inputs_
,
graphs
,
clone
=
False
)
...
...
@@ -714,7 +714,7 @@ def scan_can_remove_outs(op, out_idxs):
"""
non_removable
=
[
o
for
i
,
o
in
enumerate
(
op
.
outputs
)
if
i
not
in
out_idxs
]
required_inputs
=
gof
.
graph
.
inputs
(
non_removable
)
required_inputs
=
list
(
gof
.
graph
.
inputs
(
non_removable
)
)
out_ins
=
[]
offset
=
op
.
n_seqs
...
...
@@ -734,7 +734,7 @@ def scan_can_remove_outs(op, out_idxs):
if
out_idxs_mask
[
pos
]
and
any
([
x
in
required_inputs
for
x
in
out_ins
[
idx
]]):
# This output is required ..
out_idxs_mask
[
pos
]
=
0
required_inputs
+=
gof
.
graph
.
inputs
([
op
.
outputs
[
idx
]]
)
required_inputs
+=
list
(
gof
.
graph
.
inputs
([
op
.
outputs
[
idx
]])
)
added
=
True
required_outs
=
[
x
for
i
,
x
in
enumerate
(
out_idxs
)
if
out_idxs_mask
[
i
]
==
0
]
...
...
@@ -900,7 +900,7 @@ def reconstruct_graph(inputs, outputs, tag=None):
givens
=
OrderedDict
()
for
nw_x
,
x
in
zip
(
nw_inputs
,
inputs
):
givens
[
x
]
=
nw_x
allinputs
=
theano
.
gof
.
graph
.
inputs
(
outputs
)
allinputs
=
list
(
theano
.
gof
.
graph
.
inputs
(
outputs
)
)
for
inp
in
allinputs
:
if
isinstance
(
inp
,
theano
.
Constant
):
givens
[
inp
]
=
inp
.
clone
()
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论