提交 78164b96 authored 作者: Frederic Bastien's avatar Frederic Bastien

pprint speed up. Don't clone state anymore and use single instance.

上级 b4a695a8
...@@ -351,11 +351,6 @@ class PrinterState(gof.utils.scratchpad): ...@@ -351,11 +351,6 @@ class PrinterState(gof.utils.scratchpad):
# printed many times # printed many times
self.memo = {} self.memo = {}
def clone(self, props=None, **more_props):
if props is None:
return PrinterState(self, **more_props)
return PrinterState(self, **dict(props, **more_props))
class OperatorPrinter: class OperatorPrinter:
...@@ -387,13 +382,16 @@ class OperatorPrinter: ...@@ -387,13 +382,16 @@ class OperatorPrinter:
input_strings = [] input_strings = []
max_i = len(node.inputs) - 1 max_i = len(node.inputs) - 1
for i, input in enumerate(node.inputs): for i, input in enumerate(node.inputs):
new_precedence = self.precedence
if (self.assoc == 'left' and i != 0 or self.assoc == 'right' and if (self.assoc == 'left' and i != 0 or self.assoc == 'right' and
i != max_i): i != max_i):
s = pprinter.process(input, pstate.clone( new_precedence += 1e-6
precedence=self.precedence + 1e-6)) old_precedence = getattr(pstate, 'precedence', None)
else: try:
s = pprinter.process(input, pstate.clone( pstate.precedence = new_precedence
precedence=self.precedence)) s = pprinter.process(input, pstate)
finally:
pstate.precedence = old_precedence
input_strings.append(s) input_strings.append(s)
if len(input_strings) == 1: if len(input_strings) == 1:
s = self.operator + input_strings[0] s = self.operator + input_strings[0]
...@@ -429,8 +427,15 @@ class PatternPrinter: ...@@ -429,8 +427,15 @@ class PatternPrinter:
pattern, precedences = self.patterns[idx] pattern, precedences = self.patterns[idx]
precedences += (1000,) * len(node.inputs) precedences += (1000,) * len(node.inputs)
def pp_process(input, precedence): def pp_process(input, new_precedence):
return pprinter.process(input, pstate.clone(precedence=precedence)) try:
old_precedence = pstate.precedence
pstate.precedence = new_precedence
r = pprinter.process(input, pstate)
finally:
pstate.precedence = old_precedence
return r
d = dict((str(i), x) d = dict((str(i), x)
for i, x in enumerate(pp_process(input, precedence) for i, x in enumerate(pp_process(input, precedence)
...@@ -456,9 +461,15 @@ class FunctionPrinter: ...@@ -456,9 +461,15 @@ class FunctionPrinter:
"not the result of an operation" % self.names) "not the result of an operation" % self.names)
idx = node.outputs.index(output) idx = node.outputs.index(output)
name = self.names[idx] name = self.names[idx]
r = "%s(%s)" % (name, ", ".join( new_precedense = -1000
[pprinter.process(input, pstate.clone(precedence=-1000)) try:
for input in node.inputs])) old_precedence = pstate.precedence
pstate.precedence = new_precedence
r = "%s(%s)" % (name, ", ".join(
[pprinter.process(input, pstate) for input in node.inputs]))
finally:
pstate.precedence = old_precedence
pstate.memo[output] = r pstate.memo[output] = r
return r return r
...@@ -479,35 +490,40 @@ class IgnorePrinter: ...@@ -479,35 +490,40 @@ class IgnorePrinter:
return r return r
class DefaultPrinter: class LeafPrinter:
def __init__(self):
self.leaf_printer = LeafPrinter()
def process(self, output, pstate): def process(self, output, pstate):
if output in pstate.memo: if output in pstate.memo:
return pstate.memo[output] return pstate.memo[output]
pprinter = pstate.pprinter if output.name in greek:
node = output.owner r = greek[output.name]
if node is None: else:
return self.leaf_printer.process(output, pstate) r = str(output)
r = "%s(%s)" % (str(node.op), ", ".join(
[pprinter.process(input, pstate.clone(precedence=-1000))
for input in node.inputs]))
pstate.memo[output] = r pstate.memo[output] = r
return r return r
leaf_printer = LeafPrinter()
class LeafPrinter: class DefaultPrinter:
def process(self, output, pstate): def process(self, output, pstate):
if output in pstate.memo: if output in pstate.memo:
return pstate.memo[output] return pstate.memo[output]
if output.name in greek: pprinter = pstate.pprinter
r = greek[output.name] node = output.owner
else: if node is None:
r = str(output) return leaf_printer.process(output, pstate)
new_precedence = -1000
try:
old_precedence = pstate.precedence
pstate.precedence = new_precedence
r = "%s(%s)" % (str(node.op), ", ".join(
[pprinter.process(input, pstate)
for input in node.inputs]))
finally:
pstate.precedence = old_precedence
pstate.memo[output] = r pstate.memo[output] = r
return r return r
default_printer = DefaultPrinter()
class PPrinter: class PPrinter:
...@@ -562,7 +578,7 @@ class PPrinter: ...@@ -562,7 +578,7 @@ class PPrinter:
else: else:
strings = [] strings = []
pprinter = self.clone_assign(lambda pstate, r: r.name is not None and pprinter = self.clone_assign(lambda pstate, r: r.name is not None and
r is not current, LeafPrinter()) r is not current, leaf_printer)
inv_updates = dict((b, a) for (a, b) in iteritems(updates)) inv_updates = dict((b, a) for (a, b) in iteritems(updates))
i = 1 i = 1
for node in gof.graph.io_toposort(list(inputs) + updates.keys(), for node in gof.graph.io_toposort(list(inputs) + updates.keys(),
...@@ -631,10 +647,7 @@ else: ...@@ -631,10 +647,7 @@ else:
pprint = PPrinter() pprint = PPrinter()
pprint.assign(lambda pstate, r: True, DefaultPrinter()) pprint.assign(lambda pstate, r: True, default_printer)
pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and
pstate.target is not r and r.name is not None,
LeafPrinter())
pp = pprint pp = pprint
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论