提交 e1d46639 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove theano.gof.unify-specific objects from theano.gof.utils

上级 9a32adb3
...@@ -13,10 +13,117 @@ that satisfies the constraints. That's useful for pattern matching. ...@@ -13,10 +13,117 @@ that satisfies the constraints. That's useful for pattern matching.
from copy import copy from copy import copy
from functools import partial from functools import partial
from theano.gof.utils import ANY_TYPE, FALL_THROUGH, comm_guard
class Keyword:
def __init__(self, name, nonzero=True):
self.name = name
self.nonzero = nonzero
def __nonzero__(self):
# Python 2.x
return self.__bool__()
def __bool__(self):
# Python 3.x
return self.nonzero
def __str__(self):
return f"<{self.name}>"
def __repr__(self):
return f"<{self.name}>"
ABORT = Keyword("ABORT", False)
RETRY = Keyword("RETRY", False)
FAILURE = Keyword("FAILURE", False)
simple_types = (int, str, float, bool, type(None), Keyword)
################################ ANY_TYPE = Keyword("ANY_TYPE")
FALL_THROUGH = Keyword("FALL_THROUGH")
def comm_guard(type1, type2):
def wrap(f):
old_f = f.__globals__[f.__name__]
def new_f(arg1, arg2, *rest):
if (type1 is ANY_TYPE or isinstance(arg1, type1)) and (
type2 is ANY_TYPE or isinstance(arg2, type2)
):
pass
elif (type1 is ANY_TYPE or isinstance(arg2, type1)) and (
type2 is ANY_TYPE or isinstance(arg1, type2)
):
arg1, arg2 = arg2, arg1
else:
return old_f(arg1, arg2, *rest)
variable = f(arg1, arg2, *rest)
if variable is FALL_THROUGH:
return old_f(arg1, arg2, *rest)
else:
return variable
new_f.__name__ = f.__name__
def typename(type):
if isinstance(type, Keyword):
return str(type)
elif isinstance(type, (tuple, list)):
return "(" + ", ".join([x.__name__ for x in type]) + ")"
else:
return type.__name__
new_f.__doc__ = (
str(old_f.__doc__)
+ "\n"
+ ", ".join([typename(type) for type in (type1, type2)])
+ "\n"
+ str(f.__doc__ or "")
)
return new_f
return wrap
def type_guard(type1):
def wrap(f):
old_f = f.__globals__[f.__name__]
def new_f(arg1, *rest):
if type1 is ANY_TYPE or isinstance(arg1, type1):
variable = f(arg1, *rest)
if variable is FALL_THROUGH:
return old_f(arg1, *rest)
else:
return variable
else:
return old_f(arg1, *rest)
new_f.__name__ = f.__name__
def typename(type):
if isinstance(type, Keyword):
return str(type)
elif isinstance(type, (tuple, list)):
return "(" + ", ".join([x.__name__ for x in type]) + ")"
else:
return type.__name__
new_f.__doc__ = (
str(old_f.__doc__)
+ "\n"
+ ", ".join([typename(type) for type in (type1,)])
+ "\n"
+ str(f.__doc__ or "")
)
return new_f
return wrap
class Variable: class Variable:
...@@ -111,9 +218,6 @@ class VariableInList: # not a subclass of Variable ...@@ -111,9 +218,6 @@ class VariableInList: # not a subclass of Variable
self.variable = variable self.variable = variable
################################
_all = {} _all = {}
...@@ -133,9 +237,6 @@ OrV = partial(var_lookup, OrVariable) ...@@ -133,9 +237,6 @@ OrV = partial(var_lookup, OrVariable)
NV = partial(var_lookup, NotVariable) NV = partial(var_lookup, NotVariable)
################################
class Unification: class Unification:
""" """
This class represents a possible unification of a group of variables This class represents a possible unification of a group of variables
...@@ -191,9 +292,6 @@ class Unification: ...@@ -191,9 +292,6 @@ class Unification:
return self.unif.get(v, (v, None))[0] return self.unif.get(v, (v, None))[0]
################################
def unify_walk(a, b, U): def unify_walk(a, b, U):
""" """
unify_walk(a, b, U) returns an Unification where a and b are unified, unify_walk(a, b, U) returns an Unification where a and b are unified,
...@@ -416,9 +514,6 @@ def unify_walk(v, o, U): ...@@ -416,9 +514,6 @@ def unify_walk(v, o, U):
return FALL_THROUGH # call the next version of unify_walk that matches the type signature return FALL_THROUGH # call the next version of unify_walk that matches the type signature
################################
class FVar: class FVar:
def __init__(self, fn, *args): def __init__(self, fn, *args):
self.fn = fn self.fn = fn
...@@ -428,9 +523,6 @@ class FVar: ...@@ -428,9 +523,6 @@ class FVar:
return self.fn(*[unify_build(arg, u) for arg in self.args]) return self.fn(*[unify_build(arg, u) for arg in self.args])
################################
def unify_merge(a, b, U): def unify_merge(a, b, U):
return a return a
...@@ -503,54 +595,13 @@ def unify_merge(v, o, U): ...@@ -503,54 +595,13 @@ def unify_merge(v, o, U):
return FALL_THROUGH # call the next version of unify_walk that matches the type signature return FALL_THROUGH # call the next version of unify_walk that matches the type signature
################################
def unify_build(x, U): def unify_build(x, U):
return unify_merge(x, x, U) return unify_merge(x, x, U)
################################
def unify(a, b): def unify(a, b):
U = unify_walk(a, b, Unification()) U = unify_walk(a, b, Unification())
if not U: if not U:
return None, False return None, False
else: else:
return unify_merge(a, b, U), U return unify_merge(a, b, U), U
################################
if __name__ == "__main__":
vx = NotVariable("x", ["big", "bones"])
vy = OrVariable("y", ["hello", "big"])
vz = V("z")
va = V("a")
vl = VariableInList(vz)
pattern1 = dict(hey=vx, ulala=va, a=1)
pattern2 = dict(hey=vy, ulala=10, b=2)
# pattern1 = ["hello", "big", "bones"]
# pattern2 = vl
# pattern1 = [vx]#, "big", "bones"]
# pattern2 = [vy]#, vy, vz]
U = unify_walk(pattern1, pattern2, Unification())
if U:
print(U[va])
print(U[vx])
print(U[vy])
print(U[vz])
print(unify_merge(pattern1, pattern2, U))
else:
print("no match")
U = unify_walk((1, 2), (va, va), Unification())
print(U[va])
...@@ -461,118 +461,6 @@ def toposort(prereqs_d): ...@@ -461,118 +461,6 @@ def toposort(prereqs_d):
return seq return seq
class Keyword:
def __init__(self, name, nonzero=True):
self.name = name
self.nonzero = nonzero
def __nonzero__(self):
# Python 2.x
return self.__bool__()
def __bool__(self):
# Python 3.x
return self.nonzero
def __str__(self):
return f"<{self.name}>"
def __repr__(self):
return f"<{self.name}>"
ABORT = Keyword("ABORT", False)
RETRY = Keyword("RETRY", False)
FAILURE = Keyword("FAILURE", False)
simple_types = (int, str, float, bool, type(None), Keyword)
ANY_TYPE = Keyword("ANY_TYPE")
FALL_THROUGH = Keyword("FALL_THROUGH")
def comm_guard(type1, type2):
def wrap(f):
old_f = f.__globals__[f.__name__]
def new_f(arg1, arg2, *rest):
if (type1 is ANY_TYPE or isinstance(arg1, type1)) and (
type2 is ANY_TYPE or isinstance(arg2, type2)
):
pass
elif (type1 is ANY_TYPE or isinstance(arg2, type1)) and (
type2 is ANY_TYPE or isinstance(arg1, type2)
):
arg1, arg2 = arg2, arg1
else:
return old_f(arg1, arg2, *rest)
variable = f(arg1, arg2, *rest)
if variable is FALL_THROUGH:
return old_f(arg1, arg2, *rest)
else:
return variable
new_f.__name__ = f.__name__
def typename(type):
if isinstance(type, Keyword):
return str(type)
elif isinstance(type, (tuple, list)):
return "(" + ", ".join([x.__name__ for x in type]) + ")"
else:
return type.__name__
new_f.__doc__ = (
str(old_f.__doc__)
+ "\n"
+ ", ".join([typename(type) for type in (type1, type2)])
+ "\n"
+ str(f.__doc__ or "")
)
return new_f
return wrap
def type_guard(type1):
def wrap(f):
old_f = f.__globals__[f.__name__]
def new_f(arg1, *rest):
if type1 is ANY_TYPE or isinstance(arg1, type1):
variable = f(arg1, *rest)
if variable is FALL_THROUGH:
return old_f(arg1, *rest)
else:
return variable
else:
return old_f(arg1, *rest)
new_f.__name__ = f.__name__
def typename(type):
if isinstance(type, Keyword):
return str(type)
elif isinstance(type, (tuple, list)):
return "(" + ", ".join([x.__name__ for x in type]) + ")"
else:
return type.__name__
new_f.__doc__ = (
str(old_f.__doc__)
+ "\n"
+ ", ".join([typename(type) for type in (type1,)])
+ "\n"
+ str(f.__doc__ or "")
)
return new_f
return wrap
def flatten(a): def flatten(a):
""" """
Recursively flatten tuple, list and set in a list. Recursively flatten tuple, list and set in a list.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论