提交 1dabf854 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5062 from nouiz/printer

Printer
...@@ -367,15 +367,15 @@ class FunctionGraph(utils.object2): ...@@ -367,15 +367,15 @@ class FunctionGraph(utils.object2):
reason reason
reason is the name of the optimization or operation in progress. reason is the name of the optimization or operation in progress.
""" """
global NullType
if NullType is None:
from .null_type import NullType
# Imports the owners of the variables # Imports the owners of the variables
if variable.owner and variable.owner not in self.apply_nodes: if variable.owner and variable.owner not in self.apply_nodes:
self.__import__(variable.owner, reason=reason) self.__import__(variable.owner, reason=reason)
if (variable.owner is None and elif (variable.owner is None and
not isinstance(variable, graph.Constant) and not isinstance(variable, graph.Constant) and
variable not in self.inputs): variable not in self.inputs):
global NullType
if NullType is None:
from .null_type import NullType
if isinstance(variable.type, NullType): if isinstance(variable.type, NullType):
raise TypeError("Computation graph contains a NaN. " + raise TypeError("Computation graph contains a NaN. " +
variable.type.why_null) variable.type.why_null)
......
...@@ -340,7 +340,7 @@ class PrinterState(gof.utils.scratchpad): ...@@ -340,7 +340,7 @@ class PrinterState(gof.utils.scratchpad):
def __init__(self, props=None, **more_props): def __init__(self, props=None, **more_props):
if props is None: if props is None:
props = {} props = {}
if isinstance(props, gof.utils.scratchpad): elif isinstance(props, gof.utils.scratchpad):
self.__update__(props) self.__update__(props)
else: else:
self.__dict__.update(props) self.__dict__.update(props)
...@@ -351,11 +351,6 @@ class PrinterState(gof.utils.scratchpad): ...@@ -351,11 +351,6 @@ class PrinterState(gof.utils.scratchpad):
# printed many times # printed many times
self.memo = {} self.memo = {}
def clone(self, props=None, **more_props):
if props is None:
props = {}
return PrinterState(self, **dict(props, **more_props))
class OperatorPrinter: class OperatorPrinter:
...@@ -387,13 +382,16 @@ class OperatorPrinter: ...@@ -387,13 +382,16 @@ class OperatorPrinter:
input_strings = [] input_strings = []
max_i = len(node.inputs) - 1 max_i = len(node.inputs) - 1
for i, input in enumerate(node.inputs): for i, input in enumerate(node.inputs):
new_precedence = self.precedence
if (self.assoc == 'left' and i != 0 or self.assoc == 'right' and if (self.assoc == 'left' and i != 0 or self.assoc == 'right' and
i != max_i): i != max_i):
s = pprinter.process(input, pstate.clone( new_precedence += 1e-6
precedence=self.precedence + 1e-6)) try:
else: old_precedence = getattr(pstate, 'precedence', None)
s = pprinter.process(input, pstate.clone( pstate.precedence = new_precedence
precedence=self.precedence)) s = pprinter.process(input, pstate)
finally:
pstate.precedence = old_precedence
input_strings.append(s) input_strings.append(s)
if len(input_strings) == 1: if len(input_strings) == 1:
s = self.operator + input_strings[0] s = self.operator + input_strings[0]
...@@ -429,8 +427,15 @@ class PatternPrinter: ...@@ -429,8 +427,15 @@ class PatternPrinter:
pattern, precedences = self.patterns[idx] pattern, precedences = self.patterns[idx]
precedences += (1000,) * len(node.inputs) precedences += (1000,) * len(node.inputs)
def pp_process(input, precedence): def pp_process(input, new_precedence):
return pprinter.process(input, pstate.clone(precedence=precedence)) try:
old_precedence = getattr(pstate, 'precedence', None)
pstate.precedence = new_precedence
r = pprinter.process(input, pstate)
finally:
pstate.precedence = old_precedence
return r
d = dict((str(i), x) d = dict((str(i), x)
for i, x in enumerate(pp_process(input, precedence) for i, x in enumerate(pp_process(input, precedence)
...@@ -456,9 +461,15 @@ class FunctionPrinter: ...@@ -456,9 +461,15 @@ class FunctionPrinter:
"not the result of an operation" % self.names) "not the result of an operation" % self.names)
idx = node.outputs.index(output) idx = node.outputs.index(output)
name = self.names[idx] name = self.names[idx]
r = "%s(%s)" % (name, ", ".join( new_precedence = -1000
[pprinter.process(input, pstate.clone(precedence=-1000)) try:
for input in node.inputs])) old_precedence = getattr(pstate, 'precedence', None)
pstate.precedence = new_precedence
r = "%s(%s)" % (name, ", ".join(
[pprinter.process(input, pstate) for input in node.inputs]))
finally:
pstate.precedence = old_precedence
pstate.memo[output] = r pstate.memo[output] = r
return r return r
...@@ -479,35 +490,40 @@ class IgnorePrinter: ...@@ -479,35 +490,40 @@ class IgnorePrinter:
return r return r
class DefaultPrinter: class LeafPrinter:
def __init__(self):
self.leaf_printer = LeafPrinter()
def process(self, output, pstate): def process(self, output, pstate):
if output in pstate.memo: if output in pstate.memo:
return pstate.memo[output] return pstate.memo[output]
pprinter = pstate.pprinter if output.name in greek:
node = output.owner r = greek[output.name]
if node is None: else:
return self.leaf_printer.process(output, pstate) r = str(output)
r = "%s(%s)" % (str(node.op), ", ".join(
[pprinter.process(input, pstate.clone(precedence=-1000))
for input in node.inputs]))
pstate.memo[output] = r pstate.memo[output] = r
return r return r
leaf_printer = LeafPrinter()
class LeafPrinter: class DefaultPrinter:
def process(self, output, pstate): def process(self, output, pstate):
if output in pstate.memo: if output in pstate.memo:
return pstate.memo[output] return pstate.memo[output]
if output.name in greek: pprinter = pstate.pprinter
r = greek[output.name] node = output.owner
else: if node is None:
r = str(output) return leaf_printer.process(output, pstate)
new_precedence = -1000
try:
old_precedence = getattr(pstate, 'precedence', None)
pstate.precedence = new_precedence
r = "%s(%s)" % (str(node.op), ", ".join(
[pprinter.process(input, pstate)
for input in node.inputs]))
finally:
pstate.precedence = old_precedence
pstate.memo[output] = r pstate.memo[output] = r
return r return r
default_printer = DefaultPrinter()
class PPrinter: class PPrinter:
...@@ -562,7 +578,7 @@ class PPrinter: ...@@ -562,7 +578,7 @@ class PPrinter:
else: else:
strings = [] strings = []
pprinter = self.clone_assign(lambda pstate, r: r.name is not None and pprinter = self.clone_assign(lambda pstate, r: r.name is not None and
r is not current, LeafPrinter()) r is not current, leaf_printer)
inv_updates = dict((b, a) for (a, b) in iteritems(updates)) inv_updates = dict((b, a) for (a, b) in iteritems(updates))
i = 1 i = 1
for node in gof.graph.io_toposort(list(inputs) + updates.keys(), for node in gof.graph.io_toposort(list(inputs) + updates.keys(),
...@@ -631,10 +647,7 @@ else: ...@@ -631,10 +647,7 @@ else:
pprint = PPrinter() pprint = PPrinter()
pprint.assign(lambda pstate, r: True, DefaultPrinter()) pprint.assign(lambda pstate, r: True, default_printer)
pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and
pstate.target is not r and r.name is not None,
LeafPrinter())
pp = pprint pp = pprint
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论