提交 d70c1cc6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename theano.gof.graph.variables to vars_between

上级 6a57a3a0
...@@ -19,7 +19,7 @@ from theano.gof.graph import ( ...@@ -19,7 +19,7 @@ from theano.gof.graph import (
is_in_ancestors, is_in_ancestors,
list_of_nodes, list_of_nodes,
orphans, orphans,
variables, vars_between,
walk, walk,
) )
from theano.gof.op import Op from theano.gof.op import Op
...@@ -405,7 +405,7 @@ def test_variables_and_orphans(): ...@@ -405,7 +405,7 @@ def test_variables_and_orphans():
o2 = MyOp(r3, o1) o2 = MyOp(r3, o1)
o2.name = "o2" o2.name = "o2"
vars_res = variables([r1, r2], [o2]) vars_res = vars_between([r1, r2], [o2])
orphans_res = orphans([r1, r2], [o2]) orphans_res = orphans([r1, r2], [o2])
vars_res_list = list(vars_res) vars_res_list = list(vars_res)
......
...@@ -1444,7 +1444,7 @@ class FunctionMaker: ...@@ -1444,7 +1444,7 @@ class FunctionMaker:
): ):
print("loop through outputs node for both graphs") print("loop through outputs node for both graphs")
graph_old.variables = set( graph_old.variables = set(
gof.graph.variables(graph_old.inputs, graph_old.outputs) gof.graph.vars_between(graph_old.inputs, graph_old.outputs)
) )
# using clone allowed to avoid a lot of errors # using clone allowed to avoid a lot of errors
...@@ -1489,7 +1489,7 @@ class FunctionMaker: ...@@ -1489,7 +1489,7 @@ class FunctionMaker:
# this is a brand new graph, optimize it, save it to graph_db # this is a brand new graph, optimize it, save it to graph_db
print("graph not found in graph_db, optimizing the graph") print("graph not found in graph_db, optimizing the graph")
self.fgraph.variables = set( self.fgraph.variables = set(
gof.graph.variables(self.fgraph.inputs, self.fgraph.outputs) gof.graph.vars_between(self.fgraph.inputs, self.fgraph.outputs)
) )
# check_integrity parameters was added to ignore # check_integrity parameters was added to ignore
# "excess cached variables" errors. Works that way # "excess cached variables" errors. Works that way
......
...@@ -9,8 +9,7 @@ from theano.gof import toolbox, utils ...@@ -9,8 +9,7 @@ from theano.gof import toolbox, utils
from theano.gof.graph import Apply, Constant, Variable, applys_between from theano.gof.graph import Apply, Constant, Variable, applys_between
from theano.gof.graph import as_string as graph_as_string from theano.gof.graph import as_string as graph_as_string
from theano.gof.graph import clone as clone_graph from theano.gof.graph import clone as clone_graph
from theano.gof.graph import clone_get_equiv, io_toposort from theano.gof.graph import clone_get_equiv, io_toposort, vars_between
from theano.gof.graph import variables as variables_between
from theano.gof.utils import TestValueError, get_variable_trace_string from theano.gof.utils import TestValueError, get_variable_trace_string
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
...@@ -725,7 +724,7 @@ class FunctionGraph(utils.MetaObject): ...@@ -725,7 +724,7 @@ class FunctionGraph(utils.MetaObject):
raise Exception( raise Exception(
f"Inconsistent clients list {(node, i)} in {clients}" f"Inconsistent clients list {(node, i)} in {clients}"
) )
variables = set(variables_between(self.inputs, self.outputs)) variables = set(vars_between(self.inputs, self.outputs))
if set(self.variables) != variables: if set(self.variables) != variables:
missing = variables.difference(self.variables) missing = variables.difference(self.variables)
excess = self.variables.difference(variables) excess = self.variables.difference(variables)
......
...@@ -781,7 +781,7 @@ def inputs( ...@@ -781,7 +781,7 @@ def inputs(
yield from (r for r in ancestors(graphs, blockers) if r.owner is None) yield from (r for r in ancestors(graphs, blockers) if r.owner is None)
def variables( def vars_between(
ins: Collection[Variable], outs: Iterable[Variable] ins: Collection[Variable], outs: Iterable[Variable]
) -> Generator[Variable, None, None]: ) -> Generator[Variable, None, None]:
"""Extract the `Variable`s within the sub-graph between input and output nodes. """Extract the `Variable`s within the sub-graph between input and output nodes.
...@@ -835,7 +835,7 @@ def orphans( ...@@ -835,7 +835,7 @@ def orphans(
[y] [y]
""" """
yield from (r for r in variables(ins, outs) if r.owner is None and r not in ins) yield from (r for r in vars_between(ins, outs) if r.owner is None and r not in ins)
def applys_between( def applys_between(
...@@ -860,7 +860,7 @@ def applys_between( ...@@ -860,7 +860,7 @@ def applys_between(
""" """
yield from ( yield from (
r.owner for r in variables(ins, outs) if r not in ins and r.owner is not None r.owner for r in vars_between(ins, outs) if r not in ins and r.owner is not None
) )
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
import theano import theano
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof.graph import equal_computations, inputs, io_toposort, variables from theano.gof.graph import equal_computations, inputs, io_toposort, vars_between
class AlreadyThere(Exception): class AlreadyThere(Exception):
...@@ -895,7 +895,7 @@ def is_same_graph(var1, var2, givens=None): ...@@ -895,7 +895,7 @@ def is_same_graph(var1, var2, givens=None):
# Compute the sets of all variables found in each computational graph. # Compute the sets of all variables found in each computational graph.
inputs_var = list(map(inputs, ([var1], [var2]))) inputs_var = list(map(inputs, ([var1], [var2])))
all_vars = [ all_vars = [
set(variables(v_i, v_o)) set(vars_between(v_i, v_o))
for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2])) for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2]))
] ]
......
...@@ -14,8 +14,7 @@ import numpy as np ...@@ -14,8 +14,7 @@ import numpy as np
from theano.compile.compilelock import lock_ctx from theano.compile.compilelock import lock_ctx
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof.callcache import CallCache from theano.gof.callcache import CallCache
from theano.gof.graph import Constant, NoParams, io_toposort from theano.gof.graph import Constant, NoParams, io_toposort, vars_between
from theano.gof.graph import variables as get_variables
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.link.basic import Container, Linker, LocalLinker, PerformLinker from theano.link.basic import Container, Linker, LocalLinker, PerformLinker
from theano.link.c.cmodule import ( from theano.link.c.cmodule import (
...@@ -637,7 +636,7 @@ class CLinker(Linker): ...@@ -637,7 +636,7 @@ class CLinker(Linker):
# We need to include the unused inputs in our variables, # We need to include the unused inputs in our variables,
# otherwise we can't pass them to the module. # otherwise we can't pass them to the module.
self.variables = [var for var in self.inputs if not len(fgraph.clients[var])] self.variables = [var for var in self.inputs if not len(fgraph.clients[var])]
self.variables += list(get_variables(self.inputs, self.outputs)) self.variables += list(vars_between(self.inputs, self.outputs))
# This adds a hidden input which is the params for each node # This adds a hidden input which is the params for each node
# that needs it # that needs it
......
...@@ -827,7 +827,7 @@ def local_abstract_batch_norm_train(fgraph, node): ...@@ -827,7 +827,7 @@ def local_abstract_batch_norm_train(fgraph, node):
for (r, r_orig) in zip(results, node.outputs) for (r, r_orig) in zip(results, node.outputs)
] ]
for var in theano.gof.graph.variables(node.inputs, results): for var in theano.gof.graph.vars_between(node.inputs, results):
if var not in node.inputs: if var not in node.inputs:
copy_stack_trace(node.outputs[0], var) copy_stack_trace(node.outputs[0], var)
return results return results
...@@ -866,7 +866,7 @@ def local_abstract_batch_norm_train_grad(fgraph, node): ...@@ -866,7 +866,7 @@ def local_abstract_batch_norm_train_grad(fgraph, node):
for (r, r_orig) in zip(results, node.outputs) for (r, r_orig) in zip(results, node.outputs)
] ]
for var in theano.gof.graph.variables(node.inputs, results): for var in theano.gof.graph.vars_between(node.inputs, results):
if var not in node.inputs: if var not in node.inputs:
copy_stack_trace(node.outputs[0], var) copy_stack_trace(node.outputs[0], var)
return results return results
...@@ -898,7 +898,7 @@ def local_abstract_batch_norm_inference(fgraph, node): ...@@ -898,7 +898,7 @@ def local_abstract_batch_norm_inference(fgraph, node):
) + bias ) + bias
result = tt.patternbroadcast(result, node.outputs[0].broadcastable) result = tt.patternbroadcast(result, node.outputs[0].broadcastable)
for var in theano.gof.graph.variables(node.inputs, [result]): for var in theano.gof.graph.vars_between(node.inputs, [result]):
if var not in node.inputs: if var not in node.inputs:
copy_stack_trace(node.outputs[0], var) copy_stack_trace(node.outputs[0], var)
return [result] return [result]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论