提交 75fa1ff0 authored 作者: Frederic Bastien's avatar Frederic Bastien

Traverse the graph only once in theano.printing.pprint

上级 d5384756
......@@ -345,6 +345,11 @@ class PrinterState(gof.utils.scratchpad):
else:
self.__dict__.update(props)
self.__dict__.update(more_props)
# A dict from the object to print to its string
# representation. If it is a dag and not a tree, it allow to
# parse each node of the graph only once. They will still be
# printed many times
self.memo = {}
def clone(self, props=None, **more_props):
if props is None:
......@@ -361,6 +366,8 @@ class OperatorPrinter:
assert self.assoc in VALID_ASSOC
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter
node = output.owner
if node is None:
......@@ -393,9 +400,11 @@ class OperatorPrinter:
else:
s = (" %s " % self.operator).join(input_strings)
if parenthesize:
return "(%s)" % s
r = "(%s)" % s
else:
return s
r = s
pstate.memo[output] = r
return r
class PatternPrinter:
......@@ -409,6 +418,8 @@ class PatternPrinter:
self.patterns.append((pattern[0], pattern[1:]))
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter
node = output.owner
if node is None:
......@@ -425,7 +436,9 @@ class PatternPrinter:
for i, x in enumerate(pp_process(input, precedence)
for input, precedence in
zip(node.inputs, precedences)))
return pattern % d
r = pattern % d
pstate.memo[output] = r
return r
class FunctionPrinter:
......@@ -434,6 +447,8 @@ class FunctionPrinter:
self.names = names
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter
node = output.owner
if node is None:
......@@ -441,40 +456,27 @@ class FunctionPrinter:
"not the result of an operation" % self.names)
idx = node.outputs.index(output)
name = self.names[idx]
return "%s(%s)" % (name, ", ".join(
r = "%s(%s)" % (name, ", ".join(
[pprinter.process(input, pstate.clone(precedence=-1000))
for input in node.inputs]))
class MemberPrinter:
def __init__(self, *names):
self.names = names
def process(self, output, pstate):
pprinter = pstate.pprinter
node = output.owner
if node is None:
raise TypeError("function %s cannot represent a variable that is"
" not the result of an operation" % self.function)
idx = node.outputs.index(output)
name = self.names[idx]
input = node.inputs[0]
return "%s.%s" % (pprinter.process(input,
pstate.clone(precedence=1000)),
name)
pstate.memo[output] = r
return r
class IgnorePrinter:
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter
node = output.owner
if node is None:
raise TypeError("function %s cannot represent a variable that is"
" not the result of an operation" % self.function)
input = node.inputs[0]
return "%s" % pprinter.process(input, pstate)
r = "%s" % pprinter.process(input, pstate)
pstate.memo[output] = r
return r
class DefaultPrinter:
......@@ -482,22 +484,30 @@ class DefaultPrinter:
def __init__(self):
pass
def process(self, r, pstate):
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter
node = r.owner
node = output.owner
if node is None:
return LeafPrinter().process(r, pstate)
return "%s(%s)" % (str(node.op), ", ".join(
return LeafPrinter().process(output, pstate)
r = "%s(%s)" % (str(node.op), ", ".join(
[pprinter.process(input, pstate.clone(precedence=-1000))
for input in node.inputs]))
pstate.memo[output] = r
return r
class LeafPrinter:
def process(self, r, pstate):
if r.name in greek:
return greek[r.name]
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
if output.name in greek:
r = greek[output.name]
else:
return str(r)
r = str(output)
pstate.memo[output] = r
return r
class PPrinter:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论