提交 1ff98d25 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Move merge-based graph functions and equal_computations

The merge-based functions were moved to `toolbox` to reduce unnecessary cross-module dependencies (especially among core modules). `equal_computations` was moved from `scan_module` because it provides a basic graph object `__eq__` implementation and it has multiple references outside of its own module/sub-package.
上级 f14d4799
...@@ -14,8 +14,8 @@ from theano.gof.graph import ( ...@@ -14,8 +14,8 @@ from theano.gof.graph import (
general_toposort, general_toposort,
inputs, inputs,
io_toposort, io_toposort,
is_same_graph,
Variable, Variable,
equal_computations,
) )
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.type import Type from theano.gof.type import Type
...@@ -241,144 +241,6 @@ class TestToposort: ...@@ -241,144 +241,6 @@ class TestToposort:
assert all == [o0] assert all == [o0]
class TestIsSameGraph:
def check(self, expected, debug=True):
"""
Core function to perform comparison.
:param expected: A list of tuples (v1, v2, ((g1, o1), ..., (gN, oN)))
with:
- `v1` and `v2` two Variables (the graphs to be compared)
- `gj` a `givens` dictionary to give as input to `is_same_graph`
- `oj` the expected output of `is_same_graph(v1, v2, givens=gj)`
:param debug: If True, then we make sure we are testing both
implementations of `is_same_graph`.
This function also tries to call `is_same_graph` by inverting `v1` and
`v2`, and ensures the output remains the same.
"""
for v1, v2, go in expected:
for gj, oj in go:
r1 = is_same_graph(v1, v2, givens=gj, debug=debug)
assert r1 == oj
r2 = is_same_graph(v2, v1, givens=gj, debug=debug)
assert r2 == oj
def test_single_var(self):
# Test `is_same_graph` with some trivial graphs (one Variable).
x, y, z = tensor.vectors("x", "y", "z")
self.check(
[
(x, x, (({}, True),)),
(
x,
y,
(
({}, False),
({y: x}, True),
),
),
(x, tensor.neg(x), (({}, False),)),
(x, tensor.neg(y), (({}, False),)),
]
)
def test_full_graph(self):
# Test `is_same_graph` with more complex graphs.
x, y, z = tensor.vectors("x", "y", "z")
t = x * y
self.check(
[
(x * 2, x * 2, (({}, True),)),
(
x * 2,
y * 2,
(
({}, False),
({y: x}, True),
),
),
(
x * 2,
y * 2,
(
({}, False),
({x: y}, True),
),
),
(
x * 2,
y * 3,
(
({}, False),
({y: x}, False),
),
),
(
t * 2,
z * 2,
(
({}, False),
({t: z}, True),
),
),
(
t * 2,
z * 2,
(
({}, False),
({z: t}, True),
),
),
(x * (y * z), (x * y) * z, (({}, False),)),
]
)
def test_merge_only(self):
# Test `is_same_graph` when `equal_computations` cannot be used.
x, y, z = tensor.vectors("x", "y", "z")
t = x * y
self.check(
[
(x, t, (({}, False), ({t: x}, True))),
(
t * 2,
x * 2,
(
({}, False),
({t: x}, True),
),
),
(
x * x,
x * y,
(
({}, False),
({y: x}, True),
),
),
(
x * x,
x * y,
(
({}, False),
({y: x}, True),
),
),
(
x * x + z,
x * y + t,
(({}, False), ({y: x}, False), ({y: x, t: z}, True)),
),
],
debug=False,
)
class TestEval: class TestEval:
def setup_method(self): def setup_method(self):
self.x, self.y = tensor.scalars("x", "y") self.x, self.y = tensor.scalars("x", "y")
...@@ -476,3 +338,13 @@ class TestAutoName: ...@@ -476,3 +338,13 @@ class TestAutoName:
r2 = r1.clone() r2 = r1.clone()
assert r1.auto_name == "auto_" + str(autoname_id) assert r1.auto_name == "auto_" + str(autoname_id)
assert r2.auto_name == "auto_" + str(autoname_id + 1) assert r2.auto_name == "auto_" + str(autoname_id + 1)
def test_equal_computations():
# This was a bug report by a Theano user.
c = tensor.type_other.NoneConst
assert equal_computations([c], [c])
m = tensor.matrix()
max_argmax1 = tensor.max_and_argmax(m)
max_argmax2 = tensor.max_and_argmax(m)
assert equal_computations(max_argmax1, max_argmax2)
from theano import tensor
from theano.gof.graph import Variable, Apply from theano.gof.graph import Variable, Apply
from theano.gof.type import Type from theano.gof.type import Type
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.fg import FunctionGraph from theano.gof.fg import FunctionGraph
from theano.gof.toolbox import NodeFinder from theano.gof.toolbox import NodeFinder, is_same_graph
def as_variable(x): class TestNodeFinder:
assert isinstance(x, Variable) def test_straightforward(self):
return x class MyType(Type):
def __init__(self, name):
self.name = name
class MyType(Type):
def __init__(self, name):
self.name = name
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __eq__(self, other):
return isinstance(other, MyType)
def MyVariable(name): def __str__(self):
return Variable(MyType(name), None, None) return self.name
def __repr__(self):
return self.name
class MyOp(Op): def __eq__(self, other):
return isinstance(other, MyType)
__props__ = ("nin", "name") class MyOp(Op):
def __init__(self, nin, name): __props__ = ("nin", "name")
self.nin = nin
self.name = name
def make_node(self, *inputs): def __init__(self, nin, name):
assert len(inputs) == self.nin self.nin = nin
inputs = list(map(as_variable, inputs)) self.name = name
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyType(self.name + "_R")()]
return Apply(self, inputs, outputs)
def __str__(self): def make_node(self, *inputs):
return self.name def as_variable(x):
assert isinstance(x, Variable)
return x
assert len(inputs) == self.nin
inputs = list(map(as_variable, inputs))
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyType(self.name + "_R")()]
return Apply(self, inputs, outputs)
sigmoid = MyOp(1, "Sigmoid") def __str__(self):
add = MyOp(2, "Add") return self.name
dot = MyOp(2, "Dot")
sigmoid = MyOp(1, "Sigmoid")
add = MyOp(2, "Add")
dot = MyOp(2, "Dot")
def inputs(): def MyVariable(name):
x = MyVariable("x") return Variable(MyType(name), None, None)
y = MyVariable("y")
z = MyVariable("z")
return x, y, z
def inputs():
x = MyVariable("x")
y = MyVariable("y")
z = MyVariable("z")
return x, y, z
class TestNodeFinder:
def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e0 = dot(y, z) e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0)) e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
...@@ -83,3 +78,137 @@ class TestNodeFinder: ...@@ -83,3 +78,137 @@ class TestNodeFinder:
for type, num in ((add, 4), (sigmoid, 3), (dot, 1)): for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
if not len([t for t in g.get_nodes(type)]) == num: if not len([t for t in g.get_nodes(type)]) == num:
raise Exception("Expected: %i times %s" % (num, type)) raise Exception("Expected: %i times %s" % (num, type))
class TestIsSameGraph:
def check(self, expected):
"""
Core function to perform comparison.
:param expected: A list of tuples (v1, v2, ((g1, o1), ..., (gN, oN)))
with:
- `v1` and `v2` two Variables (the graphs to be compared)
- `gj` a `givens` dictionary to give as input to `is_same_graph`
- `oj` the expected output of `is_same_graph(v1, v2, givens=gj)`
This function also tries to call `is_same_graph` by inverting `v1` and
`v2`, and ensures the output remains the same.
"""
for v1, v2, go in expected:
for gj, oj in go:
r1 = is_same_graph(v1, v2, givens=gj)
assert r1 == oj
r2 = is_same_graph(v2, v1, givens=gj)
assert r2 == oj
def test_single_var(self):
# Test `is_same_graph` with some trivial graphs (one Variable).
x, y, z = tensor.vectors("x", "y", "z")
self.check(
[
(x, x, (({}, True),)),
(
x,
y,
(
({}, False),
({y: x}, True),
),
),
(x, tensor.neg(x), (({}, False),)),
(x, tensor.neg(y), (({}, False),)),
]
)
def test_full_graph(self):
# Test `is_same_graph` with more complex graphs.
x, y, z = tensor.vectors("x", "y", "z")
t = x * y
self.check(
[
(x * 2, x * 2, (({}, True),)),
(
x * 2,
y * 2,
(
({}, False),
({y: x}, True),
),
),
(
x * 2,
y * 2,
(
({}, False),
({x: y}, True),
),
),
(
x * 2,
y * 3,
(
({}, False),
({y: x}, False),
),
),
(
t * 2,
z * 2,
(
({}, False),
({t: z}, True),
),
),
(
t * 2,
z * 2,
(
({}, False),
({z: t}, True),
),
),
(x * (y * z), (x * y) * z, (({}, False),)),
]
)
def test_merge_only(self):
# Test `is_same_graph` when `equal_computations` cannot be used.
x, y, z = tensor.vectors("x", "y", "z")
t = x * y
self.check(
[
(x, t, (({}, False), ({t: x}, True))),
(
t * 2,
x * 2,
(
({}, False),
({t: x}, True),
),
),
(
x * x,
x * y,
(
({}, False),
({y: x}, True),
),
),
(
x * x,
x * y,
(
({}, False),
({y: x}, True),
),
),
(
x * x + z,
x * y + t,
(({}, False), ({y: x}, False), ({y: x, t: z}, True)),
),
],
)
...@@ -7,20 +7,7 @@ import numpy as np ...@@ -7,20 +7,7 @@ import numpy as np
import theano import theano
from theano import tensor from theano import tensor
from theano.scan_module.scan_utils import equal_computations, map_variables from theano.scan_module.scan_utils import map_variables
from theano.tensor.type_other import NoneConst
def test_equal_compuations():
# This was a bug report by a Theano user.
c = NoneConst
assert equal_computations([c], [c])
m = theano.tensor.matrix()
max_argmax1 = theano.tensor.max_and_argmax(m)
max_argmax2 = theano.tensor.max_and_argmax(m)
assert equal_computations(max_argmax1, max_argmax2)
class TestMapVariables: class TestMapVariables:
......
...@@ -4,6 +4,7 @@ import theano.tensor.inplace ...@@ -4,6 +4,7 @@ import theano.tensor.inplace
from theano import tensor as T, config from theano import tensor as T, config
from theano.tensor import basic as tensor from theano.tensor import basic as tensor
from theano.gof.opt import check_stack_trace from theano.gof.opt import check_stack_trace
from theano.gof.toolbox import is_same_graph
from theano.tensor.nnet import ( from theano.tensor.nnet import (
sigmoid, sigmoid,
sigmoid_inplace, sigmoid_inplace,
...@@ -351,9 +352,7 @@ class TestSigmoidOpts: ...@@ -351,9 +352,7 @@ class TestSigmoidOpts:
trees = [parse_mul_tree(e) for e in (expr1, expr2)] trees = [parse_mul_tree(e) for e in (expr1, expr2)]
perform_sigm_times_exp(trees[0]) perform_sigm_times_exp(trees[0])
trees[0] = simplify_mul(trees[0]) trees[0] = simplify_mul(trees[0])
good = theano.gof.graph.is_same_graph( good = is_same_graph(compute_mul(trees[0]), compute_mul(trees[1]))
compute_mul(trees[0]), compute_mul(trees[1])
)
if not good: if not good:
print(trees[0]) print(trees[0])
print(trees[1]) print(trees[1])
...@@ -541,7 +540,7 @@ class TestSigmoidUtils: ...@@ -541,7 +540,7 @@ class TestSigmoidUtils:
tree = (x * y) * -z tree = (x * y) * -z
mul_tree = parse_mul_tree(tree) mul_tree = parse_mul_tree(tree)
assert parse_mul_tree(compute_mul(mul_tree)) == mul_tree assert parse_mul_tree(compute_mul(mul_tree)) == mul_tree
assert theano.gof.graph.is_same_graph(compute_mul(parse_mul_tree(tree)), tree) assert is_same_graph(compute_mul(parse_mul_tree(tree)), tree)
def test_parse_mul_tree(self): def test_parse_mul_tree(self):
x, y, z = tensor.vectors("x", "y", "z") x, y, z = tensor.vectors("x", "y", "z")
...@@ -566,7 +565,7 @@ class TestSigmoidUtils: ...@@ -566,7 +565,7 @@ class TestSigmoidUtils:
lambda x: is_1pexp(x, only_process_constants=False), lambda x: is_1pexp(x, only_process_constants=False),
[(1 + exp(-x)), (exp(-x) + 1)], [(1 + exp(-x)), (exp(-x) + 1)],
): ):
assert not neg and theano.gof.graph.is_same_graph(exp_arg, -x) assert not neg and is_same_graph(exp_arg, -x)
assert is_1pexp(1 - exp(x), False) is None assert is_1pexp(1 - exp(x), False) is None
assert is_1pexp(2 + exp(x), False) is None assert is_1pexp(2 + exp(x), False) is None
assert is_1pexp(exp(x) + 2, False) is None assert is_1pexp(exp(x) + 2, False) is None
......
...@@ -10,7 +10,7 @@ from six import StringIO ...@@ -10,7 +10,7 @@ from six import StringIO
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
from theano import config from theano import config
from theano.compile import DeepCopyOp from theano.compile import DeepCopyOp
from theano.gof.graph import is_same_graph from theano.gof.toolbox import is_same_graph
from theano.tensor import ( from theano.tensor import (
_shared, _shared,
cscalar, cscalar,
......
...@@ -122,7 +122,7 @@ class OpFromGraph(gof.Op): ...@@ -122,7 +122,7 @@ class OpFromGraph(gof.Op):
.. TODO: .. TODO:
- examples for a multi-layer mlp. where? - examples for a multi-layer mlp. where?
- __hash__, __eq__ otherwise won't merge, try - __hash__, __eq__ otherwise won't merge, try
gof.opt.is_same_graph_with_merge(op1.local_outputs, op2, is_same_graph_with_merge(op1.local_outputs, op2,
local_outputs) local_outputs)
- c_code() to remove the double overhead? - c_code() to remove the double overhead?
- grad() make it support DisconnectedType and the new interface - grad() make it support DisconnectedType and the new interface
......
...@@ -24,7 +24,7 @@ from theano import config, gof ...@@ -24,7 +24,7 @@ from theano import config, gof
from theano.gof import graph from theano.gof import graph
from theano.compile.io import In, SymbolicInput, SymbolicOutput from theano.compile.io import In, SymbolicInput, SymbolicOutput
from theano.compile.ops import deep_copy_op, view_op from theano.compile.ops import deep_copy_op, view_op
from theano.gof.graph import is_same_graph from theano.gof.toolbox import is_same_graph
from theano.gof.op import ops_with_inner_function from theano.gof.op import ops_with_inner_function
_logger = logging.getLogger("theano.compile.function_module") _logger = logging.getLogger("theano.compile.function_module")
......
差异被折叠。
...@@ -1060,45 +1060,6 @@ class MergeOptimizer(Optimizer): ...@@ -1060,45 +1060,6 @@ class MergeOptimizer(Optimizer):
) )
def is_same_graph_with_merge(var1, var2, givens=None):
"""
Merge-based implementation of `theano.gof.graph.is_same_graph`.
See help on `theano.gof.graph.is_same_graph` for additional documentation.
"""
if givens is None:
givens = {}
# Copy variables since the MergeOptimizer will modify them.
copied = copy.deepcopy([var1, var2, givens])
vars = copied[0:2]
givens = copied[2]
# Create FunctionGraph.
inputs = theano.gof.graph.inputs(vars)
# The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens.
fgraph = theano.gof.fg.FunctionGraph(inputs, vars, clone=False)
# Perform Variable substitution.
for to_replace, replace_by in givens.items():
fgraph.replace(to_replace, replace_by)
# Perform merge optimization.
MergeOptimizer().optimize(fgraph)
# When two variables perform the same computations, they will have the same
# owner in the optimized graph.
# We need to be careful with the special case where the owner is None,
# which happens when the graph is made of a single Variable.
# We also need to make sure we replace a Variable if it is present in
# `givens`.
vars_replaced = [givens.get(v, v) for v in vars]
o1, o2 = [v.owner for v in vars_replaced]
if o1 is None and o2 is None:
# Comparing two single-Variable graphs: they are equal if they are
# the same Variable.
return vars_replaced[0] == vars_replaced[1]
else:
return o1 is o2
def pre_constant_merge(vars): def pre_constant_merge(vars):
""" """
Merge constants in the subgraph used to compute nodes in `vars`. Merge constants in the subgraph used to compute nodes in `vars`.
......
...@@ -13,7 +13,10 @@ from six.moves import StringIO ...@@ -13,7 +13,10 @@ from six.moves import StringIO
from theano import config from theano import config
from theano.gof.graph import ( from theano.gof.graph import (
inputs,
io_toposort, io_toposort,
equal_computations,
variables,
) )
...@@ -805,3 +808,154 @@ class NoOutputFromInplace(Feature): ...@@ -805,3 +808,154 @@ class NoOutputFromInplace(Feature):
"being computed by modifying another variable ", "being computed by modifying another variable ",
"inplace.", "inplace.",
) )
def is_same_graph_with_merge(var1, var2, givens=None):
"""
Merge-based implementation of `theano.gof.graph.is_same_graph`.
See help on `theano.gof.graph.is_same_graph` for additional documentation.
"""
from theano.gof.opt import MergeOptimizer
if givens is None:
givens = {}
# Copy variables since the MergeOptimizer will modify them.
copied = copy.deepcopy([var1, var2, givens])
vars = copied[0:2]
givens = copied[2]
# Create FunctionGraph.
graph_inputs = inputs(vars)
# The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens.
fgraph = theano.gof.fg.FunctionGraph(graph_inputs, vars, clone=False)
# Perform Variable substitution.
for to_replace, replace_by in givens.items():
fgraph.replace(to_replace, replace_by)
# Perform merge optimization.
MergeOptimizer().optimize(fgraph)
# When two variables perform the same computations, they will have the same
# owner in the optimized graph.
# We need to be careful with the special case where the owner is None,
# which happens when the graph is made of a single Variable.
# We also need to make sure we replace a Variable if it is present in
# `givens`.
vars_replaced = [givens.get(v, v) for v in vars]
o1, o2 = [v.owner for v in vars_replaced]
if o1 is None and o2 is None:
# Comparing two single-Variable graphs: they are equal if they are
# the same Variable.
return vars_replaced[0] == vars_replaced[1]
else:
return o1 is o2
def is_same_graph(var1, var2, givens=None):
"""
Return True iff Variables `var1` and `var2` perform the same computation.
By 'performing the same computation', we mean that they must share the same
graph, so that for instance this function will return False when comparing
(x * (y * z)) with ((x * y) * z).
The current implementation is not efficient since, when possible, it
verifies equality by calling two different functions that are expected to
return the same output. The goal is to verify this assumption, to
eventually get rid of one of them in the future.
Parameters
----------
var1
The first Variable to compare.
var2
The second Variable to compare.
givens
Similar to the `givens` argument of `theano.function`, it can be used
to perform substitutions in the computational graph of `var1` and
`var2`. This argument is associated to neither `var1` nor `var2`:
substitutions may affect both graphs if the substituted variable
is present in both.
Examples
--------
====== ====== ====== ======
var1 var2 givens output
====== ====== ====== ======
x + 1 x + 1 {} True
x + 1 y + 1 {} False
x + 1 y + 1 {x: y} True
====== ====== ====== ======
"""
use_equal_computations = True
if givens is None:
givens = {}
if not isinstance(givens, dict):
givens = dict(givens)
# Get result from the merge-based function.
rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens)
if givens:
# We need to build the `in_xs` and `in_ys` lists. To do this, we need
# to be able to tell whether a variable belongs to the computational
# graph of `var1` or `var2`.
# The typical case we want to handle is when `to_replace` belongs to
# one of these graphs, and `replace_by` belongs to the other one. In
# other situations, the current implementation of `equal_computations`
# is probably not appropriate, so we do not call it.
ok = True
in_xs = []
in_ys = []
# Compute the sets of all variables found in each computational graph.
inputs_var = list(map(inputs, ([var1], [var2])))
all_vars = [
set(variables(v_i, v_o))
for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2]))
]
def in_var(x, k):
# Return True iff `x` is in computation graph of variable `vark`.
return x in all_vars[k - 1]
for to_replace, replace_by in givens.items():
# Map a substitution variable to the computational graphs it
# belongs to.
inside = dict(
(v, [in_var(v, k) for k in (1, 2)]) for v in (to_replace, replace_by)
)
if (
inside[to_replace][0]
and not inside[to_replace][1]
and inside[replace_by][1]
and not inside[replace_by][0]
):
# Substitute variable in `var1` by one from `var2`.
in_xs.append(to_replace)
in_ys.append(replace_by)
elif (
inside[to_replace][1]
and not inside[to_replace][0]
and inside[replace_by][0]
and not inside[replace_by][1]
):
# Substitute variable in `var2` by one from `var1`.
in_xs.append(replace_by)
in_ys.append(to_replace)
else:
ok = False
break
if not ok:
# We cannot directly use `equal_computations`.
use_equal_computations = False
else:
in_xs = None
in_ys = None
if use_equal_computations:
rval2 = equal_computations(xs=[var1], ys=[var2], in_xs=in_xs, in_ys=in_ys)
assert rval2 == rval1
return rval1
...@@ -65,7 +65,7 @@ from theano.compile import function, In, Out ...@@ -65,7 +65,7 @@ from theano.compile import function, In, Out
from theano.compile.mode import AddFeatureOptimizer from theano.compile.mode import AddFeatureOptimizer
from theano import compile, config, gradient, gof, tensor from theano import compile, config, gradient, gof, tensor
from theano.gof import PureOp, Apply from theano.gof import PureOp, Apply
from theano.gof.graph import io_connection_pattern from theano.gof.graph import io_connection_pattern, equal_computations
from theano.gof.toolbox import NoOutputFromInplace from theano.gof.toolbox import NoOutputFromInplace
from theano.tensor import as_tensor_variable, TensorType from theano.tensor import as_tensor_variable, TensorType
...@@ -770,7 +770,7 @@ class Scan(PureOp): ...@@ -770,7 +770,7 @@ class Scan(PureOp):
if self_in.type != other_in.type: if self_in.type != other_in.type:
return False return False
return scan_utils.equal_computations( return equal_computations(
self.outputs, other.outputs, self.inputs, other.inputs self.outputs, other.outputs, self.inputs, other.inputs
) )
......
...@@ -69,9 +69,10 @@ from theano.compile import optdb ...@@ -69,9 +69,10 @@ from theano.compile import optdb
from theano.compile.function_module import deep_copy_op from theano.compile.function_module import deep_copy_op
from theano.gof import toolbox, DestroyHandler, InconsistencyError from theano.gof import toolbox, DestroyHandler, InconsistencyError
from theano.gof.opt import Optimizer, pre_constant_merge, pre_greedy_local_optimizer from theano.gof.opt import Optimizer, pre_constant_merge, pre_greedy_local_optimizer
from theano.gof.graph import equal_computations
from theano.scan_module.scan_utils import equal_computations, scan_args
from theano.scan_module import scan_op, scan_utils from theano.scan_module import scan_op, scan_utils
from theano.scan_module.scan_utils import scan_args
__docformat__ = "restructedtext en" __docformat__ = "restructedtext en"
__authors__ = ( __authors__ = (
...@@ -169,7 +170,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -169,7 +170,7 @@ def remove_constants_and_unused_inputs_scan(node):
elif op_ins[idx] in all_ins: elif op_ins[idx] in all_ins:
# Check for identical other sequence # Check for identical other sequence
identical_seqs = [ identical_seqs = [
x for x in nw_outer if scan_utils.equal_computations([x], [node_inp]) x for x in nw_outer if equal_computations([x], [node_inp])
] ]
if identical_seqs: if identical_seqs:
index = node.inputs.index(identical_seqs[0]) - 1 index = node.inputs.index(identical_seqs[0]) - 1
...@@ -195,7 +196,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -195,7 +196,7 @@ def remove_constants_and_unused_inputs_scan(node):
identical_nonseq_idx = [ identical_nonseq_idx = [
i i
for (i, x) in enumerate(nw_outer_nonseq) for (i, x) in enumerate(nw_outer_nonseq)
if scan_utils.equal_computations([x], [nw_out]) if equal_computations([x], [nw_out])
] ]
if identical_nonseq_idx: if identical_nonseq_idx:
givens[nw_in] = nw_inner_nonseq[identical_nonseq_idx[0]] givens[nw_in] = nw_inner_nonseq[identical_nonseq_idx[0]]
...@@ -1904,9 +1905,7 @@ class ScanMerge(gof.Optimizer): ...@@ -1904,9 +1905,7 @@ class ScanMerge(gof.Optimizer):
return True return True
cond = node.op.outputs[-1] cond = node.op.outputs[-1]
rep_cond = rep.op.outputs[-1] rep_cond = rep.op.outputs[-1]
return scan_utils.equal_computations( return equal_computations([cond], [rep_cond], node.op.inputs, rep.op.inputs)
[cond], [rep_cond], node.op.inputs, rep.op.inputs
)
def apply(self, fgraph): def apply(self, fgraph):
# Collect all scan nodes ordered according to toposort # Collect all scan nodes ordered according to toposort
......
...@@ -635,143 +635,6 @@ def expand_empty(tensor_var, size): ...@@ -635,143 +635,6 @@ def expand_empty(tensor_var, size):
return ret return ret
def equal_computations(xs, ys, in_xs=None, in_ys=None):
"""Checks if Theano graphs represent the same computations.
The two lists `xs`, `ys` should have the same number of entries. The
function checks if for any corresponding pair `(x,y)` from `zip(xs,ys)`
`x` and `y` represent the same computations on the same variables
(unless equivalences are provided using `in_xs`, `in_ys`).
If `in_xs` and `in_ys` are provided, then when comparing a node `x` with
a node `y` they are automatically considered as equal if there is some
index `i` such that `x == in_xs[i]` and `y == in_ys[i]`(and they both
have the same type). Note that `x` and `y` can be in the list `xs` and
`ys`, but also represent subgraphs of a computational graph in `xs`
or `ys`.
"""
assert len(xs) == len(ys)
if in_xs is None:
in_xs = []
if in_ys is None:
in_ys = []
for x, y in zip(xs, ys):
if x.owner and not y.owner:
return False
if y.owner and not x.owner:
return False
if x.owner: # Check above tell that y.owner eval to True too.
if x.owner.outputs.index(x) != y.owner.outputs.index(y):
return False
if x not in in_xs and x.type != y.type:
return False
if len(in_xs) != len(in_ys):
return False
for _x, _y in zip(in_xs, in_ys):
if _x.type != _y.type:
return False
common = set(zip(in_xs, in_ys))
different = set()
for dx, dy in zip(xs, ys):
# We checked above that both dx and dy have an owner or not
if not dx.owner:
if isinstance(dx, tensor.Constant) and isinstance(dy, tensor.Constant):
if not dx.equals(dy):
return False
else:
pass
elif (dx, dy) not in common and dx != dy:
return False
# Explore the two graphs, in parallel, depth first, comparing the nodes
# along the way for equality.
def compare_nodes(nd_x, nd_y, common, different):
"""
Compare two nodes to determine if they perform equal computation.
This is done by comparing the ops, the number of inputs, outputs and
by ensuring that the inputs themselves are the result of equal
computation.
NOTE : This function relies on the variable common to cache
results to be more efficient.
"""
if nd_x.op != nd_y.op:
return False
elif len(nd_x.inputs) != len(nd_y.inputs):
return False
elif len(nd_x.outputs) != len(nd_y.outputs):
return False
else:
all_in_common = True
for dx, dy in zip(nd_x.outputs, nd_y.outputs):
if (dx, dy) in different:
return False
if (dx, dy) not in common:
all_in_common = False
if all_in_common:
return True
# Compare the individual inputs for equality
for dx, dy in zip(nd_x.inputs, nd_y.inputs):
if (dx, dy) not in common:
# Equality between the variables is unknown, compare
# their respective owners, if they have some
if (
dx.owner
and dy.owner
and dx.owner.outputs.index(dx) == dy.owner.outputs.index(dy)
):
nodes_equal = compare_nodes(
dx.owner, dy.owner, common, different
)
if not nodes_equal:
different.add((dx, dy))
return False
# If both variables don't have an owner, then they are
# inputs and can be directly compared
elif dx.owner is None and dy.owner is None:
if dx != dy:
if isinstance(dx, tensor.Constant) and isinstance(
dy, tensor.Constant
):
if not dx.equals(dy):
return False
else:
return False
else:
return False
# If the code reaches this statement then the inputs are pair-wise
# equivalent so the outputs of the current nodes are also
# pair-wise equivalents
for dx, dy in zip(nd_x.outputs, nd_y.outputs):
common.add((dx, dy))
return True
# Validate that each xs[i], ys[i] pair represents the same computation
for i in range(len(xs)):
if xs[i].owner:
# The case where pairs of x[i]s and y[i]s don't both have an owner
# have already been addressed.
is_equal = compare_nodes(xs[i].owner, ys[i].owner, common, different)
if not is_equal:
return False
return True
def infer_shape(outs, inputs, input_shapes): def infer_shape(outs, inputs, input_shapes):
""" """
Compute the shape of the outputs given the shape of the inputs of a theano Compute the shape of the outputs given the shape of the inputs of a theano
...@@ -1413,7 +1276,7 @@ def forced_replace(out, x, y): ...@@ -1413,7 +1276,7 @@ def forced_replace(out, x, y):
if graph in visited: if graph in visited:
continue continue
visited.add(graph) visited.add(graph)
if equal_computations([graph], [x]): if gof.graph.equal_computations([graph], [x]):
to_replace.append((graph, y)) to_replace.append((graph, y))
elif graph.owner: elif graph.owner:
q.extendleft(graph.owner.inputs) q.extendleft(graph.owner.inputs)
......
...@@ -1633,7 +1633,7 @@ class ShapeFeature(object): ...@@ -1633,7 +1633,7 @@ class ShapeFeature(object):
# To be sure to cover all case, call equal_computation. # To be sure to cover all case, call equal_computation.
# Can't use theano.gof.graph.is_same_graph(dx, dy) # Can't use theano.gof.graph.is_same_graph(dx, dy)
# As it currently expect that dx and dy aren't in a FunctionGraph # As it currently expect that dx and dy aren't in a FunctionGraph
from theano.scan_module.scan_utils import equal_computations from theano.gof.graph import equal_computations
if not equal_computations([dx], [dy]): if not equal_computations([dx], [dy]):
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论