提交 d70c1cc6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename theano.gof.graph.variables to vars_between

上级 6a57a3a0
......@@ -19,7 +19,7 @@ from theano.gof.graph import (
is_in_ancestors,
list_of_nodes,
orphans,
variables,
vars_between,
walk,
)
from theano.gof.op import Op
......@@ -405,7 +405,7 @@ def test_variables_and_orphans():
o2 = MyOp(r3, o1)
o2.name = "o2"
vars_res = variables([r1, r2], [o2])
vars_res = vars_between([r1, r2], [o2])
orphans_res = orphans([r1, r2], [o2])
vars_res_list = list(vars_res)
......
......@@ -1444,7 +1444,7 @@ class FunctionMaker:
):
print("loop through outputs node for both graphs")
graph_old.variables = set(
gof.graph.variables(graph_old.inputs, graph_old.outputs)
gof.graph.vars_between(graph_old.inputs, graph_old.outputs)
)
# using clone allowed to avoid a lot of errors
......@@ -1489,7 +1489,7 @@ class FunctionMaker:
# this is a brand new graph, optimize it, save it to graph_db
print("graph not found in graph_db, optimizing the graph")
self.fgraph.variables = set(
gof.graph.variables(self.fgraph.inputs, self.fgraph.outputs)
gof.graph.vars_between(self.fgraph.inputs, self.fgraph.outputs)
)
# check_integrity parameters was added to ignore
# "excess cached variables" errors. Works that way
......
......@@ -9,8 +9,7 @@ from theano.gof import toolbox, utils
from theano.gof.graph import Apply, Constant, Variable, applys_between
from theano.gof.graph import as_string as graph_as_string
from theano.gof.graph import clone as clone_graph
from theano.gof.graph import clone_get_equiv, io_toposort
from theano.gof.graph import variables as variables_between
from theano.gof.graph import clone_get_equiv, io_toposort, vars_between
from theano.gof.utils import TestValueError, get_variable_trace_string
from theano.misc.ordered_set import OrderedSet
......@@ -725,7 +724,7 @@ class FunctionGraph(utils.MetaObject):
raise Exception(
f"Inconsistent clients list {(node, i)} in {clients}"
)
variables = set(variables_between(self.inputs, self.outputs))
variables = set(vars_between(self.inputs, self.outputs))
if set(self.variables) != variables:
missing = variables.difference(self.variables)
excess = self.variables.difference(variables)
......
......@@ -781,7 +781,7 @@ def inputs(
yield from (r for r in ancestors(graphs, blockers) if r.owner is None)
def variables(
def vars_between(
ins: Collection[Variable], outs: Iterable[Variable]
) -> Generator[Variable, None, None]:
"""Extract the `Variable`s within the sub-graph between input and output nodes.
......@@ -835,7 +835,7 @@ def orphans(
[y]
"""
yield from (r for r in variables(ins, outs) if r.owner is None and r not in ins)
yield from (r for r in vars_between(ins, outs) if r.owner is None and r not in ins)
def applys_between(
......@@ -860,7 +860,7 @@ def applys_between(
"""
yield from (
r.owner for r in variables(ins, outs) if r not in ins and r.owner is not None
r.owner for r in vars_between(ins, outs) if r not in ins and r.owner is not None
)
......
......@@ -10,7 +10,7 @@ import numpy as np
import theano
from theano.configdefaults import config
from theano.gof.graph import equal_computations, inputs, io_toposort, variables
from theano.gof.graph import equal_computations, inputs, io_toposort, vars_between
class AlreadyThere(Exception):
......@@ -895,7 +895,7 @@ def is_same_graph(var1, var2, givens=None):
# Compute the sets of all variables found in each computational graph.
inputs_var = list(map(inputs, ([var1], [var2])))
all_vars = [
set(variables(v_i, v_o))
set(vars_between(v_i, v_o))
for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2]))
]
......
......@@ -14,8 +14,7 @@ import numpy as np
from theano.compile.compilelock import lock_ctx
from theano.configdefaults import config
from theano.gof.callcache import CallCache
from theano.gof.graph import Constant, NoParams, io_toposort
from theano.gof.graph import variables as get_variables
from theano.gof.graph import Constant, NoParams, io_toposort, vars_between
from theano.gof.utils import MethodNotDefined
from theano.link.basic import Container, Linker, LocalLinker, PerformLinker
from theano.link.c.cmodule import (
......@@ -637,7 +636,7 @@ class CLinker(Linker):
# We need to include the unused inputs in our variables,
# otherwise we can't pass them to the module.
self.variables = [var for var in self.inputs if not len(fgraph.clients[var])]
self.variables += list(get_variables(self.inputs, self.outputs))
self.variables += list(vars_between(self.inputs, self.outputs))
# This adds a hidden input which is the params for each node
# that needs it
......
......@@ -827,7 +827,7 @@ def local_abstract_batch_norm_train(fgraph, node):
for (r, r_orig) in zip(results, node.outputs)
]
for var in theano.gof.graph.variables(node.inputs, results):
for var in theano.gof.graph.vars_between(node.inputs, results):
if var not in node.inputs:
copy_stack_trace(node.outputs[0], var)
return results
......@@ -866,7 +866,7 @@ def local_abstract_batch_norm_train_grad(fgraph, node):
for (r, r_orig) in zip(results, node.outputs)
]
for var in theano.gof.graph.variables(node.inputs, results):
for var in theano.gof.graph.vars_between(node.inputs, results):
if var not in node.inputs:
copy_stack_trace(node.outputs[0], var)
return results
......@@ -898,7 +898,7 @@ def local_abstract_batch_norm_inference(fgraph, node):
) + bias
result = tt.patternbroadcast(result, node.outputs[0].broadcastable)
for var in theano.gof.graph.variables(node.inputs, [result]):
for var in theano.gof.graph.vars_between(node.inputs, [result]):
if var not in node.inputs:
copy_stack_trace(node.outputs[0], var)
return [result]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论