Unverified 提交 a9275c3d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #176 from brandonwillard/update-FunctionGraph-interface

Update `FunctionGraph` interface and add tests
import pickle
from theano import tensor as tt
from theano.gof.fg import FunctionGraph
import numpy as np
import pytest
from tests.gof.utils import MyVariable, MyVariable2, op1, op2, op3
from theano import change_flags
from theano.gof.fg import FunctionGraph, MissingInputError
from theano.gof.toolbox import BadOptimization
class TestFunctionGraph:
def test_pickle(self):
v = tt.vector()
func = FunctionGraph([v], [v + 1])
var1 = op1()
var2 = op2()
var3 = op1(var1)
var4 = op2(var3, var2)
func = FunctionGraph([var1, var2], [var4])
s = pickle.dumps(func)
pickle.loads(s)
new_func = pickle.loads(s)
assert all(type(a) == type(b) for a, b in zip(func.inputs, new_func.inputs))
assert all(type(a) == type(b) for a, b in zip(func.outputs, new_func.outputs))
assert all(
type(a.op) is type(b.op) # noqa: E721
for a, b in zip(func.apply_nodes, new_func.apply_nodes)
)
assert all(a.type == b.type for a, b in zip(func.variables, new_func.variables))
def test_validate_inputs(self):
var1 = op1()
var2 = op2()
with pytest.raises(TypeError):
FunctionGraph(var1, [var2])
with pytest.raises(TypeError):
FunctionGraph([var1], var2)
with pytest.raises(ValueError):
var3 = op1(var1)
FunctionGraph([var3], [var2], clone=False)
def test_init(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = op1(var1)
var4 = op2(var3, var2)
fg = FunctionGraph([var1, var2], [var3, var4], clone=False)
assert fg.inputs == [var1, var2]
assert fg.outputs == [var3, var4]
assert fg.apply_nodes == {var3.owner, var4.owner}
assert fg.update_mapping is None
assert fg.check_integrity() is None
assert fg.variables == {var1, var2, var3, var4}
assert fg.clients(var1) == [(var3.owner, 0)]
assert fg.clients(var2) == [(var4.owner, 1)]
assert fg.clients(var3) == [(var4.owner, 0), ("output", 0)]
assert fg.clients(var4) == [("output", 1)]
def test_remove_client(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = op1(var2, var1)
var4 = op2(var3, var2)
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
assert fg.variables == {var1, var2, var3, var4, var5}
assert fg.clients(var2) == [
(var3.owner, 0),
(var4.owner, 1),
(var5.owner, 1),
(var5.owner, 2),
]
fg.remove_client(var2, (var4.owner, 1))
assert fg.clients(var2) == [
(var3.owner, 0),
(var5.owner, 1),
(var5.owner, 2),
]
fg.remove_client(var1, (var3.owner, 1))
assert fg.clients(var1) == []
assert var4.owner in fg.apply_nodes
# This next `remove_client` should trigger a complete removal of `var4`'s
# variables and `Apply` node from the `FunctionGraph`.
#
# Also, notice that we already removed `var4` from `var2`'s client list
# above, so, when we completely remove `var4`, `fg.remove_client` will
# attempt to remove `(var4.owner, 1)` from `var2`'s client list again.
# This attempt would previously raise a `ValueError` exception, because
# the entry was not in the list.
fg.remove_client(var4, (var5.owner, 0), reason="testing")
assert var4.owner not in fg.apply_nodes
assert var4.owner.tag.removed_by == ["testing"]
assert not any(o in fg.variables for o in var4.owner.outputs)
def test_import_node(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = op1(var2, var1)
var4 = op2(var3, var2)
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
var5 = MyVariable("var5")
var6 = op2(var5)
with pytest.raises(MissingInputError):
fg.import_node(var6.owner)
var6 = op2(var2)
assert not hasattr(var6.owner.tag, "imported_by")
fg.import_node(var6.owner)
assert hasattr(var6.owner.tag, "imported_by")
assert var6 in fg.variables
assert var6.owner in fg.apply_nodes
assert (var6.owner, 0) in var2.clients
def test_import_var(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = op1(var2, var1)
var4 = op2(var3, var2)
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
with pytest.raises(MissingInputError):
var0 = MyVariable("var0")
# We can't import a new `FunctionGraph` input (i.e. something
# without an owner)
fg.import_var(var0, "testing")
var5 = op2()
# We can import variables with owners
fg.import_var(var5, "testing")
assert var5 in fg.variables
assert var5.owner in fg.apply_nodes
with pytest.raises(TypeError, match="Computation graph contains.*"):
from theano.gof.null_type import NullType
fg.import_var(NullType()(), "testing")
def test_change_input(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = op1(var2, var1)
var4 = op2(var3, var2)
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
var6 = MyVariable2("var6")
with pytest.raises(TypeError):
fg.change_input("output", 1, var6)
with pytest.raises(TypeError):
fg.change_input(var5.owner, 1, var6)
old_apply_nodes = set(fg.apply_nodes)
old_variables = set(fg.variables)
old_var5_clients = list(var5.clients)
# We're replacing with the same variable, so nothing should happen
fg.change_input(var5.owner, 1, var2)
assert old_apply_nodes == fg.apply_nodes
assert old_variables == fg.variables
assert old_var5_clients == var5.clients
# Perform a valid `Apply` node input change
fg.change_input(var5.owner, 1, var1)
assert var5.owner.inputs[1] is var1
assert (var5.owner, 1) not in var2.clients
@change_flags(compute_test_value="raise")
def test_replace_test_value(self):
var1 = MyVariable("var1")
var1.tag.test_value = 1
var2 = MyVariable("var2")
var2.tag.test_value = 2
var3 = op1(var2, var1)
var4 = op2(var3, var2)
var4.tag.test_value = np.array([1, 2])
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
var6 = op3()
var6.tag.test_value = np.array(0)
assert var6.tag.test_value.shape != var4.tag.test_value.shape
with pytest.raises(AssertionError, match="The replacement.*"):
fg.replace(var4, var6)
def test_replace(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = op1(var2, var1)
var4 = op2(var3, var2)
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
with pytest.raises(Exception, match="Cannot replace.*"):
var4.fgraph = object()
# Trigger a `FunctionGraph` ownership error
fg.replace(var4, var1, verbose=True)
var4.fgraph = fg
with pytest.raises(BadOptimization):
var0 = MyVariable2("var0")
# The types don't match and one cannot be converted to the other
fg.replace(var3, var0)
# Test a basic replacement
fg.replace_all([(var3, var1)])
assert var3 not in fg.variables
assert fg.apply_nodes == {var4.owner, var5.owner}
assert var4.owner.inputs == [var1, var2]
def test_replace_circular(self):
"""`FunctionGraph` allows cycles--for better or worse."""
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = op1(var2, var1)
var4 = op2(var3, var2)
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
fg.replace_all([(var3, var4)])
# The following works (and is kind of gross), because `var4` has been
# mutated in-place
assert fg.apply_nodes == {var4.owner, var5.owner}
assert var4.owner.inputs == [var4, var2]
def test_replace_bad_state(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = op1(var2, var1)
var4 = op2(var3, var2)
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
with pytest.raises(MissingInputError):
var0 = MyVariable("var0")
var0.fgraph = object()
# FIXME TODO XXX: This breaks the state of the `FunctionGraph`,
# because it doesn't check for validity of the replacement *first*.
fg.replace(var1, var0, verbose=True)
def test_check_integrity(self):
var1 = MyVariable("var1")
var2 = MyVariable("var2")
var3 = op1(var2, var1)
var4 = op2(var3, var2)
var5 = op3(var4, var2, var2)
fg = FunctionGraph([var1, var2], [var3, var5], clone=False)
with pytest.raises(Exception, match="The nodes are .*"):
fg.apply_nodes.remove(var5.owner)
fg.check_integrity()
with pytest.raises(Exception, match="Inconsistent clients.*"):
fg.apply_nodes.add(var5.owner)
var2.clients.remove((var5.owner, 1))
fg.check_integrity()
var2.clients.append((var5.owner, 1))
with pytest.raises(Exception, match="The variables are.*"):
fg.variables.remove(var4)
fg.check_integrity()
fg.variables.add(var4)
with pytest.raises(Exception, match="Undeclared input.*"):
var6 = MyVariable2("var6")
var6.fgraph = fg
var6.clients = [(var5.owner, 3)]
fg.variables.add(var6)
var5.owner.inputs.append(var6)
fg.check_integrity()
fg.variables.remove(var6)
var5.owner.inputs.remove(var6)
# TODO: What if the index value is greater than 1? It will throw an
# `IndexError`, but that doesn't sound like anything we'd want.
with pytest.raises(Exception, match="Inconsistent clients list.*"):
var4.clients.append(("output", 1))
fg.check_integrity()
var4.clients.remove(("output", 1))
with pytest.raises(Exception, match="Client not in FunctionGraph.*"):
var4.clients.append((var6.owner, 0))
fg.check_integrity()
var4.clients.remove((var6.owner, 0))
with pytest.raises(Exception, match="Inconsistent clients list.*"):
var4.clients.append((var3.owner, 0))
fg.check_integrity()
import theano.tensor as tt
from tests.gof.utils import MyType, MyVariable, op1, op2, op3, op4, op5, op6, op_y, op_z
from theano.gof.fg import FunctionGraph
from theano.gof.graph import Apply, Constant, Variable
from theano.gof.graph import Apply, Constant
from theano.gof.op import Op
from theano.gof.opt import (
EquilibriumOptimizer,
......@@ -15,82 +16,11 @@ from theano.gof.opt import (
pre_greedy_local_optimizer,
theano,
)
from theano.gof.type import Type
from theano.tensor.opt import constant_folding
from theano.tensor.subtensor import AdvancedSubtensor
from theano.tensor.type_other import MakeSlice, SliceConstant, slicetype
def is_variable(x):
if not isinstance(x, Variable):
raise TypeError("not a Variable", x)
return x
class MyType(Type):
def filter(self, data):
return data
def __eq__(self, other):
return isinstance(other, MyType)
def __hash__(self):
return hash(MyType)
def MyVariable(name):
return Variable(MyType(), None, None, name=name)
class MyOp(Op):
def __init__(self, name, dmap=None, x=None):
self.name = name
if dmap is None:
dmap = {}
self.destroy_map = dmap
self.x = x
def make_node(self, *inputs):
inputs = list(map(is_variable, inputs))
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyType()()]
return Apply(self, inputs, outputs)
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __eq__(self, other):
# rval = (self is other) or (isinstance(other, MyOp) and self.x is not None and self.x == other.x and self.name == other.name)
rval = (self is other) or (
isinstance(other, MyOp) and self.x is not None and self.x == other.x
)
return rval
def __hash__(self):
# return hash(self.x if self.x is not None else id(self)) ^ hash(self.name)
if self.x is not None:
return hash(self.x)
else:
return id(self)
op1 = MyOp("Op1")
op2 = MyOp("Op2")
op3 = MyOp("Op3")
op4 = MyOp("Op4")
op5 = MyOp("Op5")
op6 = MyOp("Op6")
op_d = MyOp("OpD", {0: [0]})
op_y = MyOp("OpY", x=1)
op_z = MyOp("OpZ", x=1)
def inputs():
x = MyVariable("x")
y = MyVariable("y")
......
import numpy as np
from theano.gof.graph import Apply, Variable
from theano.gof.op import Op
from theano.gof.type import Type
def is_variable(x):
if not isinstance(x, Variable):
raise TypeError("not a Variable", x)
return x
class MyType(Type):
def filter(self, data):
return data
def __eq__(self, other):
return isinstance(other, MyType)
def __hash__(self):
return hash(MyType)
class MyType2(Type):
def filter(self, data):
return data
def __eq__(self, other):
return isinstance(other, MyType)
def __hash__(self):
return hash(MyType)
def MyVariable(name):
return Variable(MyType(), None, None, name=name)
def MyVariable2(name):
return Variable(MyType2(), None, None, name=name)
class MyOp(Op):
def __init__(self, name, dmap=None, x=None):
self.name = name
if dmap is None:
dmap = {}
self.destroy_map = dmap
self.x = x
def make_node(self, *inputs):
inputs = list(map(is_variable, inputs))
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyType()()]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
outputs[0] = np.array(inputs)
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __eq__(self, other):
# rval = (self is other) or (isinstance(other, MyOp) and self.x is not None and self.x == other.x and self.name == other.name)
rval = (self is other) or (
isinstance(other, MyOp) and self.x is not None and self.x == other.x
)
return rval
def __hash__(self):
# return hash(self.x if self.x is not None else id(self)) ^ hash(self.name)
if self.x is not None:
return hash(self.x)
else:
return id(self)
op1 = MyOp("Op1")
op2 = MyOp("Op2")
op3 = MyOp("Op3")
op4 = MyOp("Op4")
op5 = MyOp("Op5")
op6 = MyOp("Op6")
op_d = MyOp("OpD", {0: [0]})
op_y = MyOp("OpY", x=1)
op_z = MyOp("OpZ", x=1)
......@@ -15,9 +15,6 @@ from theano.gof.utils import TestValueError, get_variable_trace_string
from theano.misc.ordered_set import OrderedSet
NullType = None
class InconsistencyError(Exception):
"""
This exception should be thrown by listeners to FunctionGraph when the
......@@ -105,14 +102,16 @@ class FunctionGraph(utils.object2):
Parameters
----------
inputs : list of variables
inputs : list of theano.gof.graph.Variable
Inputs nodes of the graph, usually declared by the user
outputs : list of variables
outputs : list of theano.gof.graph.Variable
Outputs nodes of the graph.
clone : boolean
If true, we will clone the graph. This is useful to remove the
constant cache problem.
update_mapping : dictionary
features : list of theano.gof.toolbox.Feature
A list of features to be added to the `FunctionGraph`.
update_mapping : dict
Mapping between the inputs with updates and the outputs
corresponding to their updates.
"""
......@@ -120,6 +119,12 @@ class FunctionGraph(utils.object2):
if clone:
inputs, outputs = graph.clone(inputs, outputs)
if not isinstance(inputs, list):
raise TypeError("Argument `inputs` should be a list")
if not isinstance(outputs, list):
raise TypeError("Argument `outputs` should be a list")
self.execute_callbacks_time = 0
self.execute_callbacks_times = {}
......@@ -139,47 +144,71 @@ class FunctionGraph(utils.object2):
# outputs even if they aren't used in the graph.
self.variables = set()
self.inputs = list(inputs)
# TODO FIXME: We should *not* be using a list created elsewhere!
self.outputs = outputs
for f in features:
self.attach_feature(f)
self.attach_feature(toolbox.ReplaceValidate())
for input in self.inputs:
if input.owner is not None:
self.inputs = []
for in_var in inputs:
if in_var.owner is not None:
raise ValueError(
"One of the provided inputs is the output of"
"One of the provided inputs is the output of "
"an already existing node. "
"If that is okay, either discard that "
"input's owner or use graph.clone."
)
self.__setup_r__(input)
self.variables.add(input)
self.add_input(in_var, check=False)
for output in outputs:
self.__import_r__(output, reason="init")
self.import_var(output, reason="init")
for i, output in enumerate(outputs):
output.clients.append(("output", i))
self.profile = None
self.update_mapping = update_mapping
def add_input(self, input):
if input not in self.inputs:
self.inputs.append(input)
self.__setup_r__(input)
self.variables.add(input)
def __setup_r__(self, r):
if hasattr(r, "fgraph") and r.fgraph is not None and r.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % r)
r.fgraph = self
r.clients = []
# self.execute_callbacks('on_setup_variable', r)
def __setup_node__(self, node):
# sets up node so it belongs to this fgraph
def add_input(self, var, check=True):
"""Add a new variable as an input to this `FunctionGraph`.
Parameters
----------
var : theano.gof.graph.Variable
"""
if check and var in self.inputs:
return
self.inputs.append(var)
self.setup_var(var)
self.variables.add(var)
def setup_var(self, var):
"""Set up a variable so it belongs to this `FunctionGraph`.
Parameters
----------
var : theano.gof.graph.Variable
"""
if hasattr(var, "fgraph") and var.fgraph is not None and var.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % var)
var.fgraph = self
var.clients = []
# self.execute_callbacks('on_setup_variable', var)
def setup_node(self, node):
"""Set up node so it belongs to this `FunctionGraph`.
Parameters
----------
node : theano.gof.graph.Apply
"""
if hasattr(node, "fgraph") and node.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % node)
if hasattr(node.op, "view_map") and not all(
......@@ -226,125 +255,141 @@ class FunctionGraph(utils.object2):
self.profile = None
self.update_mapping = None
# clients #
def clients(self, r):
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
Told differently, a list of (node,i) such that each node have
r as input at index i.
def clients(self, var):
"""Return a list of all the `(node, i)` pairs such that `node.inputs[i]` is `var`.
"""
return r.clients
Told differently, a `list` of `(node, i)` such that each node have
`var` as input at index `i`.
def __add_client__(self, r, new_client):
"""
Updates the list of clients of r with new_clients.
return var.clients
def add_client(self, var, new_client):
"""Update the clients of `var` with `new_clients`.
Parameters
----------
r
Variable.
new_client
(node, i) pair such that node.inputs[i] is r.
var : Variable.
new_client : (Apply, int)
A `(node, i)` pair such that `node.inputs[i]` is `var`.
"""
# Ne need to do the assert as it is always True. The logic
# that call __add_client__ is valid. When the client list is
# long, the check it time consuming, so we don't enable it by
# default.
# assert not new_client in r.clients
r.clients.append(new_client)
def __remove_client__(self, r, client_to_remove, reason=None):
"""
Removes all from the clients list of r.
var.clients.append(new_client)
def remove_client(self, var, client_to_remove, reason=None):
"""Recursively removes clients of a variable.
This is the main method to remove variable or apply node from
an FunctionGraph.
This is the main method to remove variables or `Apply` nodes from
a `FunctionGraph`.
Remove r from this fgraph if it don't have clients left. If it
have an owner and all the outputs of the owner have no
clients, it will be removed.
This will remove `var` from the `FunctionGraph` if it doesn't have any
clients remaining. If it has an owner and all the outputs of the owner
have no clients, it will also be removed.
Parameters
----------
r : Variable
The clients of r will be removed.
client_to_remove : (op, i) pair
(op, i) pair such that node.inputs[i] is not r anymore.
var : Variable
The clients of `var` that will be removed.
client_to_remove : pair of (Apply, int)
A `(node, i)` pair such that `node.inputs[i]` will no longer be
`var` in this `FunctionGraph`.
"""
l = [(r, client_to_remove)]
while l:
r, client_to_remove = l.pop()
r.clients.remove(client_to_remove)
# entry should be uniq in r. No need to assert it as it is
# already asserted in __add_client__.
# assert entry not in r.clients
if r.clients:
removal_stack = [(var, client_to_remove)]
while removal_stack:
var, client_to_remove = removal_stack.pop()
try:
var.clients.remove(client_to_remove)
except ValueError:
# In this case, the original `var` could've been removed from
# the current `var`'s client list before this call.
# There's nothing inherently wrong with that, so we continue as
# if it were removed here.
pass
if var.clients:
continue
# r have no more clients, so check if we need to remove it
# and its parent.
variable = r
if not variable.owner:
# A Constant or input without client. Remove it.
self.variables.remove(variable)
# This allow to quickly know if a var is still in the fgraph
# or not.
del variable.fgraph
# Now, `var` has no more clients, so check if we need to remove it
# and its `Apply` node
if not var.owner:
# The `var` is a `Constant` or an input without a client, so we
# remove it
self.variables.remove(var)
# This allows us to quickly determine if `var` is still in the
# `FunctionGraph`
# TODO: It's a poor approach; remove it
del var.fgraph
else:
apply_node = variable.owner
used = [output for output in apply_node.outputs if output.clients]
# If the apply node is not used and is not an output
if not used:
apply_node = var.owner
if not any(output.clients for output in apply_node.outputs):
# The `Apply` node is not used and is not an output, so we
# remove it and its outputs
if not hasattr(apply_node.tag, "removed_by"):
apply_node.tag.removed_by = []
apply_node.tag.removed_by.append(str(reason))
self.apply_nodes.remove(apply_node)
# del apply_node.fgraph
self.variables.difference_update(apply_node.outputs)
#
# for var in apply_node.outputs:
# del var.fgraph
self.variables.difference_update(apply_node.outputs)
self.execute_callbacks("on_prune", apply_node, reason)
for i, input in enumerate(apply_node.inputs):
l.append((input, (apply_node, i)))
for i, in_var in enumerate(apply_node.inputs):
removal_stack.append((in_var, (apply_node, i)))
def __import_r__(self, variable, reason):
"""
Import variables to this FunctionGraph and also their apply_node,
if those nodes are not in this graph.
def import_var(self, var, reason):
"""Import variables into this `FunctionGraph`.
This will also import the `variable`'s `Apply` node.
Parameters:
----------
reason
reason is the name of the optimization or operation in progress.
variable : theano.gof.graph.Variable
The variable to be imported.
reason : str
The name of the optimization or operation in progress.
"""
# Imports the owners of the variables
if variable.owner and variable.owner not in self.apply_nodes:
self.__import__(variable.owner, reason=reason)
if var.owner and var.owner not in self.apply_nodes:
self.import_node(var.owner, reason=reason)
elif (
variable.owner is None
and not isinstance(variable, graph.Constant)
and variable not in self.inputs
var.owner is None
and not isinstance(var, graph.Constant)
and var not in self.inputs
):
global NullType
if NullType is None:
from .null_type import NullType
if isinstance(variable.type, NullType):
from theano.gof.null_type import NullType
if isinstance(var.type, NullType):
raise TypeError(
"Computation graph contains a NaN. " + variable.type.why_null
"Computation graph contains a NaN. " + var.type.why_null
)
raise MissingInputError("Undeclared input", variable=variable)
if not getattr(variable, "fgraph", None) is self:
self.__setup_r__(variable)
self.variables.add(variable)
raise MissingInputError("Undeclared input", variable=var)
if not getattr(var, "fgraph", None) is self:
self.setup_var(var)
self.variables.add(var)
def __import__(self, apply_node, check=True, reason=None):
"""
Given an apply_node, recursively search from this node to know graph,
and then add all unknown variables and apply_nodes to this graph.
def import_node(self, apply_node, check=True, reason=None):
"""Recursively import everything between an `Apply` node and the `FunctionGraph`'s outputs.
Parameters:
----------
apply_node : theano.gof.graph.Apply
The node to be imported.
check : bool
Check that the inputs for the imported nodes are also present in
the `FunctionGraph`.
reason : str
The name of the optimization or operation in progress.
"""
node = apply_node
......@@ -358,13 +403,13 @@ class FunctionGraph(utils.object2):
for node in new_nodes:
if hasattr(node, "fgraph") and node.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % node)
for r in node.inputs:
if hasattr(r, "fgraph") and r.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % r)
for var in node.inputs:
if hasattr(var, "fgraph") and var.fgraph is not self:
raise Exception("%s is already owned by another fgraph" % var)
if (
r.owner is None
and not isinstance(r, graph.Constant)
and r not in self.inputs
var.owner is None
and not isinstance(var, graph.Constant)
and var not in self.inputs
):
# Standard error message
error_msg = (
......@@ -373,51 +418,60 @@ class FunctionGraph(utils.object2):
"provided and not given a value. Use the "
"Theano flag exception_verbosity='high', "
"for more information on this error."
% (node.inputs.index(r), str(node))
% (node.inputs.index(var), str(node))
)
raise MissingInputError(error_msg, variable=r)
raise MissingInputError(error_msg, variable=var)
for node in new_nodes:
assert node not in self.apply_nodes
self.__setup_node__(node)
self.setup_node(node)
self.apply_nodes.add(node)
if not hasattr(node.tag, "imported_by"):
node.tag.imported_by = []
node.tag.imported_by.append(str(reason))
for output in node.outputs:
self.__setup_r__(output)
self.setup_var(output)
self.variables.add(output)
for i, input in enumerate(node.inputs):
if input not in self.variables:
self.__setup_r__(input)
self.setup_var(input)
self.variables.add(input)
self.__add_client__(input, (node, i))
self.add_client(input, (node, i))
assert node.fgraph is self
self.execute_callbacks("on_import", node, reason)
# change input #
def change_input(self, node, i, new_r, reason=None):
"""
Changes node.inputs[i] to new_r.
def change_input(self, node, i, new_var, reason=None):
"""Change ``node.inputs[i]`` to `new_var`.
``new_var.type == old_var.type`` must be ``True``, where ``old_var`` is the
current value of ``node.inputs[i]`` which we want to replace.
new_r.type == old_r.type must be True, where old_r is the
current value of node.inputs[i] which we want to replace.
For each feature that has an `on_change_input` method, this method calls:
``feature.on_change_input(function_graph, node, i, old_var, new_var, reason)``
For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(function_graph, node, i, old_r, new_r, reason)
Parameters
----------
node : theano.gof.graph.Apply or str
The node for which an input is to be changed. If the value is
the string ``"output"`` then the ``self.outputs`` will be used
instead of ``node.inputs``.
i : int
The index in `node.inputs` that we want to change.
new_var : theano.gof.graph.Variable
The new variable to take the place of ``node.inputs[i]``.
"""
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if node == "output":
r = self.outputs[i]
if not r.type == new_r.type:
if not r.type == new_var.type:
raise TypeError(
"The type of the replacement must be the"
" same as the type of the original Variable.",
r,
new_r,
new_var,
)
self.outputs[i] = new_r
self.outputs[i] = new_var
else:
if node.fgraph is not self:
raise Exception(
......@@ -425,51 +479,63 @@ class FunctionGraph(utils.object2):
" belong to this FunctionGraph" % node
)
r = node.inputs[i]
if not r.type == new_r.type:
if not r.type == new_var.type:
raise TypeError(
"The type of the replacement must be the"
" same as the type of the original Variable.",
r,
new_r,
new_var,
)
node.inputs[i] = new_r
node.inputs[i] = new_var
if r is new_r:
if r is new_var:
return
self.__import_r__(new_r, reason=reason)
self.__add_client__(new_r, (node, i))
self.__remove_client__(r, (node, i), reason=reason)
self.import_var(new_var, reason=reason)
self.add_client(new_var, (node, i))
self.remove_client(r, (node, i), reason=reason)
# Precondition: the substitution is semantically valid
# However it may introduce cycles to the graph, in which case the
# transaction will be reverted later.
self.execute_callbacks("on_change_input", node, i, r, new_r, reason=reason)
self.execute_callbacks("on_change_input", node, i, r, new_var, reason=reason)
# replace #
def replace(self, r, new_r, reason=None, verbose=None):
"""
This is the main interface to manipulate the subgraph in FunctionGraph.
For every node that uses r as input, makes it use new_r instead.
def replace(self, var, new_var, reason=None, verbose=None):
"""Replace a variable in the `FunctionGraph`.
This is the main interface to manipulate the subgraph in `FunctionGraph`.
For every node that uses `var` as input, makes it use `new_var` instead.
Parameters:
----------
var : theano.gof.graph.Variable
The variable to be replaced.
new_var : theano.gof.graph.Variable
The variable to replace `var`.
reason : str
The name of the optimization or operation in progress.
verbose : bool
Print `reason`, `var`, and `new_var`.
"""
if verbose is None:
verbose = config.optimizer_verbose
if verbose:
print(reason, r, new_r)
if hasattr(r, "fgraph") and r.fgraph is not self:
print(reason, var, new_var)
if hasattr(var, "fgraph") and var.fgraph is not self:
raise Exception(
"Cannot replace %s because it does not belong "
"to this FunctionGraph" % r,
"to this FunctionGraph" % var,
str(reason),
)
if r.type != new_r.type:
new_r2 = r.type.convert_variable(new_r)
if var.type != new_var.type:
new_var_2 = var.type.convert_variable(new_var)
# We still make sure that the type converts correctly
if new_r2 is None or new_r2.type != r.type:
if new_var_2 is None or new_var_2.type != var.type:
done = dict()
used_ids = dict()
old = theano.compile.debugmode.debugprint(
r,
var,
prefix=" ",
depth=6,
file=StringIO(),
......@@ -478,7 +544,7 @@ class FunctionGraph(utils.object2):
used_ids=used_ids,
).getvalue()
new = theano.compile.debugmode.debugprint(
new_r,
new_var,
prefix=" ",
depth=6,
file=StringIO(),
......@@ -487,16 +553,17 @@ class FunctionGraph(utils.object2):
used_ids=used_ids,
).getvalue()
raise toolbox.BadOptimization(
r,
new_r,
var,
new_var,
None,
None,
str(reason) + ". The type of the replacement must be the same.",
old,
new,
)
new_r = new_r2
if r not in self.variables:
new_var = new_var_2
if var not in self.variables:
# this variable isn't in the graph... don't raise an
# exception here, just return silently because it makes it
# easier to implement some optimizations for
......@@ -505,8 +572,8 @@ class FunctionGraph(utils.object2):
if theano.config.compute_test_value != "off":
try:
tval = theano.gof.op.get_test_value(r)
new_tval = theano.gof.op.get_test_value(new_r)
tval = theano.gof.op.get_test_value(var)
new_tval = theano.gof.op.get_test_value(new_var)
except TestValueError:
pass
else:
......@@ -518,27 +585,21 @@ class FunctionGraph(utils.object2):
"a shape different from the original variable's "
"test value. Original: %s, new: %s"
% (tval_shape, new_tval_shape),
r,
new_r,
var,
new_var,
str(reason),
)
for node, i in list(r.clients): # copy the client list for iteration
assert (node == "output" and self.outputs[i] is r) or (node.inputs[i] is r)
self.change_input(node, i, new_r, reason=reason)
# sometimes the following is triggered. If you understand why, please explain to James.
# He's curious... -JB20090331
# if len(r.clients) != 0:
# print >> sys.stderr, "WARNING: CLIENTS LEFT AFTER REPLACE", r, r.clients
for node, i in list(var.clients): # copy the client list for iteration
assert (node == "output" and self.outputs[i] is var) or (
node.inputs[i] is var
)
self.change_input(node, i, new_var, reason=reason)
def replace_all(self, pairs, reason=None):
"""
For every node that uses r as input, makes it use new_r instead
"""
for r, new_r in pairs:
self.replace(r, new_r, reason=reason)
"""Replace variables in the `FunctionGraph` according to `(var, new_var)` pairs in a list."""
for var, new_var in pairs:
self.replace(var, new_var, reason=reason)
def attach_feature(self, feature):
"""
......@@ -587,7 +648,6 @@ class FunctionGraph(utils.object2):
if detach is not None:
detach(self)
# callback utils #
def execute_callbacks(self, name, *args, **kwargs):
"""Execute callbacks
......@@ -625,7 +685,6 @@ class FunctionGraph(utils.object2):
d[feature] = fn(*args)
return d
# misc #
def toposort(self):
"""Toposort
......@@ -655,17 +714,16 @@ class FunctionGraph(utils.object2):
return order
def orderings(self):
"""
Return dict d s.t. d[node] is a list of nodes that must be evaluated
before node itself can be evaluated.
"""Return `dict` `d` s.t. `d[node]` is a list of nodes that must be evaluated before `node` itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that
all clients of any destroyed inputs have already computed their outputs.
the clients of any destroyed inputs have already computed their
outputs.
Notes
-----
This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself.
This only calls the `orderings()` function on all features. It does not
take care of computing the dependencies by itself.
"""
assert isinstance(self._features, list)
......@@ -769,7 +827,6 @@ class FunctionGraph(utils.object2):
def __repr__(self):
return self.__str__()
# clone #
def clone(self, check_integrity=True):
"""
Clone the graph and get a memo( a dict )that map old node to new node
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论