提交 84a66c9f authored 作者: Olivier Breuleux's avatar Olivier Breuleux

improved PatternOptimizer

上级 3451e700
......@@ -165,7 +165,7 @@ class Value(Result):
return self.name
return "<" + str(self.data) + ">" #+ "::" + str(self.type)
def clone(self):
return self.__class__(self.type, self.data)
return self.__class__(self.type, self.data, self.name)
def __set_owner(self, value):
if value is not None:
raise ValueError("Value instances cannot have an owner.")
......
......@@ -352,26 +352,26 @@ class PatternSub(LocalOptimizer):
"""
if node.op != self.op:
return False
def match(pattern, expr, u, first = False):
def match(pattern, expr, u, allow_multiple_clients = False):
if isinstance(pattern, (list, tuple)):
if expr.owner is None:
return False
if not (expr.owner.op == pattern[0]) or (not self.allow_multiple_clients and not first and len(expr.clients) > 1):
if not (expr.owner.op == pattern[0]) or (not allow_multiple_clients and len(expr.clients) > 1):
return False
if len(pattern) - 1 != len(expr.owner.inputs):
return False
for p, v in zip(pattern[1:], expr.owner.inputs):
u = match(p, v, u)
u = match(p, v, u, self.allow_multiple_clients)
if not u:
return False
elif isinstance(pattern, dict):
try:
real_pattern = pattern['pattern']
constraint = pattern['constraint']
except KeyError:
raise KeyError("Malformed pattern: %s (expected keys pattern and constraint)" % pattern)
raise KeyError("Malformed pattern: %s (expected key 'pattern')" % pattern)
constraint = pattern.get('constraint', lambda expr: True)
if constraint(expr):
return match(real_pattern, expr, u, False)
return match(real_pattern, expr, u, pattern.get('allow_multiple_clients', False))
else:
return False
elif isinstance(pattern, str):
......@@ -393,7 +393,7 @@ class PatternSub(LocalOptimizer):
elif isinstance(pattern, str):
return u[unify.Var(pattern)]
else:
return pattern
return pattern.clone()
u = match(self.in_pattern, node.out, unify.Unification(), True)
if u:
......@@ -408,7 +408,7 @@ class PatternSub(LocalOptimizer):
if isinstance(pattern, (list, tuple)):
return "%s(%s)" % (str(pattern[0]), ", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict):
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern['constraint']))
return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern.get('constraint', 'no conditions')))
else:
return str(pattern)
return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
......
......@@ -757,7 +757,7 @@ class Subtensor(Op):
rest = inputs[1:]
return [SetSubtensor(self.idx_list)(zeros_like(x), gz, *rest)] + [None] * len(rest)
def __eq__(self, others):
def __eq__(self, other):
return type(self) == type(other) and self.idx_list == other.idx_list
def __hash__(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论