提交 6c6d81c6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename theano.gof.graph.inputs to graph_inputs

上级 27605fae
......@@ -14,7 +14,7 @@ from theano.gof.graph import (
clone,
equal_computations,
general_toposort,
inputs,
graph_inputs,
io_toposort,
is_in_ancestors,
list_of_nodes,
......@@ -132,8 +132,10 @@ class TestClone(X):
_, new = clone([r1, r2, r5], node.outputs, False)
new_node = new[0].owner
new_node.inputs = [MyVariable(7), MyVariable(8)]
assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(R7, R8)"]
assert self.str(inputs(node.outputs), node.outputs) == [
assert self.str(graph_inputs(new_node.outputs), new_node.outputs) == [
"MyOp(R7, R8)"
]
assert self.str(graph_inputs(node.outputs), node.outputs) == [
"MyOp(MyOp(R1, R2), R5)"
]
......@@ -384,7 +386,7 @@ def test_ancestors():
assert res_list == [o2, r3, o1]
def test_inputs():
def test_graph_inputs():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
......@@ -392,7 +394,7 @@ def test_inputs():
o2 = MyOp(r3, o1)
o2.name = "o2"
res = inputs([o2], blockers=None)
res = graph_inputs([o2], blockers=None)
res_list = list(res)
assert res_list == [r3, r1, r2]
......
......@@ -1966,7 +1966,7 @@ class TestScan:
f1 = z * (x + y) ** 2 + 5
f2 = theano.clone(f1, replace=None, strict=True, share_inputs=True)
f2_inp = theano.gof.graph.inputs([f2])
f2_inp = theano.gof.graph.graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
......@@ -1982,7 +1982,7 @@ class TestScan:
f1 = z * (x + y) ** 2 + 5
f2 = theano.clone(f1, replace=None, strict=True, share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2])
f2_inp = theano.gof.graph.graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
......@@ -2001,7 +2001,7 @@ class TestScan:
f2 = theano.clone(
f1, replace=OrderedDict([(y, y2)]), strict=True, share_inputs=True
)
f2_inp = theano.gof.graph.inputs([f2])
f2_inp = theano.gof.graph.graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
......@@ -2019,7 +2019,7 @@ class TestScan:
f2 = theano.clone(
f1, replace=OrderedDict([(y, y2)]), strict=False, share_inputs=True
)
f2_inp = theano.gof.graph.inputs([f2])
f2_inp = theano.gof.graph.graph_inputs([f2])
assert z in f2_inp
assert x in f2_inp
assert y2 in f2_inp
......@@ -2035,7 +2035,7 @@ class TestScan:
f1 = z * (x + y) ** 2 + 5
f2 = theano.clone(f1, replace=[(y, y2)], strict=True, share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2])
f2_inp = theano.gof.graph.graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
......@@ -2051,7 +2051,7 @@ class TestScan:
f1 = z * (x + y) ** 2 + 5
f2 = theano.clone(f1, replace=[(y, y2)], strict=False, share_inputs=False)
f2_inp = theano.gof.graph.inputs([f2])
f2_inp = theano.gof.graph.graph_inputs([f2])
assert z not in f2_inp
assert x not in f2_inp
assert y2 not in f2_inp
......
......@@ -8,8 +8,7 @@ from pytest import fixture, importorskip, raises
import theano.tensor as tt
from theano import change_flags, config
from theano.gof.fg import FunctionGraph
from theano.gof.graph import Variable
from theano.gof.graph import inputs as tt_inputs
from theano.gof.graph import Variable, graph_inputs
from theano.gof.op import get_test_value
from theano.tensor.random.basic import (
bernoulli,
......@@ -145,7 +144,7 @@ def test_normal_ShapeFeature():
d_rv.tag.test_value
fg = FunctionGraph(
[i for i in tt_inputs([d_rv]) if not isinstance(i, tt.Constant)],
[i for i in graph_inputs([d_rv]) if not isinstance(i, tt.Constant)],
[d_rv],
clone=False,
features=[tt.opt.ShapeFeature()],
......@@ -296,7 +295,7 @@ def test_mvnormal_ShapeFeature():
d_rv = multivariate_normal(tt.ones((M_tt,)), tt.eye(M_tt), size=2)
fg = FunctionGraph(
[i for i in tt_inputs([d_rv]) if not isinstance(i, tt.Constant)],
[i for i in graph_inputs([d_rv]) if not isinstance(i, tt.Constant)],
[d_rv],
clone=False,
features=[tt.opt.ShapeFeature()],
......@@ -305,7 +304,7 @@ def test_mvnormal_ShapeFeature():
s1, s2 = fg.shape_feature.shape_of[d_rv]
assert get_test_value(s1) == 2
assert M_tt in tt_inputs([s2])
assert M_tt in graph_inputs([s2])
# Test broadcasted shapes
mean = tt.tensor(config.floatX, [True, False])
......@@ -319,7 +318,7 @@ def test_mvnormal_ShapeFeature():
d_rv = multivariate_normal(mean, cov, size=[2, 3])
fg = FunctionGraph(
[i for i in tt_inputs([d_rv]) if not isinstance(i, tt.Constant)],
[i for i in graph_inputs([d_rv]) if not isinstance(i, tt.Constant)],
[d_rv],
clone=False,
features=[tt.opt.ShapeFeature()],
......@@ -392,7 +391,7 @@ def test_dirichlet_ShapeFeature():
d_rv = dirichlet(tt.ones((M_tt, N_tt)), name="Gamma")
fg = FunctionGraph(
[i for i in tt_inputs([d_rv]) if not isinstance(i, tt.Constant)],
[i for i in graph_inputs([d_rv]) if not isinstance(i, tt.Constant)],
[d_rv],
clone=False,
features=[tt.opt.ShapeFeature()],
......@@ -400,8 +399,8 @@ def test_dirichlet_ShapeFeature():
s1, s2 = fg.shape_feature.shape_of[d_rv]
assert M_tt in tt_inputs([s1])
assert N_tt in tt_inputs([s2])
assert M_tt in graph_inputs([s1])
assert N_tt in graph_inputs([s2])
def test_poisson_samples():
......
......@@ -22,7 +22,7 @@ def grad_sources_inputs(sources, inputs):
the new interface so the tests don't need to be rewritten.
"""
if inputs is None:
inputs = list(theano.gof.graph.inputs([source[0] for source in sources]))
inputs = list(theano.gof.graph.graph_inputs([source[0] for source in sources]))
return dict(
zip(
inputs,
......
......@@ -339,7 +339,9 @@ class OpFromGraph(Op):
# To correctly support shared variables the inner fct should
# not see them. Otherwise there is a problem with the gradient.
self.shared_inputs = [
var for var in gof.graph.inputs(outputs) if isinstance(var, SharedVariable)
var
for var in gof.graph.graph_inputs(outputs)
if isinstance(var, SharedVariable)
]
shared_vars = [var.type() for var in self.shared_inputs]
......
......@@ -2416,7 +2416,7 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
outputs = [self.wrap_out(o) for o in outputs]
_inputs = list(
gof.graph.inputs(
gof.graph.graph_inputs(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
)
......
......@@ -1206,7 +1206,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
}
# We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = list(gof.graph.inputs(fgraph.outputs))
all_graph_inputs = list(gof.graph.graph_inputs(fgraph.outputs))
has_destroyers_attr = hasattr(fgraph, "has_destroyers")
for i in range(len(fgraph.outputs)):
......@@ -1454,10 +1454,18 @@ class FunctionMaker:
t2 = f2.outputs[i]
givens = dict(
zip(gof.graph.inputs([t1]), gof.graph.inputs([t2]))
zip(
gof.graph.graph_inputs([t1]),
gof.graph.graph_inputs([t2]),
)
)
temp = dict(zip(gof.graph.inputs([t1]), gof.graph.inputs([t2])))
temp = dict(
zip(
gof.graph.graph_inputs([t1]),
gof.graph.graph_inputs([t2]),
)
)
# hack to remove inconstent entry in givens
# seems to work that but source of inconsistency
......@@ -1554,7 +1562,7 @@ class FunctionMaker:
inputs = [self.wrap_in(i) for i in inputs]
outputs = [self.wrap_out(o) for o in outputs]
_inputs = list(
gof.graph.inputs(
gof.graph.graph_inputs(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
)
......
......@@ -11,7 +11,7 @@ import theano
from theano.compile import Function, builders
from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply, Constant, Variable
from theano.gof.graph import inputs as graph_inputs
from theano.gof.graph import graph_inputs as graph_inputs
from theano.printing import pydot_imported, pydot_imported_msg
......
......@@ -757,7 +757,7 @@ def ancestors(
yield from walk(graphs, expand, False)
def inputs(
def graph_inputs(
graphs: Iterable[Variable], blockers: Collection[Variable] = None
) -> Generator[Variable, None, None]:
"""Return the inputs required to compute the given Variables.
......
......@@ -10,7 +10,7 @@ import numpy as np
import theano
from theano.configdefaults import config
from theano.gof.graph import equal_computations, inputs, io_toposort, vars_between
from theano.gof.graph import equal_computations, graph_inputs, io_toposort, vars_between
class AlreadyThere(Exception):
......@@ -807,10 +807,10 @@ def is_same_graph_with_merge(var1, var2, givens=None):
vars = copied[0:2]
givens = copied[2]
# Create FunctionGraph.
graph_inputs = list(inputs(vars))
inputs = list(graph_inputs(vars))
# The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens.
fgraph = theano.gof.fg.FunctionGraph(graph_inputs, vars, clone=False)
fgraph = theano.gof.fg.FunctionGraph(inputs, vars, clone=False)
# Perform Variable substitution.
for to_replace, replace_by in givens.items():
fgraph.replace(to_replace, replace_by)
......@@ -893,7 +893,7 @@ def is_same_graph(var1, var2, givens=None):
in_xs = []
in_ys = []
# Compute the sets of all variables found in each computational graph.
inputs_var = list(map(inputs, ([var1], [var2])))
inputs_var = list(map(graph_inputs, ([var1], [var2])))
all_vars = [
set(vars_between(v_i, v_o))
for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2]))
......
......@@ -820,7 +820,7 @@ def pydotprint(
fct = fct.outputs
assert isinstance(fct, (list, tuple))
assert all(isinstance(v, gof.Variable) for v in fct)
fct = gof.FunctionGraph(inputs=list(gof.graph.inputs(fct)), outputs=fct)
fct = gof.FunctionGraph(inputs=list(gof.graph.graph_inputs(fct)), outputs=fct)
profile = None
outputs = fct.outputs
topo = fct.toposort()
......
......@@ -803,7 +803,7 @@ def scan(
and not isinstance(x, SharedVariable)
and not isinstance(x, gof.Constant)
),
gof.graph.inputs(fake_outputs),
gof.graph.graph_inputs(fake_outputs),
)
extra_inputs = [x for x in all_inputs if x not in args + fake_nonseqs]
non_seqs += extra_inputs
......
......@@ -62,7 +62,7 @@ from theano.compile.profiling import ScanProfileStats, register_profiler_printer
from theano.configdefaults import config
from theano.gof.fg import MissingInputError
from theano.gof.graph import Apply, Variable, equal_computations
from theano.gof.graph import inputs as graph_inputs
from theano.gof.graph import graph_inputs as graph_inputs
from theano.gof.graph import io_connection_pattern
from theano.gof.op import Op, ops_with_inner_function
from theano.gof.toolbox import NoOutputFromInplace
......
......@@ -150,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
# Same for the outer graph, initialized w/ number of steps
nw_outer = [node.inputs[0]]
all_ins = list(gof.graph.inputs(op_outs))
all_ins = list(gof.graph.graph_inputs(op_outs))
for idx in range(op.n_seqs):
node_inp = node.inputs[idx + 1]
if (
......
......@@ -268,7 +268,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
return new_graph
graphs = list(graphs)
inputs_ = list(set(list(gof.graph.inputs(graphs)) + list(additional_inputs)))
inputs_ = list(set(list(gof.graph.graph_inputs(graphs)) + list(additional_inputs)))
# perform any desired replacement of input variables. these
# aren't replaced by the local optimizer approach because they are
......@@ -280,7 +280,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
if new_input is not input_
]
graphs = clone(graphs, share_inputs=True, replace=replacements)
inputs_ = list(set(list(gof.graph.inputs(graphs)) + list(additional_inputs)))
inputs_ = list(set(list(gof.graph.graph_inputs(graphs)) + list(additional_inputs)))
fg = gof.fg.FunctionGraph(inputs_, graphs, clone=False)
......@@ -330,7 +330,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
replacements = [wrapped_replacer(o) for o in node.outputs]
# Add inputs to replacement graphs as inputs to this `fgraph`
for i in gof.graph.inputs(replacements):
for i in gof.graph.graph_inputs(replacements):
fgraph.add_input(i)
return replacements
......@@ -370,7 +370,7 @@ def _map_variables_inner(
other_inputs = []
constants = []
for input_ in gof.graph.inputs([new_graph]):
for input_ in gof.graph.graph_inputs([new_graph]):
if isinstance(input_, gof.Variable):
if isinstance(input_, gof.Constant):
constants.append(input_)
......@@ -714,7 +714,7 @@ def scan_can_remove_outs(op, out_idxs):
"""
non_removable = [o for i, o in enumerate(op.outputs) if i not in out_idxs]
required_inputs = list(gof.graph.inputs(non_removable))
required_inputs = list(gof.graph.graph_inputs(non_removable))
out_ins = []
offset = op.n_seqs
......@@ -734,7 +734,7 @@ def scan_can_remove_outs(op, out_idxs):
if out_idxs_mask[pos] and any([x in required_inputs for x in out_ins[idx]]):
# This output is required ..
out_idxs_mask[pos] = 0
required_inputs += list(gof.graph.inputs([op.outputs[idx]]))
required_inputs += list(gof.graph.graph_inputs([op.outputs[idx]]))
added = True
required_outs = [x for i, x in enumerate(out_idxs) if out_idxs_mask[i] == 0]
......@@ -900,7 +900,7 @@ def reconstruct_graph(inputs, outputs, tag=None):
givens = OrderedDict()
for nw_x, x in zip(nw_inputs, inputs):
givens[x] = nw_x
allinputs = list(theano.gof.graph.inputs(outputs))
allinputs = list(theano.gof.graph.graph_inputs(outputs))
for inp in allinputs:
if isinstance(inp, theano.Constant):
givens[inp] = inp.clone()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论