提交 9abca4b8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Convert theano.gof.graph graph walking functions to generators

上级 88dfd88f
...@@ -8,22 +8,24 @@ from theano import shared, tensor ...@@ -8,22 +8,24 @@ from theano import shared, tensor
from theano.gof.graph import ( from theano.gof.graph import (
Apply, Apply,
Variable, Variable,
ancestors,
as_string, as_string,
clone, clone,
equal_computations, equal_computations,
general_toposort, general_toposort,
inputs, inputs,
io_toposort, io_toposort,
is_in_ancestors,
list_of_nodes,
ops,
orphans,
stack_search,
variables,
) )
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.type import Type from theano.gof.type import Type
def as_variable(x):
assert isinstance(x, Variable)
return x
class MyType(Type): class MyType(Type):
def __init__(self, thingy): def __init__(self, thingy):
self.thingy = thingy self.thingy = thingy
...@@ -47,32 +49,16 @@ class MyOp(Op): ...@@ -47,32 +49,16 @@ class MyOp(Op):
__props__ = () __props__ = ()
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = list(map(as_variable, inputs))
for input in inputs: for input in inputs:
if not isinstance(input.type, MyType): assert isinstance(input, Variable)
print(input, input.type, type(input), type(input.type)) assert isinstance(input.type, MyType)
raise Exception("Error 1") outputs = [MyVariable(sum(input.type.thingy for input in inputs))]
outputs = [MyVariable(sum([input.type.thingy for input in inputs]))] return Apply(self, list(inputs), outputs)
return Apply(self, inputs, outputs)
MyOp = MyOp() MyOp = MyOp()
class TestInputs:
def test_inputs(self):
r1, r2 = MyVariable(1), MyVariable(2)
node = MyOp.make_node(r1, r2)
assert inputs(node.outputs) == [r1, r2]
def test_inputs_deep(self):
r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5)
node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], r5)
i = inputs(node2.outputs)
assert i == [r1, r2, r5], i
class X: class X:
def leaf_formatter(self, leaf): def leaf_formatter(self, leaf):
return str(leaf.type) return str(leaf.type)
...@@ -145,7 +131,7 @@ class TestClone(X): ...@@ -145,7 +131,7 @@ class TestClone(X):
node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5) node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5)
_, new = clone([r1, r2, r5], node.outputs, False) _, new = clone([r1, r2, r5], node.outputs, False)
new_node = new[0].owner new_node = new[0].owner
new_node.inputs = MyVariable(7), MyVariable(8) new_node.inputs = [MyVariable(7), MyVariable(8)]
assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(R7, R8)"] assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(R7, R8)"]
assert self.str(inputs(node.outputs), node.outputs) == [ assert self.str(inputs(node.outputs), node.outputs) == [
"MyOp(MyOp(R1, R2), R5)" "MyOp(MyOp(R1, R2), R5)"
...@@ -156,7 +142,7 @@ class TestClone(X): ...@@ -156,7 +142,7 @@ class TestClone(X):
node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5) node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5)
_, new = clone([r1, r2, r5], node.outputs, False) _, new = clone([r1, r2, r5], node.outputs, False)
new_node = new[0].owner new_node = new[0].owner
new_node.inputs = MyVariable(7), MyVariable(8) new_node.inputs = [MyVariable(7), MyVariable(8)]
c1 = tensor.constant(1.5) c1 = tensor.constant(1.5)
i, o = clone([c1], [c1]) i, o = clone([c1], [c1])
...@@ -181,19 +167,36 @@ def prenode(obj): ...@@ -181,19 +167,36 @@ def prenode(obj):
class TestToposort: class TestToposort:
def test_0(self): def test_simple(self):
# Test a simple graph # Test a simple graph
r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5)
o = MyOp.make_node(r1, r2) o = MyOp(r1, r2)
o2 = MyOp.make_node(o.outputs[0], r5) o.name = "o1"
o2 = MyOp(o, r5)
all = general_toposort(o2.outputs, prenode) o2.name = "o2"
assert all == [r5, r2, r1, o, o.outputs[0], o2, o2.outputs[0]]
clients = {}
all = io_toposort([r5], o2.outputs) res = general_toposort([o2], prenode, clients=clients)
assert all == [o, o2]
assert clients == {
def test_1(self): o2.owner: [o2],
o: [o2.owner],
r5: [o2.owner],
o.owner: [o],
r1: [o.owner],
r2: [o.owner],
}
assert res == [r5, r2, r1, o.owner, o, o2.owner, o2]
with pytest.raises(ValueError):
general_toposort(
[o2], prenode, compute_deps_cache=lambda x: None, deps_cache=None
)
res = io_toposort([r5], [o2])
assert res == [o.owner, o2.owner]
def test_double_dependencies(self):
# Test a graph with double dependencies # Test a graph with double dependencies
r1, r5 = MyVariable(1), MyVariable(5) r1, r5 = MyVariable(1), MyVariable(5)
o = MyOp.make_node(r1, r1) o = MyOp.make_node(r1, r1)
...@@ -201,7 +204,7 @@ class TestToposort: ...@@ -201,7 +204,7 @@ class TestToposort:
all = general_toposort(o2.outputs, prenode) all = general_toposort(o2.outputs, prenode)
assert all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]] assert all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]]
def test_2(self): def test_inputs_owners(self):
# Test a graph where the inputs have owners # Test a graph where the inputs have owners
r1, r5 = MyVariable(1), MyVariable(5) r1, r5 = MyVariable(1), MyVariable(5)
o = MyOp.make_node(r1, r1) o = MyOp.make_node(r1, r1)
...@@ -214,7 +217,7 @@ class TestToposort: ...@@ -214,7 +217,7 @@ class TestToposort:
all = io_toposort([r2b], o2.outputs) all = io_toposort([r2b], o2.outputs)
assert all == [o2] assert all == [o2]
def test_3(self): def test_not_connected(self):
# Test a graph which is not connected # Test a graph which is not connected
r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4) r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4)
o0 = MyOp.make_node(r1, r2) o0 = MyOp.make_node(r1, r2)
...@@ -222,7 +225,7 @@ class TestToposort: ...@@ -222,7 +225,7 @@ class TestToposort:
all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs) all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs)
assert all == [o1, o0] or all == [o0, o1] assert all == [o1, o0] or all == [o0, o1]
def test_4(self): def test_io_chain(self):
# Test inputs and outputs mixed together in a chain graph # Test inputs and outputs mixed together in a chain graph
r1, r2 = MyVariable(1), MyVariable(2) r1, r2 = MyVariable(1), MyVariable(2)
o0 = MyOp.make_node(r1, r2) o0 = MyOp.make_node(r1, r2)
...@@ -230,7 +233,7 @@ class TestToposort: ...@@ -230,7 +233,7 @@ class TestToposort:
all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]]) all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]])
assert all == [o1] assert all == [o1]
def test_5(self): def test_outputs_clients(self):
# Test when outputs have clients # Test when outputs have clients
r1, r2, r4 = MyVariable(1), MyVariable(2), MyVariable(4) r1, r2, r4 = MyVariable(1), MyVariable(2), MyVariable(4)
o0 = MyOp.make_node(r1, r2) o0 = MyOp.make_node(r1, r2)
...@@ -326,3 +329,134 @@ def test_equal_computations(): ...@@ -326,3 +329,134 @@ def test_equal_computations():
max_argmax1 = tensor.max_and_argmax(m) max_argmax1 = tensor.max_and_argmax(m)
max_argmax2 = tensor.max_and_argmax(m) max_argmax2 = tensor.max_and_argmax(m)
assert equal_computations(max_argmax1, max_argmax2) assert equal_computations(max_argmax1, max_argmax2)
def test_stack_search():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
def expand(r):
if r.owner:
return r.owner.inputs
res = stack_search([o2], expand, bfs=True, return_children=False)
res_list = list(res)
assert res_list == [o2, r3, o1, r1, r2]
res = stack_search([o2], expand, bfs=False, return_children=False)
res_list = list(res)
assert res_list == [o2, o1, r2, r1, r3]
res = stack_search([o2], expand, bfs=True, return_children=True)
res_list = list(res)
assert res_list == [
(o2, [r3, o1]),
(r3, None),
(o1, [r1, r2]),
(r1, None),
(r2, None),
]
def test_ancestors():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
res = ancestors([o2], blockers=None)
res_list = list(res)
assert res_list == [o2, r3, o1, r1, r2]
res = ancestors([o2], blockers=None)
assert r3 in res
res_list = list(res)
assert res_list == [o1, r1, r2]
res = ancestors([o2], blockers=[o1])
res_list = list(res)
assert res_list == [o2, r3, o1]
def test_inputs():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
res = inputs([o2], blockers=None)
res_list = list(res)
assert res_list == [r3, r1, r2]
def test_variables_and_orphans():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
vars_res = variables([r1, r2], [o2])
orphans_res = orphans([r1, r2], [o2])
vars_res_list = list(vars_res)
orphans_res_list = list(orphans_res)
assert vars_res_list == [o2, o1, r3, r2, r1]
assert orphans_res_list == [r3]
def test_ops():
r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, r4)
o2.name = "o2"
o3 = MyOp(r3, o1, o2)
o3.name = "o3"
res = ops([r1, r2], [o3])
res_list = list(res)
assert res_list == [o3.owner, o2.owner, o1.owner]
def test_list_of_nodes():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
res = list_of_nodes([r1, r2], [o2])
assert res == [o2.owner, o1.owner]
def test_is_in_ancestors():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
assert is_in_ancestors(o2.owner, o1.owner)
@pytest.mark.xfail(reason="Not implemented")
def test_io_connection_pattern():
raise AssertionError()
@pytest.mark.xfail(reason="Not implemented")
def test_view_roots():
raise AssertionError()
...@@ -1336,7 +1336,6 @@ def test_grad_useless_sum(): ...@@ -1336,7 +1336,6 @@ def test_grad_useless_sum():
x = TensorType(theano.config.floatX, (True,))("x") x = TensorType(theano.config.floatX, (True,))("x")
l = tt.log(1.0 - sigmoid(x))[0] l = tt.log(1.0 - sigmoid(x))[0]
g = tt.grad(l, x) g = tt.grad(l, x)
nodes = theano.gof.graph.ops([x], [g])
f = theano.function([x], g, mode=mode) f = theano.function([x], g, mode=mode)
test_values = [-100, -1, 0, 1, 100] test_values = [-100, -1, 0, 1, 100]
...@@ -1349,7 +1348,9 @@ def test_grad_useless_sum(): ...@@ -1349,7 +1348,9 @@ def test_grad_useless_sum():
finally: finally:
TensorType.values_eq_approx = old_values_eq_approx TensorType.values_eq_approx = old_values_eq_approx
assert not any([isinstance(node.op, Sum) for node in nodes]) assert not any(
[isinstance(node.op, Sum) for node in theano.gof.graph.ops([x], [g])]
)
assert np.allclose( assert np.allclose(
outputs, [[-3.72007598e-44], [-0.26894142], [-0.5], [-0.73105858], [-1.0]] outputs, [[-3.72007598e-44], [-0.26894142], [-0.5], [-0.73105858], [-1.0]]
) )
......
...@@ -22,7 +22,7 @@ def grad_sources_inputs(sources, inputs): ...@@ -22,7 +22,7 @@ def grad_sources_inputs(sources, inputs):
the new interface so the tests don't need to be rewritten. the new interface so the tests don't need to be rewritten.
""" """
if inputs is None: if inputs is None:
inputs = theano.gof.graph.inputs([source[0] for source in sources]) inputs = list(theano.gof.graph.inputs([source[0] for source in sources]))
return dict( return dict(
zip( zip(
inputs, inputs,
......
...@@ -2415,9 +2415,11 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2415,9 +2415,11 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
inputs = [self.wrap_in(i) for i in inputs] inputs = [self.wrap_in(i) for i in inputs]
outputs = [self.wrap_out(o) for o in outputs] outputs = [self.wrap_out(o) for o in outputs]
_inputs = gof.graph.inputs( _inputs = list(
[o.variable for o in outputs] gof.graph.inputs(
+ [i.update for i in inputs if getattr(i, "update", False)] [o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
)
) )
# Check if some input variables are unused # Check if some input variables are unused
......
...@@ -1206,7 +1206,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1206,7 +1206,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
} }
# We can't use fgraph.inputs as this don't include Constant Value. # We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = gof.graph.inputs(fgraph.outputs) all_graph_inputs = list(gof.graph.inputs(fgraph.outputs))
has_destroyers_attr = hasattr(fgraph, "has_destroyers") has_destroyers_attr = hasattr(fgraph, "has_destroyers")
for i in range(len(fgraph.outputs)): for i in range(len(fgraph.outputs)):
...@@ -1553,9 +1553,11 @@ class FunctionMaker: ...@@ -1553,9 +1553,11 @@ class FunctionMaker:
# Wrap them in In or Out instances if needed. # Wrap them in In or Out instances if needed.
inputs = [self.wrap_in(i) for i in inputs] inputs = [self.wrap_in(i) for i in inputs]
outputs = [self.wrap_out(o) for o in outputs] outputs = [self.wrap_out(o) for o in outputs]
_inputs = gof.graph.inputs( _inputs = list(
[o.variable for o in outputs] gof.graph.inputs(
+ [i.update for i in inputs if getattr(i, "update", False)] [o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
)
) )
# Check if some input variables are unused # Check if some input variables are unused
...@@ -1697,12 +1699,14 @@ class FunctionMaker: ...@@ -1697,12 +1699,14 @@ class FunctionMaker:
# There should be two categories of variables in inputs: # There should be two categories of variables in inputs:
# - variables that have to be provided (used_inputs) # - variables that have to be provided (used_inputs)
# - shared variables that will be updated # - shared variables that will be updated
used_inputs = gof.graph.ancestors( used_inputs = list(
( gof.graph.ancestors(
[o.variable for o in outputs] (
+ [i.update for i in inputs if getattr(i, "update", False)] [o.variable for o in outputs]
), + [i.update for i in inputs if getattr(i, "update", False)]
blockers=[i.variable for i in inputs], ),
blockers=[i.variable for i in inputs],
)
) )
msg = ( msg = (
......
...@@ -710,7 +710,7 @@ class FunctionGraph(utils.MetaObject): ...@@ -710,7 +710,7 @@ class FunctionGraph(utils.MetaObject):
Call this for a diagnosis if things go awry. Call this for a diagnosis if things go awry.
""" """
nodes = ops_between(self.inputs, self.outputs) nodes = set(ops_between(self.inputs, self.outputs))
if self.apply_nodes != nodes: if self.apply_nodes != nodes:
missing = nodes.difference(self.apply_nodes) missing = nodes.difference(self.apply_nodes)
excess = self.apply_nodes.difference(nodes) excess = self.apply_nodes.difference(nodes)
......
差异被折叠。
...@@ -35,10 +35,6 @@ _logger = logging.getLogger("theano.gof.opt") ...@@ -35,10 +35,6 @@ _logger = logging.getLogger("theano.gof.opt")
_optimizer_idx = [0] _optimizer_idx = [0]
def _list_of_nodes(fgraph):
return list(graph.io_toposort(fgraph.inputs, fgraph.outputs))
class LocalMetaOptimizerSkipAssertionError(AssertionError): class LocalMetaOptimizerSkipAssertionError(AssertionError):
"""This is an AssertionError, but instead of having the """This is an AssertionError, but instead of having the
LocalMetaOptimizer print the error, it just skip that LocalMetaOptimizer print the error, it just skip that
...@@ -1344,7 +1340,9 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1344,7 +1340,9 @@ class LocalOptGroup(LocalOptimizer):
else: # It must be a dict else: # It must be a dict
new_vars = list(new_repl.values()) new_vars = list(new_repl.values())
if self.profile: if self.profile:
self.node_created[opt] += len(graph.ops(fgraph.variables, new_vars)) self.node_created[opt] += len(
list(graph.ops(fgraph.variables, new_vars))
)
self.applied_true[opt] += 1 self.applied_true[opt] += 1
break # break from the for loop over optimization. break # break from the for loop over optimization.
if not new_repl: # No optimization applied in the last iteration if not new_repl: # No optimization applied in the last iteration
...@@ -1454,7 +1452,9 @@ class GraphToGPULocalOptGroup(LocalOptGroup): ...@@ -1454,7 +1452,9 @@ class GraphToGPULocalOptGroup(LocalOptGroup):
if not new_repl: if not new_repl:
continue continue
if self.profile: if self.profile:
self.node_created[opt] += len(graph.ops(fgraph.variables, new_repl)) self.node_created[opt] += len(
list(graph.ops(fgraph.variables, new_repl))
)
self.applied_true[opt] += 1 self.applied_true[opt] += 1
return new_repl return new_repl
......
...@@ -807,7 +807,7 @@ def is_same_graph_with_merge(var1, var2, givens=None): ...@@ -807,7 +807,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
vars = copied[0:2] vars = copied[0:2]
givens = copied[2] givens = copied[2]
# Create FunctionGraph. # Create FunctionGraph.
graph_inputs = inputs(vars) graph_inputs = list(inputs(vars))
# The clone isn't needed as we did a deepcopy and we cloning will # The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens. # break the mapping in givens.
fgraph = theano.gof.fg.FunctionGraph(graph_inputs, vars, clone=False) fgraph = theano.gof.fg.FunctionGraph(graph_inputs, vars, clone=False)
......
...@@ -637,7 +637,7 @@ class CLinker(Linker): ...@@ -637,7 +637,7 @@ class CLinker(Linker):
# We need to include the unused inputs in our variables, # We need to include the unused inputs in our variables,
# otherwise we can't pass them to the module. # otherwise we can't pass them to the module.
self.variables = [var for var in self.inputs if not len(fgraph.clients[var])] self.variables = [var for var in self.inputs if not len(fgraph.clients[var])]
self.variables += get_variables(self.inputs, self.outputs) self.variables += list(get_variables(self.inputs, self.outputs))
# This adds a hidden input which is the params for each node # This adds a hidden input which is the params for each node
# that needs it # that needs it
......
...@@ -820,7 +820,7 @@ def pydotprint( ...@@ -820,7 +820,7 @@ def pydotprint(
fct = fct.outputs fct = fct.outputs
assert isinstance(fct, (list, tuple)) assert isinstance(fct, (list, tuple))
assert all(isinstance(v, gof.Variable) for v in fct) assert all(isinstance(v, gof.Variable) for v in fct)
fct = gof.FunctionGraph(inputs=gof.graph.inputs(fct), outputs=fct) fct = gof.FunctionGraph(inputs=list(gof.graph.inputs(fct)), outputs=fct)
profile = None profile = None
outputs = fct.outputs outputs = fct.outputs
topo = fct.toposort() topo = fct.toposort()
......
...@@ -150,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -150,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
# Same for the outer graph, initialized w/ number of steps # Same for the outer graph, initialized w/ number of steps
nw_outer = [node.inputs[0]] nw_outer = [node.inputs[0]]
all_ins = gof.graph.inputs(op_outs) all_ins = list(gof.graph.inputs(op_outs))
for idx in range(op.n_seqs): for idx in range(op.n_seqs):
node_inp = node.inputs[idx + 1] node_inp = node.inputs[idx + 1]
if ( if (
......
...@@ -268,7 +268,7 @@ def map_variables(replacer, graphs, additional_inputs=None): ...@@ -268,7 +268,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
return new_graph return new_graph
graphs = list(graphs) graphs = list(graphs)
inputs_ = list(set(gof.graph.inputs(graphs) + list(additional_inputs))) inputs_ = list(set(list(gof.graph.inputs(graphs)) + list(additional_inputs)))
# perform any desired replacement of input variables. these # perform any desired replacement of input variables. these
# aren't replaced by the local optimizer approach because they are # aren't replaced by the local optimizer approach because they are
...@@ -280,7 +280,7 @@ def map_variables(replacer, graphs, additional_inputs=None): ...@@ -280,7 +280,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
if new_input is not input_ if new_input is not input_
] ]
graphs = clone(graphs, share_inputs=True, replace=replacements) graphs = clone(graphs, share_inputs=True, replace=replacements)
inputs_ = list(set(gof.graph.inputs(graphs) + list(additional_inputs))) inputs_ = list(set(list(gof.graph.inputs(graphs)) + list(additional_inputs)))
fg = gof.fg.FunctionGraph(inputs_, graphs, clone=False) fg = gof.fg.FunctionGraph(inputs_, graphs, clone=False)
...@@ -714,7 +714,7 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -714,7 +714,7 @@ def scan_can_remove_outs(op, out_idxs):
""" """
non_removable = [o for i, o in enumerate(op.outputs) if i not in out_idxs] non_removable = [o for i, o in enumerate(op.outputs) if i not in out_idxs]
required_inputs = gof.graph.inputs(non_removable) required_inputs = list(gof.graph.inputs(non_removable))
out_ins = [] out_ins = []
offset = op.n_seqs offset = op.n_seqs
...@@ -734,7 +734,7 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -734,7 +734,7 @@ def scan_can_remove_outs(op, out_idxs):
if out_idxs_mask[pos] and any([x in required_inputs for x in out_ins[idx]]): if out_idxs_mask[pos] and any([x in required_inputs for x in out_ins[idx]]):
# This output is required .. # This output is required ..
out_idxs_mask[pos] = 0 out_idxs_mask[pos] = 0
required_inputs += gof.graph.inputs([op.outputs[idx]]) required_inputs += list(gof.graph.inputs([op.outputs[idx]]))
added = True added = True
required_outs = [x for i, x in enumerate(out_idxs) if out_idxs_mask[i] == 0] required_outs = [x for i, x in enumerate(out_idxs) if out_idxs_mask[i] == 0]
...@@ -900,7 +900,7 @@ def reconstruct_graph(inputs, outputs, tag=None): ...@@ -900,7 +900,7 @@ def reconstruct_graph(inputs, outputs, tag=None):
givens = OrderedDict() givens = OrderedDict()
for nw_x, x in zip(nw_inputs, inputs): for nw_x, x in zip(nw_inputs, inputs):
givens[x] = nw_x givens[x] = nw_x
allinputs = theano.gof.graph.inputs(outputs) allinputs = list(theano.gof.graph.inputs(outputs))
for inp in allinputs: for inp in allinputs:
if isinstance(inp, theano.Constant): if isinstance(inp, theano.Constant):
givens[inp] = inp.clone() givens[inp] = inp.clone()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论