提交 9260f716 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/gof/unify.py

上级 54160cd1
""" """
If you have two expressions If you have two expressions containing unification variables, these expressions
containing unification variables, these expressions can be "unified" can be "unified" if there exists an assignment to all unification variables
if there exists an assignment to all unification variables such that such that the two expressions are equal.
the two expressions are equal. For instance, [5, A, B] and [A, C, 9]
can be unified if A=C=5 and B=9, yielding [5, 5, 9]. [5, [A, B]] and For instance, [5, A, B] and [A, C, 9] can be unified if A=C=5 and B=9,
[A, [1, 2]] cannot be unified because there is no value for A that yielding [5, 5, 9].
satisfies the constraints. That's useful for pattern matching. [5, [A, B]] and [A, [1, 2]] cannot be unified because there is no value for A
that satisfies the constraints. That's useful for pattern matching.
""" """
from __future__ import print_function from __future__ import print_function
from copy import copy from copy import copy
...@@ -26,12 +28,15 @@ class Variable: ...@@ -26,12 +28,15 @@ class Variable:
Behavior for unifying various types of variables should be added as Behavior for unifying various types of variables should be added as
overloadings of the 'unify' function. overloadings of the 'unify' function.
Note: there are two Variable classes in theano and this is the Notes
more rarely used one. -----
There are two Variable classes in theano and this is the more rarely used
one.
This class is used internally by the PatternSub optimization, This class is used internally by the PatternSub optimization,
and possibly other subroutines that have to perform graph queries. and possibly other subroutines that have to perform graph queries.
If that doesn't sound like what you're doing, the Variable class you If that doesn't sound like what you're doing, the Variable class you
want is probably theano.gof.graph.Variable want is probably theano.gof.graph.Variable.
""" """
def __init__(self, name="?"): def __init__(self, name="?"):
self.name = name self.name = name
...@@ -48,14 +53,18 @@ class Variable: ...@@ -48,14 +53,18 @@ class Variable:
class FreeVariable(Variable): class FreeVariable(Variable):
""" """
This Variable can take any value. This Variable can take any value.
""" """
pass pass
class BoundVariable(Variable): class BoundVariable(Variable):
""" """
This Variable is bound to a value accessible via the value field. This Variable is bound to a value accessible via the value field.
""" """
def __init__(self, name, value): def __init__(self, name, value):
self.name = name self.name = name
self.value = value self.value = value
...@@ -65,7 +74,9 @@ class OrVariable(Variable): ...@@ -65,7 +74,9 @@ class OrVariable(Variable):
""" """
This Variable could be any value from a finite list of values, This Variable could be any value from a finite list of values,
accessible via the options field. accessible via the options field.
""" """
def __init__(self, name, options): def __init__(self, name, options):
self.name = name self.name = name
self.options = options self.options = options
...@@ -75,7 +86,9 @@ class NotVariable(Variable): ...@@ -75,7 +86,9 @@ class NotVariable(Variable):
""" """
This Variable can take any value but a finite amount of forbidden This Variable can take any value but a finite amount of forbidden
values, accessible via the not_options field. values, accessible via the not_options field.
""" """
def __init__(self, name, not_options): def __init__(self, name, not_options):
self.name = name self.name = name
self.not_options = not_options self.not_options = not_options
...@@ -84,10 +97,12 @@ class NotVariable(Variable): ...@@ -84,10 +97,12 @@ class NotVariable(Variable):
class VariableInList: # not a subclass of Variable class VariableInList: # not a subclass of Variable
""" """
This special kind of variable is matched against a list and unifies This special kind of variable is matched against a list and unifies
an inner Variable to an OrVariable of the values in the list. For an inner Variable to an OrVariable of the values in the list.
example, if we unify VariableInList(FreeVariable('x')) to [1,2,3], For example, if we unify VariableInList(FreeVariable('x')) to [1,2,3],
the 'x' variable is unified to an OrVariable('?', [1,2,3]). the 'x' variable is unified to an OrVariable('?', [1,2,3]).
""" """
def __init__(self, variable): def __init__(self, variable):
self.variable = variable self.variable = variable
...@@ -120,13 +135,17 @@ class Unification: ...@@ -120,13 +135,17 @@ class Unification:
""" """
This class represents a possible unification of a group of variables This class represents a possible unification of a group of variables
with each other or with tangible values. with each other or with tangible values.
"""
def __init__(self, inplace=False):
""" Parameters
----------
inplace : bool
If inplace is False, the merge method will return a new Unification If inplace is False, the merge method will return a new Unification
that is independent from the previous one (which allows backtracking). that is independent from the previous one (which allows backtracking).
"""
"""
def __init__(self, inplace=False):
self.unif = {} self.unif = {}
self.inplace = inplace self.inplace = inplace
...@@ -134,6 +153,7 @@ class Unification: ...@@ -134,6 +153,7 @@ class Unification:
""" """
Links all the specified vars to a Variable that represents their Links all the specified vars to a Variable that represents their
unification. unification.
""" """
if self.inplace: if self.inplace:
U = self U = self
...@@ -163,6 +183,7 @@ class Unification: ...@@ -163,6 +183,7 @@ class Unification:
""" """
For a variable v, returns a Variable that represents the tightest For a variable v, returns a Variable that represents the tightest
set of possible values it can take. set of possible values it can take.
""" """
return self.unif.get(v, (v, None))[0] return self.unif.get(v, (v, None))[0]
...@@ -172,23 +193,25 @@ class Unification: ...@@ -172,23 +193,25 @@ class Unification:
def unify_walk(a, b, U): def unify_walk(a, b, U):
""" """
unify_walk(a, b, U) returns an Unification where a and b are unified, given the unify_walk(a, b, U) returns an Unification where a and b are unified,
unification that already exists in the Unification U. If the unification fails, given the unification that already exists in the Unification U. If the
it returns False. unification fails, it returns False.
There are two ways to expand the functionality of unify_walk. The first way is: There are two ways to expand the functionality of unify_walk. The first way
is:
@comm_guard(type_of_a, type_of_b) @comm_guard(type_of_a, type_of_b)
def unify_walk(a, b, U): def unify_walk(a, b, U):
... ...
A function defined as such will be executed whenever the types of a and b A function defined as such will be executed whenever the types of a and b
match the declaration. Note that comm_guard automatically guarantees that match the declaration. Note that comm_guard automatically guarantees that
your function is commutative: it will try to match the types of a, b or b, a. your function is commutative: it will try to match the types of a, b or
It is recommended to define unify_walk in that fashion for new types of Variable b, a.
because different types of Variable interact a lot with each other, e.g. It is recommended to define unify_walk in that fashion for new types of
when unifying an OrVariable with a NotVariable, etc. You can return the Variable because different types of Variable interact a lot with each other,
special marker FALL_THROUGH to indicate that you want to relay execution e.g. when unifying an OrVariable with a NotVariable, etc. You can return
to the next match of the type signature. The definitions of unify_walk are tried the special marker FALL_THROUGH to indicate that you want to relay execution
in the reverse order of their declaration. to the next match of the type signature. The definitions of unify_walk are
tried in the reverse order of their declaration.
Another way is to override __unify_walk__ in an user-defined class. Another way is to override __unify_walk__ in an user-defined class.
...@@ -209,7 +232,8 @@ def unify_walk(a, b, U): ...@@ -209,7 +232,8 @@ def unify_walk(a, b, U):
@comm_guard(FreeVariable, ANY_TYPE) @comm_guard(FreeVariable, ANY_TYPE)
def unify_walk(fv, o, U): def unify_walk(fv, o, U):
""" """
FreeV is unified to BoundVariable(other_object) FreeV is unified to BoundVariable(other_object).
""" """
v = BoundVariable("?", o) v = BoundVariable("?", o)
return U.merge(v, fv) return U.merge(v, fv)
...@@ -218,7 +242,8 @@ def unify_walk(fv, o, U): ...@@ -218,7 +242,8 @@ def unify_walk(fv, o, U):
@comm_guard(BoundVariable, ANY_TYPE) @comm_guard(BoundVariable, ANY_TYPE)
def unify_walk(bv, o, U): def unify_walk(bv, o, U):
""" """
The unification succeed iff BV.value == other_object The unification succeed iff BV.value == other_object.
""" """
if bv.value == o: if bv.value == o:
return U return U
...@@ -229,7 +254,8 @@ def unify_walk(bv, o, U): ...@@ -229,7 +254,8 @@ def unify_walk(bv, o, U):
@comm_guard(OrVariable, ANY_TYPE) @comm_guard(OrVariable, ANY_TYPE)
def unify_walk(ov, o, U): def unify_walk(ov, o, U):
""" """
The unification succeeds iff other_object in OrV.options The unification succeeds iff other_object in OrV.options.
""" """
if o in ov.options: if o in ov.options:
v = BoundVariable("?", o) v = BoundVariable("?", o)
...@@ -241,7 +267,8 @@ def unify_walk(ov, o, U): ...@@ -241,7 +267,8 @@ def unify_walk(ov, o, U):
@comm_guard(NotVariable, ANY_TYPE) @comm_guard(NotVariable, ANY_TYPE)
def unify_walk(nv, o, U): def unify_walk(nv, o, U):
""" """
The unification succeeds iff other_object not in NV.not_options The unification succeeds iff other_object not in NV.not_options.
""" """
if o in nv.not_options: if o in nv.not_options:
return False return False
...@@ -254,6 +281,7 @@ def unify_walk(nv, o, U): ...@@ -254,6 +281,7 @@ def unify_walk(nv, o, U):
def unify_walk(fv, v, U): def unify_walk(fv, v, U):
""" """
Both variables are unified. Both variables are unified.
""" """
v = U[v] v = U[v]
return U.merge(v, fv) return U.merge(v, fv)
...@@ -262,7 +290,8 @@ def unify_walk(fv, v, U): ...@@ -262,7 +290,8 @@ def unify_walk(fv, v, U):
@comm_guard(BoundVariable, Variable) @comm_guard(BoundVariable, Variable)
def unify_walk(bv, v, U): def unify_walk(bv, v, U):
""" """
V is unified to BV.value V is unified to BV.value.
""" """
return unify_walk(v, bv.value, U) return unify_walk(v, bv.value, U)
...@@ -271,6 +300,7 @@ def unify_walk(bv, v, U): ...@@ -271,6 +300,7 @@ def unify_walk(bv, v, U):
def unify_walk(a, b, U): def unify_walk(a, b, U):
""" """
OrV(list1) == OrV(list2) == OrV(intersection(list1, list2)) OrV(list1) == OrV(list2) == OrV(intersection(list1, list2))
""" """
opt = intersection(a.options, b.options) opt = intersection(a.options, b.options)
if not opt: if not opt:
...@@ -286,6 +316,7 @@ def unify_walk(a, b, U): ...@@ -286,6 +316,7 @@ def unify_walk(a, b, U):
def unify_walk(a, b, U): def unify_walk(a, b, U):
""" """
NV(list1) == NV(list2) == NV(union(list1, list2)) NV(list1) == NV(list2) == NV(union(list1, list2))
""" """
opt = union(a.not_options, b.not_options) opt = union(a.not_options, b.not_options)
v = NotVariable("?", opt) v = NotVariable("?", opt)
...@@ -296,6 +327,7 @@ def unify_walk(a, b, U): ...@@ -296,6 +327,7 @@ def unify_walk(a, b, U):
def unify_walk(o, n, U): def unify_walk(o, n, U):
""" """
OrV(list1) == NV(list2) == OrV(list1 \ list2) OrV(list1) == NV(list2) == OrV(list1 \ list2)
""" """
opt = [x for x in o.options if x not in n.not_options] opt = [x for x in o.options if x not in n.not_options]
if not opt: if not opt:
...@@ -311,6 +343,7 @@ def unify_walk(o, n, U): ...@@ -311,6 +343,7 @@ def unify_walk(o, n, U):
def unify_walk(vil, l, U): def unify_walk(vil, l, U):
""" """
Unifies VIL's inner Variable to OrV(list). Unifies VIL's inner Variable to OrV(list).
""" """
v = vil.variable v = vil.variable
ov = OrVariable("?", l) ov = OrVariable("?", l)
...@@ -321,6 +354,7 @@ def unify_walk(vil, l, U): ...@@ -321,6 +354,7 @@ def unify_walk(vil, l, U):
def unify_walk(l1, l2, U): def unify_walk(l1, l2, U):
""" """
Tries to unify each corresponding pair of elements from l1 and l2. Tries to unify each corresponding pair of elements from l1 and l2.
""" """
if len(l1) != len(l2): if len(l1) != len(l2):
return False return False
...@@ -335,6 +369,7 @@ def unify_walk(l1, l2, U): ...@@ -335,6 +369,7 @@ def unify_walk(l1, l2, U):
def unify_walk(d1, d2, U): def unify_walk(d1, d2, U):
""" """
Tries to unify values of corresponding keys. Tries to unify values of corresponding keys.
""" """
for (k1, v1) in iteritems(d1): for (k1, v1) in iteritems(d1):
if k1 in d2: if k1 in d2:
...@@ -349,6 +384,7 @@ def unify_walk(a, b, U): ...@@ -349,6 +384,7 @@ def unify_walk(a, b, U):
""" """
Checks for the existence of the __unify_walk__ method for one of Checks for the existence of the __unify_walk__ method for one of
the objects. the objects.
""" """
if (not isinstance(a, Variable) and if (not isinstance(a, Variable) and
not isinstance(b, Variable) and not isinstance(b, Variable) and
...@@ -364,6 +400,7 @@ def unify_walk(v, o, U): ...@@ -364,6 +400,7 @@ def unify_walk(v, o, U):
This simply checks if the Var has an unification in U and uses it This simply checks if the Var has an unification in U and uses it
instead of the Var. If the Var is already its tighest unification, instead of the Var. If the Var is already its tighest unification,
falls through. falls through.
""" """
best_v = U[v] best_v = U[v]
if v is not best_v: if v is not best_v:
...@@ -447,6 +484,7 @@ def unify_merge(v, o, U): ...@@ -447,6 +484,7 @@ def unify_merge(v, o, U):
This simply checks if the Var has an unification in U and uses it This simply checks if the Var has an unification in U and uses it
instead of the Var. If the Var is already its tighest unification, instead of the Var. If the Var is already its tighest unification,
falls through. falls through.
""" """
best_v = U[v] best_v = U[v]
if v is not best_v: if v is not best_v:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论