提交 6c6d81c6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename theano.gof.graph.inputs to graph_inputs

上级 27605fae
...@@ -14,7 +14,7 @@ from theano.gof.graph import ( ...@@ -14,7 +14,7 @@ from theano.gof.graph import (
clone, clone,
equal_computations, equal_computations,
general_toposort, general_toposort,
inputs, graph_inputs,
io_toposort, io_toposort,
is_in_ancestors, is_in_ancestors,
list_of_nodes, list_of_nodes,
...@@ -132,8 +132,10 @@ class TestClone(X): ...@@ -132,8 +132,10 @@ class TestClone(X):
_, new = clone([r1, r2, r5], node.outputs, False) _, new = clone([r1, r2, r5], node.outputs, False)
new_node = new[0].owner new_node = new[0].owner
new_node.inputs = [MyVariable(7), MyVariable(8)] new_node.inputs = [MyVariable(7), MyVariable(8)]
assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(R7, R8)"] assert self.str(graph_inputs(new_node.outputs), new_node.outputs) == [
assert self.str(inputs(node.outputs), node.outputs) == [ "MyOp(R7, R8)"
]
assert self.str(graph_inputs(node.outputs), node.outputs) == [
"MyOp(MyOp(R1, R2), R5)" "MyOp(MyOp(R1, R2), R5)"
] ]
...@@ -384,7 +386,7 @@ def test_ancestors(): ...@@ -384,7 +386,7 @@ def test_ancestors():
assert res_list == [o2, r3, o1] assert res_list == [o2, r3, o1]
def test_inputs(): def test_graph_inputs():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2) o1 = MyOp(r1, r2)
...@@ -392,7 +394,7 @@ def test_inputs(): ...@@ -392,7 +394,7 @@ def test_inputs():
o2 = MyOp(r3, o1) o2 = MyOp(r3, o1)
o2.name = "o2" o2.name = "o2"
res = inputs([o2], blockers=None) res = graph_inputs([o2], blockers=None)
res_list = list(res) res_list = list(res)
assert res_list == [r3, r1, r2] assert res_list == [r3, r1, r2]
......
...@@ -1966,7 +1966,7 @@ class TestScan: ...@@ -1966,7 +1966,7 @@ class TestScan:
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = theano.clone(f1, replace=None, strict=True, share_inputs=True) f2 = theano.clone(f1, replace=None, strict=True, share_inputs=True)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.graph_inputs([f2])
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
...@@ -1982,7 +1982,7 @@ class TestScan: ...@@ -1982,7 +1982,7 @@ class TestScan:
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = theano.clone(f1, replace=None, strict=True, share_inputs=False) f2 = theano.clone(f1, replace=None, strict=True, share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.graph_inputs([f2])
assert z not in f2_inp assert z not in f2_inp
assert x not in f2_inp assert x not in f2_inp
...@@ -2001,7 +2001,7 @@ class TestScan: ...@@ -2001,7 +2001,7 @@ class TestScan:
f2 = theano.clone( f2 = theano.clone(
f1, replace=OrderedDict([(y, y2)]), strict=True, share_inputs=True f1, replace=OrderedDict([(y, y2)]), strict=True, share_inputs=True
) )
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.graph_inputs([f2])
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
assert y2 in f2_inp assert y2 in f2_inp
...@@ -2019,7 +2019,7 @@ class TestScan: ...@@ -2019,7 +2019,7 @@ class TestScan:
f2 = theano.clone( f2 = theano.clone(
f1, replace=OrderedDict([(y, y2)]), strict=False, share_inputs=True f1, replace=OrderedDict([(y, y2)]), strict=False, share_inputs=True
) )
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.graph_inputs([f2])
assert z in f2_inp assert z in f2_inp
assert x in f2_inp assert x in f2_inp
assert y2 in f2_inp assert y2 in f2_inp
...@@ -2035,7 +2035,7 @@ class TestScan: ...@@ -2035,7 +2035,7 @@ class TestScan:
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = theano.clone(f1, replace=[(y, y2)], strict=True, share_inputs=False) f2 = theano.clone(f1, replace=[(y, y2)], strict=True, share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.graph_inputs([f2])
assert z not in f2_inp assert z not in f2_inp
assert x not in f2_inp assert x not in f2_inp
assert y2 not in f2_inp assert y2 not in f2_inp
...@@ -2051,7 +2051,7 @@ class TestScan: ...@@ -2051,7 +2051,7 @@ class TestScan:
f1 = z * (x + y) ** 2 + 5 f1 = z * (x + y) ** 2 + 5
f2 = theano.clone(f1, replace=[(y, y2)], strict=False, share_inputs=False) f2 = theano.clone(f1, replace=[(y, y2)], strict=False, share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2]) f2_inp = theano.gof.graph.graph_inputs([f2])
assert z not in f2_inp assert z not in f2_inp
assert x not in f2_inp assert x not in f2_inp
assert y2 not in f2_inp assert y2 not in f2_inp
......
...@@ -8,8 +8,7 @@ from pytest import fixture, importorskip, raises ...@@ -8,8 +8,7 @@ from pytest import fixture, importorskip, raises
import theano.tensor as tt import theano.tensor as tt
from theano import change_flags, config from theano import change_flags, config
from theano.gof.fg import FunctionGraph from theano.gof.fg import FunctionGraph
from theano.gof.graph import Variable from theano.gof.graph import Variable, graph_inputs
from theano.gof.graph import inputs as tt_inputs
from theano.gof.op import get_test_value from theano.gof.op import get_test_value
from theano.tensor.random.basic import ( from theano.tensor.random.basic import (
bernoulli, bernoulli,
...@@ -145,7 +144,7 @@ def test_normal_ShapeFeature(): ...@@ -145,7 +144,7 @@ def test_normal_ShapeFeature():
d_rv.tag.test_value d_rv.tag.test_value
fg = FunctionGraph( fg = FunctionGraph(
[i for i in tt_inputs([d_rv]) if not isinstance(i, tt.Constant)], [i for i in graph_inputs([d_rv]) if not isinstance(i, tt.Constant)],
[d_rv], [d_rv],
clone=False, clone=False,
features=[tt.opt.ShapeFeature()], features=[tt.opt.ShapeFeature()],
...@@ -296,7 +295,7 @@ def test_mvnormal_ShapeFeature(): ...@@ -296,7 +295,7 @@ def test_mvnormal_ShapeFeature():
d_rv = multivariate_normal(tt.ones((M_tt,)), tt.eye(M_tt), size=2) d_rv = multivariate_normal(tt.ones((M_tt,)), tt.eye(M_tt), size=2)
fg = FunctionGraph( fg = FunctionGraph(
[i for i in tt_inputs([d_rv]) if not isinstance(i, tt.Constant)], [i for i in graph_inputs([d_rv]) if not isinstance(i, tt.Constant)],
[d_rv], [d_rv],
clone=False, clone=False,
features=[tt.opt.ShapeFeature()], features=[tt.opt.ShapeFeature()],
...@@ -305,7 +304,7 @@ def test_mvnormal_ShapeFeature(): ...@@ -305,7 +304,7 @@ def test_mvnormal_ShapeFeature():
s1, s2 = fg.shape_feature.shape_of[d_rv] s1, s2 = fg.shape_feature.shape_of[d_rv]
assert get_test_value(s1) == 2 assert get_test_value(s1) == 2
assert M_tt in tt_inputs([s2]) assert M_tt in graph_inputs([s2])
# Test broadcasted shapes # Test broadcasted shapes
mean = tt.tensor(config.floatX, [True, False]) mean = tt.tensor(config.floatX, [True, False])
...@@ -319,7 +318,7 @@ def test_mvnormal_ShapeFeature(): ...@@ -319,7 +318,7 @@ def test_mvnormal_ShapeFeature():
d_rv = multivariate_normal(mean, cov, size=[2, 3]) d_rv = multivariate_normal(mean, cov, size=[2, 3])
fg = FunctionGraph( fg = FunctionGraph(
[i for i in tt_inputs([d_rv]) if not isinstance(i, tt.Constant)], [i for i in graph_inputs([d_rv]) if not isinstance(i, tt.Constant)],
[d_rv], [d_rv],
clone=False, clone=False,
features=[tt.opt.ShapeFeature()], features=[tt.opt.ShapeFeature()],
...@@ -392,7 +391,7 @@ def test_dirichlet_ShapeFeature(): ...@@ -392,7 +391,7 @@ def test_dirichlet_ShapeFeature():
d_rv = dirichlet(tt.ones((M_tt, N_tt)), name="Gamma") d_rv = dirichlet(tt.ones((M_tt, N_tt)), name="Gamma")
fg = FunctionGraph( fg = FunctionGraph(
[i for i in tt_inputs([d_rv]) if not isinstance(i, tt.Constant)], [i for i in graph_inputs([d_rv]) if not isinstance(i, tt.Constant)],
[d_rv], [d_rv],
clone=False, clone=False,
features=[tt.opt.ShapeFeature()], features=[tt.opt.ShapeFeature()],
...@@ -400,8 +399,8 @@ def test_dirichlet_ShapeFeature(): ...@@ -400,8 +399,8 @@ def test_dirichlet_ShapeFeature():
s1, s2 = fg.shape_feature.shape_of[d_rv] s1, s2 = fg.shape_feature.shape_of[d_rv]
assert M_tt in tt_inputs([s1]) assert M_tt in graph_inputs([s1])
assert N_tt in tt_inputs([s2]) assert N_tt in graph_inputs([s2])
def test_poisson_samples(): def test_poisson_samples():
......
...@@ -22,7 +22,7 @@ def grad_sources_inputs(sources, inputs): ...@@ -22,7 +22,7 @@ def grad_sources_inputs(sources, inputs):
the new interface so the tests don't need to be rewritten. the new interface so the tests don't need to be rewritten.
""" """
if inputs is None: if inputs is None:
inputs = list(theano.gof.graph.inputs([source[0] for source in sources])) inputs = list(theano.gof.graph.graph_inputs([source[0] for source in sources]))
return dict( return dict(
zip( zip(
inputs, inputs,
......
...@@ -339,7 +339,9 @@ class OpFromGraph(Op): ...@@ -339,7 +339,9 @@ class OpFromGraph(Op):
# To correctly support shared variables the inner fct should # To correctly support shared variables the inner fct should
# not see them. Otherwise there is a problem with the gradient. # not see them. Otherwise there is a problem with the gradient.
self.shared_inputs = [ self.shared_inputs = [
var for var in gof.graph.inputs(outputs) if isinstance(var, SharedVariable) var
for var in gof.graph.graph_inputs(outputs)
if isinstance(var, SharedVariable)
] ]
shared_vars = [var.type() for var in self.shared_inputs] shared_vars = [var.type() for var in self.shared_inputs]
......
...@@ -2416,7 +2416,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions ...@@ -2416,7 +2416,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
outputs = [self.wrap_out(o) for o in outputs] outputs = [self.wrap_out(o) for o in outputs]
_inputs = list( _inputs = list(
gof.graph.inputs( gof.graph.graph_inputs(
[o.variable for o in outputs] [o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)] + [i.update for i in inputs if getattr(i, "update", False)]
) )
......
...@@ -1206,7 +1206,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1206,7 +1206,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
} }
# We can't use fgraph.inputs as this don't include Constant Value. # We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = list(gof.graph.inputs(fgraph.outputs)) all_graph_inputs = list(gof.graph.graph_inputs(fgraph.outputs))
has_destroyers_attr = hasattr(fgraph, "has_destroyers") has_destroyers_attr = hasattr(fgraph, "has_destroyers")
for i in range(len(fgraph.outputs)): for i in range(len(fgraph.outputs)):
...@@ -1454,10 +1454,18 @@ class FunctionMaker: ...@@ -1454,10 +1454,18 @@ class FunctionMaker:
t2 = f2.outputs[i] t2 = f2.outputs[i]
givens = dict( givens = dict(
zip(gof.graph.inputs([t1]), gof.graph.inputs([t2])) zip(
gof.graph.graph_inputs([t1]),
gof.graph.graph_inputs([t2]),
)
) )
temp = dict(zip(gof.graph.inputs([t1]), gof.graph.inputs([t2]))) temp = dict(
zip(
gof.graph.graph_inputs([t1]),
gof.graph.graph_inputs([t2]),
)
)
# hack to remove inconstent entry in givens # hack to remove inconstent entry in givens
# seems to work that but source of inconsistency # seems to work that but source of inconsistency
...@@ -1554,7 +1562,7 @@ class FunctionMaker: ...@@ -1554,7 +1562,7 @@ class FunctionMaker:
inputs = [self.wrap_in(i) for i in inputs] inputs = [self.wrap_in(i) for i in inputs]
outputs = [self.wrap_out(o) for o in outputs] outputs = [self.wrap_out(o) for o in outputs]
_inputs = list( _inputs = list(
gof.graph.inputs( gof.graph.graph_inputs(
[o.variable for o in outputs] [o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)] + [i.update for i in inputs if getattr(i, "update", False)]
) )
......
...@@ -11,7 +11,7 @@ import theano ...@@ -11,7 +11,7 @@ import theano
from theano.compile import Function, builders from theano.compile import Function, builders
from theano.gof.fg import FunctionGraph from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply, Constant, Variable from theano.gof.graph import Apply, Constant, Variable
from theano.gof.graph import inputs as graph_inputs from theano.gof.graph import graph_inputs as graph_inputs
from theano.printing import pydot_imported, pydot_imported_msg from theano.printing import pydot_imported, pydot_imported_msg
......
...@@ -757,7 +757,7 @@ def ancestors( ...@@ -757,7 +757,7 @@ def ancestors(
yield from walk(graphs, expand, False) yield from walk(graphs, expand, False)
def inputs( def graph_inputs(
graphs: Iterable[Variable], blockers: Collection[Variable] = None graphs: Iterable[Variable], blockers: Collection[Variable] = None
) -> Generator[Variable, None, None]: ) -> Generator[Variable, None, None]:
"""Return the inputs required to compute the given Variables. """Return the inputs required to compute the given Variables.
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
import theano import theano
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof.graph import equal_computations, inputs, io_toposort, vars_between from theano.gof.graph import equal_computations, graph_inputs, io_toposort, vars_between
class AlreadyThere(Exception): class AlreadyThere(Exception):
...@@ -807,10 +807,10 @@ def is_same_graph_with_merge(var1, var2, givens=None): ...@@ -807,10 +807,10 @@ def is_same_graph_with_merge(var1, var2, givens=None):
vars = copied[0:2] vars = copied[0:2]
givens = copied[2] givens = copied[2]
# Create FunctionGraph. # Create FunctionGraph.
graph_inputs = list(inputs(vars)) inputs = list(graph_inputs(vars))
# The clone isn't needed as we did a deepcopy and we cloning will # The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens. # break the mapping in givens.
fgraph = theano.gof.fg.FunctionGraph(graph_inputs, vars, clone=False) fgraph = theano.gof.fg.FunctionGraph(inputs, vars, clone=False)
# Perform Variable substitution. # Perform Variable substitution.
for to_replace, replace_by in givens.items(): for to_replace, replace_by in givens.items():
fgraph.replace(to_replace, replace_by) fgraph.replace(to_replace, replace_by)
...@@ -893,7 +893,7 @@ def is_same_graph(var1, var2, givens=None): ...@@ -893,7 +893,7 @@ def is_same_graph(var1, var2, givens=None):
in_xs = [] in_xs = []
in_ys = [] in_ys = []
# Compute the sets of all variables found in each computational graph. # Compute the sets of all variables found in each computational graph.
inputs_var = list(map(inputs, ([var1], [var2]))) inputs_var = list(map(graph_inputs, ([var1], [var2])))
all_vars = [ all_vars = [
set(vars_between(v_i, v_o)) set(vars_between(v_i, v_o))
for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2])) for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2]))
......
...@@ -820,7 +820,7 @@ def pydotprint( ...@@ -820,7 +820,7 @@ def pydotprint(
fct = fct.outputs fct = fct.outputs
assert isinstance(fct, (list, tuple)) assert isinstance(fct, (list, tuple))
assert all(isinstance(v, gof.Variable) for v in fct) assert all(isinstance(v, gof.Variable) for v in fct)
fct = gof.FunctionGraph(inputs=list(gof.graph.inputs(fct)), outputs=fct) fct = gof.FunctionGraph(inputs=list(gof.graph.graph_inputs(fct)), outputs=fct)
profile = None profile = None
outputs = fct.outputs outputs = fct.outputs
topo = fct.toposort() topo = fct.toposort()
......
...@@ -803,7 +803,7 @@ def scan( ...@@ -803,7 +803,7 @@ def scan(
and not isinstance(x, SharedVariable) and not isinstance(x, SharedVariable)
and not isinstance(x, gof.Constant) and not isinstance(x, gof.Constant)
), ),
gof.graph.inputs(fake_outputs), gof.graph.graph_inputs(fake_outputs),
) )
extra_inputs = [x for x in all_inputs if x not in args + fake_nonseqs] extra_inputs = [x for x in all_inputs if x not in args + fake_nonseqs]
non_seqs += extra_inputs non_seqs += extra_inputs
......
...@@ -62,7 +62,7 @@ from theano.compile.profiling import ScanProfileStats, register_profiler_printer ...@@ -62,7 +62,7 @@ from theano.compile.profiling import ScanProfileStats, register_profiler_printer
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof.fg import MissingInputError from theano.gof.fg import MissingInputError
from theano.gof.graph import Apply, Variable, equal_computations from theano.gof.graph import Apply, Variable, equal_computations
from theano.gof.graph import inputs as graph_inputs from theano.gof.graph import graph_inputs as graph_inputs
from theano.gof.graph import io_connection_pattern from theano.gof.graph import io_connection_pattern
from theano.gof.op import Op, ops_with_inner_function from theano.gof.op import Op, ops_with_inner_function
from theano.gof.toolbox import NoOutputFromInplace from theano.gof.toolbox import NoOutputFromInplace
......
...@@ -150,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -150,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
# Same for the outer graph, initialized w/ number of steps # Same for the outer graph, initialized w/ number of steps
nw_outer = [node.inputs[0]] nw_outer = [node.inputs[0]]
all_ins = list(gof.graph.inputs(op_outs)) all_ins = list(gof.graph.graph_inputs(op_outs))
for idx in range(op.n_seqs): for idx in range(op.n_seqs):
node_inp = node.inputs[idx + 1] node_inp = node.inputs[idx + 1]
if ( if (
......
...@@ -268,7 +268,7 @@ def map_variables(replacer, graphs, additional_inputs=None): ...@@ -268,7 +268,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
return new_graph return new_graph
graphs = list(graphs) graphs = list(graphs)
inputs_ = list(set(list(gof.graph.inputs(graphs)) + list(additional_inputs))) inputs_ = list(set(list(gof.graph.graph_inputs(graphs)) + list(additional_inputs)))
# perform any desired replacement of input variables. these # perform any desired replacement of input variables. these
# aren't replaced by the local optimizer approach because they are # aren't replaced by the local optimizer approach because they are
...@@ -280,7 +280,7 @@ def map_variables(replacer, graphs, additional_inputs=None): ...@@ -280,7 +280,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
if new_input is not input_ if new_input is not input_
] ]
graphs = clone(graphs, share_inputs=True, replace=replacements) graphs = clone(graphs, share_inputs=True, replace=replacements)
inputs_ = list(set(list(gof.graph.inputs(graphs)) + list(additional_inputs))) inputs_ = list(set(list(gof.graph.graph_inputs(graphs)) + list(additional_inputs)))
fg = gof.fg.FunctionGraph(inputs_, graphs, clone=False) fg = gof.fg.FunctionGraph(inputs_, graphs, clone=False)
...@@ -330,7 +330,7 @@ def map_variables(replacer, graphs, additional_inputs=None): ...@@ -330,7 +330,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
replacements = [wrapped_replacer(o) for o in node.outputs] replacements = [wrapped_replacer(o) for o in node.outputs]
# Add inputs to replacement graphs as inputs to this `fgraph` # Add inputs to replacement graphs as inputs to this `fgraph`
for i in gof.graph.inputs(replacements): for i in gof.graph.graph_inputs(replacements):
fgraph.add_input(i) fgraph.add_input(i)
return replacements return replacements
...@@ -370,7 +370,7 @@ def _map_variables_inner( ...@@ -370,7 +370,7 @@ def _map_variables_inner(
other_inputs = [] other_inputs = []
constants = [] constants = []
for input_ in gof.graph.inputs([new_graph]): for input_ in gof.graph.graph_inputs([new_graph]):
if isinstance(input_, gof.Variable): if isinstance(input_, gof.Variable):
if isinstance(input_, gof.Constant): if isinstance(input_, gof.Constant):
constants.append(input_) constants.append(input_)
...@@ -714,7 +714,7 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -714,7 +714,7 @@ def scan_can_remove_outs(op, out_idxs):
""" """
non_removable = [o for i, o in enumerate(op.outputs) if i not in out_idxs] non_removable = [o for i, o in enumerate(op.outputs) if i not in out_idxs]
required_inputs = list(gof.graph.inputs(non_removable)) required_inputs = list(gof.graph.graph_inputs(non_removable))
out_ins = [] out_ins = []
offset = op.n_seqs offset = op.n_seqs
...@@ -734,7 +734,7 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -734,7 +734,7 @@ def scan_can_remove_outs(op, out_idxs):
if out_idxs_mask[pos] and any([x in required_inputs for x in out_ins[idx]]): if out_idxs_mask[pos] and any([x in required_inputs for x in out_ins[idx]]):
# This output is required .. # This output is required ..
out_idxs_mask[pos] = 0 out_idxs_mask[pos] = 0
required_inputs += list(gof.graph.inputs([op.outputs[idx]])) required_inputs += list(gof.graph.graph_inputs([op.outputs[idx]]))
added = True added = True
required_outs = [x for i, x in enumerate(out_idxs) if out_idxs_mask[i] == 0] required_outs = [x for i, x in enumerate(out_idxs) if out_idxs_mask[i] == 0]
...@@ -900,7 +900,7 @@ def reconstruct_graph(inputs, outputs, tag=None): ...@@ -900,7 +900,7 @@ def reconstruct_graph(inputs, outputs, tag=None):
givens = OrderedDict() givens = OrderedDict()
for nw_x, x in zip(nw_inputs, inputs): for nw_x, x in zip(nw_inputs, inputs):
givens[x] = nw_x givens[x] = nw_x
allinputs = list(theano.gof.graph.inputs(outputs)) allinputs = list(theano.gof.graph.graph_inputs(outputs))
for inp in allinputs: for inp in allinputs:
if isinstance(inp, theano.Constant): if isinstance(inp, theano.Constant):
givens[inp] = inp.clone() givens[inp] = inp.clone()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论