提交 9260f716 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/gof/unify.py

上级 54160cd1
"""
If you have two expressions
containing unification variables, these expressions can be "unified"
if there exists an assignment to all unification variables such that
the two expressions are equal. For instance, [5, A, B] and [A, C, 9]
can be unified if A=C=5 and B=9, yielding [5, 5, 9]. [5, [A, B]] and
[A, [1, 2]] cannot be unified because there is no value for A that
satisfies the constraints. That's useful for pattern matching.
If you have two expressions containing unification variables, these expressions
can be "unified" if there exists an assignment to all unification variables
such that the two expressions are equal.
For instance, [5, A, B] and [A, C, 9] can be unified if A=C=5 and B=9,
yielding [5, 5, 9].
[5, [A, B]] and [A, [1, 2]] cannot be unified because there is no value for A
that satisfies the constraints. That's useful for pattern matching.
"""
from __future__ import print_function
from copy import copy
......@@ -26,12 +28,15 @@ class Variable:
Behavior for unifying various types of variables should be added as
overloadings of the 'unify' function.
Note: there are two Variable classes in theano and this is the
more rarely used one.
Notes
-----
There are two Variable classes in theano and this is the more rarely used
one.
This class is used internally by the PatternSub optimization,
and possibly other subroutines that have to perform graph queries.
If that doesn't sound like what you're doing, the Variable class you
want is probably theano.gof.graph.Variable
want is probably theano.gof.graph.Variable.
"""
def __init__(self, name="?"):
self.name = name
......@@ -48,14 +53,18 @@ class Variable:
class FreeVariable(Variable):
"""
This Variable can take any value.
"""
pass
class BoundVariable(Variable):
"""
This Variable is bound to a value accessible via the value field.
"""
def __init__(self, name, value):
self.name = name
self.value = value
......@@ -65,7 +74,9 @@ class OrVariable(Variable):
"""
This Variable could be any value from a finite list of values,
accessible via the options field.
"""
def __init__(self, name, options):
self.name = name
self.options = options
......@@ -75,7 +86,9 @@ class NotVariable(Variable):
"""
This Variable can take any value but a finite amount of forbidden
values, accessible via the not_options field.
"""
def __init__(self, name, not_options):
self.name = name
self.not_options = not_options
......@@ -84,10 +97,12 @@ class NotVariable(Variable):
class VariableInList: # not a subclass of Variable
"""
This special kind of variable is matched against a list and unifies
an inner Variable to an OrVariable of the values in the list. For
example, if we unify VariableInList(FreeVariable('x')) to [1,2,3],
an inner Variable to an OrVariable of the values in the list.
For example, if we unify VariableInList(FreeVariable('x')) to [1,2,3],
the 'x' variable is unified to an OrVariable('?', [1,2,3]).
"""
def __init__(self, variable):
self.variable = variable
......@@ -120,13 +135,17 @@ class Unification:
"""
This class represents a possible unification of a group of variables
with each other or with tangible values.
"""
def __init__(self, inplace=False):
"""
Parameters
----------
inplace : bool
If inplace is False, the merge method will return a new Unification
that is independent from the previous one (which allows backtracking).
"""
"""
def __init__(self, inplace=False):
self.unif = {}
self.inplace = inplace
......@@ -134,6 +153,7 @@ class Unification:
"""
Links all the specified vars to a Variable that represents their
unification.
"""
if self.inplace:
U = self
......@@ -163,6 +183,7 @@ class Unification:
"""
For a variable v, returns a Variable that represents the tightest
set of possible values it can take.
"""
return self.unif.get(v, (v, None))[0]
......@@ -172,23 +193,25 @@ class Unification:
def unify_walk(a, b, U):
"""
unify_walk(a, b, U) returns an Unification where a and b are unified, given the
unification that already exists in the Unification U. If the unification fails,
it returns False.
unify_walk(a, b, U) returns an Unification where a and b are unified,
given the unification that already exists in the Unification U. If the
unification fails, it returns False.
There are two ways to expand the functionality of unify_walk. The first way is:
There are two ways to expand the functionality of unify_walk. The first way
is:
@comm_guard(type_of_a, type_of_b)
def unify_walk(a, b, U):
...
A function defined as such will be executed whenever the types of a and b
match the declaration. Note that comm_guard automatically guarantees that
your function is commutative: it will try to match the types of a, b or b, a.
It is recommended to define unify_walk in that fashion for new types of Variable
because different types of Variable interact a lot with each other, e.g.
when unifying an OrVariable with a NotVariable, etc. You can return the
special marker FALL_THROUGH to indicate that you want to relay execution
to the next match of the type signature. The definitions of unify_walk are tried
in the reverse order of their declaration.
your function is commutative: it will try to match the types of a, b or
b, a.
It is recommended to define unify_walk in that fashion for new types of
Variable because different types of Variable interact a lot with each other,
e.g. when unifying an OrVariable with a NotVariable, etc. You can return
the special marker FALL_THROUGH to indicate that you want to relay execution
to the next match of the type signature. The definitions of unify_walk are
tried in the reverse order of their declaration.
Another way is to override __unify_walk__ in an user-defined class.
......@@ -209,7 +232,8 @@ def unify_walk(a, b, U):
@comm_guard(FreeVariable, ANY_TYPE)
def unify_walk(fv, o, U):
"""
FreeV is unified to BoundVariable(other_object)
FreeV is unified to BoundVariable(other_object).
"""
v = BoundVariable("?", o)
return U.merge(v, fv)
......@@ -218,7 +242,8 @@ def unify_walk(fv, o, U):
@comm_guard(BoundVariable, ANY_TYPE)
def unify_walk(bv, o, U):
"""
The unification succeed iff BV.value == other_object
The unification succeed iff BV.value == other_object.
"""
if bv.value == o:
return U
......@@ -229,7 +254,8 @@ def unify_walk(bv, o, U):
@comm_guard(OrVariable, ANY_TYPE)
def unify_walk(ov, o, U):
"""
The unification succeeds iff other_object in OrV.options
The unification succeeds iff other_object in OrV.options.
"""
if o in ov.options:
v = BoundVariable("?", o)
......@@ -241,7 +267,8 @@ def unify_walk(ov, o, U):
@comm_guard(NotVariable, ANY_TYPE)
def unify_walk(nv, o, U):
"""
The unification succeeds iff other_object not in NV.not_options
The unification succeeds iff other_object not in NV.not_options.
"""
if o in nv.not_options:
return False
......@@ -254,6 +281,7 @@ def unify_walk(nv, o, U):
def unify_walk(fv, v, U):
"""
Both variables are unified.
"""
v = U[v]
return U.merge(v, fv)
......@@ -262,7 +290,8 @@ def unify_walk(fv, v, U):
@comm_guard(BoundVariable, Variable)
def unify_walk(bv, v, U):
"""
V is unified to BV.value
V is unified to BV.value.
"""
return unify_walk(v, bv.value, U)
......@@ -271,6 +300,7 @@ def unify_walk(bv, v, U):
def unify_walk(a, b, U):
"""
OrV(list1) == OrV(list2) == OrV(intersection(list1, list2))
"""
opt = intersection(a.options, b.options)
if not opt:
......@@ -286,6 +316,7 @@ def unify_walk(a, b, U):
def unify_walk(a, b, U):
"""
NV(list1) == NV(list2) == NV(union(list1, list2))
"""
opt = union(a.not_options, b.not_options)
v = NotVariable("?", opt)
......@@ -296,6 +327,7 @@ def unify_walk(a, b, U):
def unify_walk(o, n, U):
"""
OrV(list1) == NV(list2) == OrV(list1 \ list2)
"""
opt = [x for x in o.options if x not in n.not_options]
if not opt:
......@@ -311,6 +343,7 @@ def unify_walk(o, n, U):
def unify_walk(vil, l, U):
"""
Unifies VIL's inner Variable to OrV(list).
"""
v = vil.variable
ov = OrVariable("?", l)
......@@ -321,6 +354,7 @@ def unify_walk(vil, l, U):
def unify_walk(l1, l2, U):
"""
Tries to unify each corresponding pair of elements from l1 and l2.
"""
if len(l1) != len(l2):
return False
......@@ -335,6 +369,7 @@ def unify_walk(l1, l2, U):
def unify_walk(d1, d2, U):
"""
Tries to unify values of corresponding keys.
"""
for (k1, v1) in iteritems(d1):
if k1 in d2:
......@@ -349,6 +384,7 @@ def unify_walk(a, b, U):
"""
Checks for the existence of the __unify_walk__ method for one of
the objects.
"""
if (not isinstance(a, Variable) and
not isinstance(b, Variable) and
......@@ -364,6 +400,7 @@ def unify_walk(v, o, U):
This simply checks if the Var has an unification in U and uses it
instead of the Var. If the Var is already its tighest unification,
falls through.
"""
best_v = U[v]
if v is not best_v:
......@@ -447,6 +484,7 @@ def unify_merge(v, o, U):
This simply checks if the Var has an unification in U and uses it
instead of the Var. If the Var is already its tighest unification,
falls through.
"""
best_v = U[v]
if v is not best_v:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论