提交 6a57a3a0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Rename theano.gof.graph.ops to applys_between

上级 af03b72f
...@@ -9,6 +9,7 @@ from theano.gof.graph import ( ...@@ -9,6 +9,7 @@ from theano.gof.graph import (
Apply, Apply,
Variable, Variable,
ancestors, ancestors,
applys_between,
as_string, as_string,
clone, clone,
equal_computations, equal_computations,
...@@ -17,7 +18,6 @@ from theano.gof.graph import ( ...@@ -17,7 +18,6 @@ from theano.gof.graph import (
io_toposort, io_toposort,
is_in_ancestors, is_in_ancestors,
list_of_nodes, list_of_nodes,
ops,
orphans, orphans,
variables, variables,
walk, walk,
...@@ -424,7 +424,7 @@ def test_ops(): ...@@ -424,7 +424,7 @@ def test_ops():
o3 = MyOp(r3, o1, o2) o3 = MyOp(r3, o1, o2)
o3.name = "o3" o3.name = "o3"
res = ops([r1, r2], [o3]) res = applys_between([r1, r2], [o3])
res_list = list(res) res_list = list(res)
assert res_list == [o3.owner, o2.owner, o1.owner] assert res_list == [o3.owner, o2.owner, o1.owner]
......
...@@ -1349,7 +1349,7 @@ def test_grad_useless_sum(): ...@@ -1349,7 +1349,7 @@ def test_grad_useless_sum():
TensorType.values_eq_approx = old_values_eq_approx TensorType.values_eq_approx = old_values_eq_approx
assert not any( assert not any(
[isinstance(node.op, Sum) for node in theano.gof.graph.ops([x], [g])] [isinstance(node.op, Sum) for node in theano.gof.graph.applys_between([x], [g])]
) )
assert np.allclose( assert np.allclose(
outputs, [[-3.72007598e-44], [-0.26894142], [-0.5], [-0.73105858], [-1.0]] outputs, [[-3.72007598e-44], [-0.26894142], [-0.5], [-0.73105858], [-1.0]]
......
...@@ -6,7 +6,7 @@ from tests import unittest_tools as utt ...@@ -6,7 +6,7 @@ from tests import unittest_tools as utt
from theano import function from theano import function
from theano import tensor as tt from theano import tensor as tt
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof.graph import ops as graph_ops from theano.gof.graph import applys_between
from theano.tensor.extra_ops import ( from theano.tensor.extra_ops import (
Bartlett, Bartlett,
BroadcastTo, BroadcastTo,
...@@ -1221,7 +1221,8 @@ def test_broadcast_shape(): ...@@ -1221,7 +1221,8 @@ def test_broadcast_shape():
arrays_are_shapes=True, arrays_are_shapes=True,
) )
assert any( assert any(
isinstance(node.op, tt.opt.Assert) for node in graph_ops([x_tt, y_tt], b_tt) isinstance(node.op, tt.opt.Assert)
for node in applys_between([x_tt, y_tt], b_tt)
) )
assert np.array_equal([z.eval() for z in b_tt], b.shape) assert np.array_equal([z.eval() for z in b_tt], b.shape)
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True) b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
...@@ -1229,7 +1230,8 @@ def test_broadcast_shape(): ...@@ -1229,7 +1230,8 @@ def test_broadcast_shape():
# These are all constants, so there shouldn't be any asserts in the # These are all constants, so there shouldn't be any asserts in the
# resulting graph. # resulting graph.
assert not any( assert not any(
isinstance(node.op, tt.opt.Assert) for node in graph_ops([x_tt, y_tt], b_tt) isinstance(node.op, tt.opt.Assert)
for node in applys_between([x_tt, y_tt], b_tt)
) )
x = np.array([1, 2, 3]) x = np.array([1, 2, 3])
...@@ -1257,7 +1259,8 @@ def test_broadcast_shape(): ...@@ -1257,7 +1259,8 @@ def test_broadcast_shape():
assert b_tt[0].value == 1 assert b_tt[0].value == 1
assert np.array_equal([z.eval() for z in b_tt], b.shape) assert np.array_equal([z.eval() for z in b_tt], b.shape)
assert not any( assert not any(
isinstance(node.op, tt.opt.Assert) for node in graph_ops([x_tt, y_tt], b_tt) isinstance(node.op, tt.opt.Assert)
for node in applys_between([x_tt, y_tt], b_tt)
) )
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True) b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_tt], b.shape) assert np.array_equal([z.eval() for z in b_tt], b.shape)
......
...@@ -6,11 +6,10 @@ from io import StringIO ...@@ -6,11 +6,10 @@ from io import StringIO
import theano import theano
from theano.configdefaults import config from theano.configdefaults import config
from theano.gof import toolbox, utils from theano.gof import toolbox, utils
from theano.gof.graph import Apply, Constant, Variable from theano.gof.graph import Apply, Constant, Variable, applys_between
from theano.gof.graph import as_string as graph_as_string from theano.gof.graph import as_string as graph_as_string
from theano.gof.graph import clone as clone_graph from theano.gof.graph import clone as clone_graph
from theano.gof.graph import clone_get_equiv, io_toposort from theano.gof.graph import clone_get_equiv, io_toposort
from theano.gof.graph import ops as ops_between
from theano.gof.graph import variables as variables_between from theano.gof.graph import variables as variables_between
from theano.gof.utils import TestValueError, get_variable_trace_string from theano.gof.utils import TestValueError, get_variable_trace_string
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
...@@ -710,7 +709,7 @@ class FunctionGraph(utils.MetaObject): ...@@ -710,7 +709,7 @@ class FunctionGraph(utils.MetaObject):
Call this for a diagnosis if things go awry. Call this for a diagnosis if things go awry.
""" """
nodes = set(ops_between(self.inputs, self.outputs)) nodes = set(applys_between(self.inputs, self.outputs))
if self.apply_nodes != nodes: if self.apply_nodes != nodes:
missing = nodes.difference(self.apply_nodes) missing = nodes.difference(self.apply_nodes)
excess = self.apply_nodes.difference(nodes) excess = self.apply_nodes.difference(nodes)
......
...@@ -838,7 +838,7 @@ def orphans( ...@@ -838,7 +838,7 @@ def orphans(
yield from (r for r in variables(ins, outs) if r.owner is None and r not in ins) yield from (r for r in variables(ins, outs) if r.owner is None and r not in ins)
def ops( def applys_between(
ins: Collection[Variable], outs: Iterable[Variable] ins: Collection[Variable], outs: Iterable[Variable]
) -> Generator[Apply, None, None]: ) -> Generator[Apply, None, None]:
"""Extract the `Apply`s contained within the sub-graph between given input and output variables. """Extract the `Apply`s contained within the sub-graph between given input and output variables.
...@@ -1291,7 +1291,7 @@ def as_string( ...@@ -1291,7 +1291,7 @@ def as_string(
multi.add(op) multi.add(op)
else: else:
seen.add(op) seen.add(op)
for op in ops(i, outputs): for op in applys_between(i, outputs):
for input in op.inputs: for input in op.inputs:
op2 = input.owner op2 = input.owner
if input in i or input in orph or op2 is None: if input in i or input in orph or op2 is None:
......
...@@ -1341,7 +1341,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1341,7 +1341,7 @@ class LocalOptGroup(LocalOptimizer):
new_vars = list(new_repl.values()) new_vars = list(new_repl.values())
if self.profile: if self.profile:
self.node_created[opt] += len( self.node_created[opt] += len(
list(graph.ops(fgraph.variables, new_vars)) list(graph.applys_between(fgraph.variables, new_vars))
) )
self.applied_true[opt] += 1 self.applied_true[opt] += 1
break # break from the for loop over optimization. break # break from the for loop over optimization.
...@@ -1453,7 +1453,7 @@ class GraphToGPULocalOptGroup(LocalOptGroup): ...@@ -1453,7 +1453,7 @@ class GraphToGPULocalOptGroup(LocalOptGroup):
continue continue
if self.profile: if self.profile:
self.node_created[opt] += len( self.node_created[opt] += len(
list(graph.ops(fgraph.variables, new_repl)) list(graph.applys_between(fgraph.variables, new_repl))
) )
self.applied_true[opt] += 1 self.applied_true[opt] += 1
......
...@@ -397,7 +397,7 @@ class GraphToGPU(GlobalOptimizer): ...@@ -397,7 +397,7 @@ class GraphToGPU(GlobalOptimizer):
if new_ops: if new_ops:
node_created[lopt] += len( node_created[lopt] += len(
graph.ops([mapping[i] for i in node.inputs], outputs) graph.applys_between([mapping[i] for i in node.inputs], outputs)
) )
if any( if any(
[ [
......
...@@ -4154,9 +4154,6 @@ class Composite(ScalarOp): ...@@ -4154,9 +4154,6 @@ class Composite(ScalarOp):
assert len(res[0]) == len(inputs) assert len(res[0]) == len(inputs)
assert res[0] != inputs assert res[0] != inputs
inputs, outputs = res[0], res2[1] inputs, outputs = res[0], res2[1]
# Next assert comment just for speed
# assert not any([isinstance(node.op, Composite) for node in
# ops(inputs, outputs)])
self.inputs = copy(inputs) self.inputs = copy(inputs)
self.outputs = copy(outputs) self.outputs = copy(outputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论