提交 75fa1ff0 authored 作者: Frederic Bastien's avatar Frederic Bastien

Traverse the graph only once in theano.printing.pprint

上级 d5384756
...@@ -345,6 +345,11 @@ class PrinterState(gof.utils.scratchpad): ...@@ -345,6 +345,11 @@ class PrinterState(gof.utils.scratchpad):
else: else:
self.__dict__.update(props) self.__dict__.update(props)
self.__dict__.update(more_props) self.__dict__.update(more_props)
# A dict from the object to print to its string
# representation. If it is a dag and not a tree, it allow to
# parse each node of the graph only once. They will still be
# printed many times
self.memo = {}
def clone(self, props=None, **more_props): def clone(self, props=None, **more_props):
if props is None: if props is None:
...@@ -361,6 +366,8 @@ class OperatorPrinter: ...@@ -361,6 +366,8 @@ class OperatorPrinter:
assert self.assoc in VALID_ASSOC assert self.assoc in VALID_ASSOC
def process(self, output, pstate): def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter pprinter = pstate.pprinter
node = output.owner node = output.owner
if node is None: if node is None:
...@@ -393,9 +400,11 @@ class OperatorPrinter: ...@@ -393,9 +400,11 @@ class OperatorPrinter:
else: else:
s = (" %s " % self.operator).join(input_strings) s = (" %s " % self.operator).join(input_strings)
if parenthesize: if parenthesize:
return "(%s)" % s r = "(%s)" % s
else: else:
return s r = s
pstate.memo[output] = r
return r
class PatternPrinter: class PatternPrinter:
...@@ -409,6 +418,8 @@ class PatternPrinter: ...@@ -409,6 +418,8 @@ class PatternPrinter:
self.patterns.append((pattern[0], pattern[1:])) self.patterns.append((pattern[0], pattern[1:]))
def process(self, output, pstate): def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter pprinter = pstate.pprinter
node = output.owner node = output.owner
if node is None: if node is None:
...@@ -425,7 +436,9 @@ class PatternPrinter: ...@@ -425,7 +436,9 @@ class PatternPrinter:
for i, x in enumerate(pp_process(input, precedence) for i, x in enumerate(pp_process(input, precedence)
for input, precedence in for input, precedence in
zip(node.inputs, precedences))) zip(node.inputs, precedences)))
return pattern % d r = pattern % d
pstate.memo[output] = r
return r
class FunctionPrinter: class FunctionPrinter:
...@@ -434,6 +447,8 @@ class FunctionPrinter: ...@@ -434,6 +447,8 @@ class FunctionPrinter:
self.names = names self.names = names
def process(self, output, pstate): def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter pprinter = pstate.pprinter
node = output.owner node = output.owner
if node is None: if node is None:
...@@ -441,40 +456,27 @@ class FunctionPrinter: ...@@ -441,40 +456,27 @@ class FunctionPrinter:
"not the result of an operation" % self.names) "not the result of an operation" % self.names)
idx = node.outputs.index(output) idx = node.outputs.index(output)
name = self.names[idx] name = self.names[idx]
return "%s(%s)" % (name, ", ".join( r = "%s(%s)" % (name, ", ".join(
[pprinter.process(input, pstate.clone(precedence=-1000)) [pprinter.process(input, pstate.clone(precedence=-1000))
for input in node.inputs])) for input in node.inputs]))
pstate.memo[output] = r
return r
class MemberPrinter:
def __init__(self, *names):
self.names = names
def process(self, output, pstate):
pprinter = pstate.pprinter
node = output.owner
if node is None:
raise TypeError("function %s cannot represent a variable that is"
" not the result of an operation" % self.function)
idx = node.outputs.index(output)
name = self.names[idx]
input = node.inputs[0]
return "%s.%s" % (pprinter.process(input,
pstate.clone(precedence=1000)),
name)
class IgnorePrinter: class IgnorePrinter:
def process(self, output, pstate): def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter pprinter = pstate.pprinter
node = output.owner node = output.owner
if node is None: if node is None:
raise TypeError("function %s cannot represent a variable that is" raise TypeError("function %s cannot represent a variable that is"
" not the result of an operation" % self.function) " not the result of an operation" % self.function)
input = node.inputs[0] input = node.inputs[0]
return "%s" % pprinter.process(input, pstate) r = "%s" % pprinter.process(input, pstate)
pstate.memo[output] = r
return r
class DefaultPrinter: class DefaultPrinter:
...@@ -482,22 +484,30 @@ class DefaultPrinter: ...@@ -482,22 +484,30 @@ class DefaultPrinter:
def __init__(self): def __init__(self):
pass pass
def process(self, r, pstate): def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
pprinter = pstate.pprinter pprinter = pstate.pprinter
node = r.owner node = output.owner
if node is None: if node is None:
return LeafPrinter().process(r, pstate) return LeafPrinter().process(output, pstate)
return "%s(%s)" % (str(node.op), ", ".join( r = "%s(%s)" % (str(node.op), ", ".join(
[pprinter.process(input, pstate.clone(precedence=-1000)) [pprinter.process(input, pstate.clone(precedence=-1000))
for input in node.inputs])) for input in node.inputs]))
pstate.memo[output] = r
return r
class LeafPrinter: class LeafPrinter:
def process(self, r, pstate): def process(self, output, pstate):
if r.name in greek: if output in pstate.memo:
return greek[r.name] return pstate.memo[output]
if output.name in greek:
r = greek[output.name]
else: else:
return str(r) r = str(output)
pstate.memo[output] = r
return r
class PPrinter: class PPrinter:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论