提交 1dabf854 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5062 from nouiz/printer

Printer
......@@ -367,15 +367,15 @@ class FunctionGraph(utils.object2):
reason
reason is the name of the optimization or operation in progress.
"""
global NullType
if NullType is None:
from .null_type import NullType
# Imports the owners of the variables
if variable.owner and variable.owner not in self.apply_nodes:
self.__import__(variable.owner, reason=reason)
if (variable.owner is None and
elif (variable.owner is None and
not isinstance(variable, graph.Constant) and
variable not in self.inputs):
global NullType
if NullType is None:
from .null_type import NullType
if isinstance(variable.type, NullType):
raise TypeError("Computation graph contains a NaN. " +
variable.type.why_null)
......
......@@ -340,7 +340,7 @@ class PrinterState(gof.utils.scratchpad):
def __init__(self, props=None, **more_props):
if props is None:
props = {}
if isinstance(props, gof.utils.scratchpad):
elif isinstance(props, gof.utils.scratchpad):
self.__update__(props)
else:
self.__dict__.update(props)
......@@ -351,11 +351,6 @@ class PrinterState(gof.utils.scratchpad):
# printed many times
self.memo = {}
def clone(self, props=None, **more_props):
if props is None:
props = {}
return PrinterState(self, **dict(props, **more_props))
class OperatorPrinter:
......@@ -387,13 +382,16 @@ class OperatorPrinter:
input_strings = []
max_i = len(node.inputs) - 1
for i, input in enumerate(node.inputs):
new_precedence = self.precedence
if (self.assoc == 'left' and i != 0 or self.assoc == 'right' and
i != max_i):
s = pprinter.process(input, pstate.clone(
precedence=self.precedence + 1e-6))
else:
s = pprinter.process(input, pstate.clone(
precedence=self.precedence))
new_precedence += 1e-6
try:
old_precedence = getattr(pstate, 'precedence', None)
pstate.precedence = new_precedence
s = pprinter.process(input, pstate)
finally:
pstate.precedence = old_precedence
input_strings.append(s)
if len(input_strings) == 1:
s = self.operator + input_strings[0]
......@@ -429,8 +427,15 @@ class PatternPrinter:
pattern, precedences = self.patterns[idx]
precedences += (1000,) * len(node.inputs)
def pp_process(input, precedence):
return pprinter.process(input, pstate.clone(precedence=precedence))
def pp_process(input, new_precedence):
try:
old_precedence = getattr(pstate, 'precedence', None)
pstate.precedence = new_precedence
r = pprinter.process(input, pstate)
finally:
pstate.precedence = old_precedence
return r
d = dict((str(i), x)
for i, x in enumerate(pp_process(input, precedence)
......@@ -456,9 +461,15 @@ class FunctionPrinter:
"not the result of an operation" % self.names)
idx = node.outputs.index(output)
name = self.names[idx]
r = "%s(%s)" % (name, ", ".join(
[pprinter.process(input, pstate.clone(precedence=-1000))
for input in node.inputs]))
new_precedence = -1000
try:
old_precedence = getattr(pstate, 'precedence', None)
pstate.precedence = new_precedence
r = "%s(%s)" % (name, ", ".join(
[pprinter.process(input, pstate) for input in node.inputs]))
finally:
pstate.precedence = old_precedence
pstate.memo[output] = r
return r
......@@ -479,35 +490,40 @@ class IgnorePrinter:
return r
class DefaultPrinter:
def __init__(self):
self.leaf_printer = LeafPrinter()
class LeafPrinter:
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter
node = output.owner
if node is None:
return self.leaf_printer.process(output, pstate)
r = "%s(%s)" % (str(node.op), ", ".join(
[pprinter.process(input, pstate.clone(precedence=-1000))
for input in node.inputs]))
if output.name in greek:
r = greek[output.name]
else:
r = str(output)
pstate.memo[output] = r
return r
leaf_printer = LeafPrinter()
class LeafPrinter:
class DefaultPrinter:
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
if output.name in greek:
r = greek[output.name]
else:
r = str(output)
pprinter = pstate.pprinter
node = output.owner
if node is None:
return leaf_printer.process(output, pstate)
new_precedence = -1000
try:
old_precedence = getattr(pstate, 'precedence', None)
pstate.precedence = new_precedence
r = "%s(%s)" % (str(node.op), ", ".join(
[pprinter.process(input, pstate)
for input in node.inputs]))
finally:
pstate.precedence = old_precedence
pstate.memo[output] = r
return r
default_printer = DefaultPrinter()
class PPrinter:
......@@ -562,7 +578,7 @@ class PPrinter:
else:
strings = []
pprinter = self.clone_assign(lambda pstate, r: r.name is not None and
r is not current, LeafPrinter())
r is not current, leaf_printer)
inv_updates = dict((b, a) for (a, b) in iteritems(updates))
i = 1
for node in gof.graph.io_toposort(list(inputs) + updates.keys(),
......@@ -631,10 +647,7 @@ else:
pprint = PPrinter()
pprint.assign(lambda pstate, r: True, DefaultPrinter())
pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and
pstate.target is not r and r.name is not None,
LeafPrinter())
pprint.assign(lambda pstate, r: True, default_printer)
pp = pprint
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论