提交 e7cb39cb authored 作者: Frederic's avatar Frederic

pep8

上级 76d02f64
...@@ -107,8 +107,9 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -107,8 +107,9 @@ def debugprint(obj, depth=-1, print_type=False,
scan_ops = [] scan_ops = []
for r in results_to_print: for r in results_to_print:
#Add the parent scan op to the list as well # Add the parent scan op to the list as well
if hasattr(r.owner, 'op') and isinstance(r.owner.op, theano.scan_module.scan_op.Scan): if (hasattr(r.owner, 'op') and
isinstance(r.owner.op, theano.scan_module.scan_op.Scan)):
scan_ops.append(r) scan_ops.append(r)
debugmode.debugprint(r, depth=depth, done=done, print_type=print_type, debugmode.debugprint(r, depth=depth, done=done, print_type=print_type,
...@@ -122,7 +123,8 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -122,7 +123,8 @@ def debugprint(obj, depth=-1, print_type=False,
for s in scan_ops: for s in scan_ops:
print >> file, "" print >> file, ""
debugmode.debugprint(s, depth=depth, done=done, print_type=print_type, debugmode.debugprint(s, depth=depth, done=done,
print_type=print_type,
file=_file, ids=ids, file=_file, ids=ids,
scan_ops=scan_ops, stop_on_name=stop_on_name) scan_ops=scan_ops, stop_on_name=stop_on_name)
if hasattr(s.owner.op, 'fn'): if hasattr(s.owner.op, 'fn'):
...@@ -135,10 +137,12 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -135,10 +137,12 @@ def debugprint(obj, depth=-1, print_type=False,
if isinstance(i.owner.op, theano.scan_module.scan_op.Scan): if isinstance(i.owner.op, theano.scan_module.scan_op.Scan):
scan_ops.append(i) scan_ops.append(i)
debugmode.debugprint(r=i, prefix=new_prefix, depth=depth, done=done, debugmode.debugprint(r=i, prefix=new_prefix,
depth=depth, done=done,
print_type=print_type, file=file, print_type=print_type, file=file,
ids=ids, stop_on_name=stop_on_name, ids=ids, stop_on_name=stop_on_name,
prefix_child=new_prefix_child, scan_ops=scan_ops) prefix_child=new_prefix_child,
scan_ops=scan_ops)
if file is _file: if file is _file:
return file return file
...@@ -269,10 +273,10 @@ class OperatorPrinter: ...@@ -269,10 +273,10 @@ class OperatorPrinter:
if (self.assoc == 'left' and i != 0 or self.assoc == 'right' if (self.assoc == 'left' and i != 0 or self.assoc == 'right'
and i != max_i): and i != max_i):
s = pprinter.process(input, pstate.clone( s = pprinter.process(input, pstate.clone(
precedence=self.precedence + 1e-6)) precedence=self.precedence + 1e-6))
else: else:
s = pprinter.process(input, pstate.clone( s = pprinter.process(input, pstate.clone(
precedence=self.precedence)) precedence=self.precedence))
input_strings.append(s) input_strings.append(s)
if len(input_strings) == 1: if len(input_strings) == 1:
s = self.operator + input_strings[0] s = self.operator + input_strings[0]
...@@ -327,8 +331,8 @@ class FunctionPrinter: ...@@ -327,8 +331,8 @@ class FunctionPrinter:
idx = node.outputs.index(output) idx = node.outputs.index(output)
name = self.names[idx] name = self.names[idx]
return "%s(%s)" % (name, ", ".join( return "%s(%s)" % (name, ", ".join(
[pprinter.process(input, pstate.clone(precedence=-1000)) [pprinter.process(input, pstate.clone(precedence=-1000))
for input in node.inputs])) for input in node.inputs]))
class MemberPrinter: class MemberPrinter:
...@@ -374,8 +378,8 @@ class DefaultPrinter: ...@@ -374,8 +378,8 @@ class DefaultPrinter:
if node is None: if node is None:
return LeafPrinter().process(r, pstate) return LeafPrinter().process(r, pstate)
return "%s(%s)" % (str(node.op), ", ".join( return "%s(%s)" % (str(node.op), ", ".join(
[pprinter.process(input, pstate.clone(precedence=-1000)) [pprinter.process(input, pstate.clone(precedence=-1000))
for input in node.inputs])) for input in node.inputs]))
class LeafPrinter: class LeafPrinter:
...@@ -442,7 +446,7 @@ class PPrinter: ...@@ -442,7 +446,7 @@ class PPrinter:
if output in inv_updates: if output in inv_updates:
name = str(inv_updates[output]) name = str(inv_updates[output])
strings.append((i + 1000, "%s <- %s" % ( strings.append((i + 1000, "%s <- %s" % (
name, pprinter.process(output)))) name, pprinter.process(output))))
i += 1 i += 1
if output.name is not None or output in outputs: if output.name is not None or output in outputs:
if output.name is None: if output.name is None:
...@@ -514,13 +518,13 @@ Print to the terminal a math-like expression. ...@@ -514,13 +518,13 @@ Print to the terminal a math-like expression.
# colors not used: orange, amber#FFBF00, purple, pink, # colors not used: orange, amber#FFBF00, purple, pink,
# used by default: green, blue, grey, red # used by default: green, blue, grey, red
default_colorCodes = {'GpuFromHost': 'red', default_colorCodes = {'GpuFromHost': 'red',
'HostFromGpu': 'red', 'HostFromGpu': 'red',
'Scan': 'yellow', 'Scan': 'yellow',
'Shape': 'cyan', 'Shape': 'cyan',
'IfElse': 'magenta', 'IfElse': 'magenta',
'Elemwise': '#FFAABB', # dark pink 'Elemwise': '#FFAABB', # dark pink
'Subtensor': '#FFAAFF', # purple 'Subtensor': '#FFAAFF', # purple
'Alloc': '#FFAA22'} # orange 'Alloc': '#FFAA22'} # orange
def pydotprint(fct, outfile=None, def pydotprint(fct, outfile=None,
...@@ -633,7 +637,7 @@ def pydotprint(fct, outfile=None, ...@@ -633,7 +637,7 @@ def pydotprint(fct, outfile=None,
topo = fct.toposort() topo = fct.toposort()
if not pydot_imported: if not pydot_imported:
raise RuntimeError("Failed to import pydot. You must install pydot" raise RuntimeError("Failed to import pydot. You must install pydot"
" for `pydotprint` to work.") " for `pydotprint` to work.")
return return
g = pd.Dot() g = pd.Dot()
...@@ -696,8 +700,8 @@ def pydotprint(fct, outfile=None, ...@@ -696,8 +700,8 @@ def pydotprint(fct, outfile=None,
varstr = (input_update[var].variable.name + " UPDATE " varstr = (input_update[var].variable.name + " UPDATE "
+ str(var.type)) + str(var.type))
else: else:
#a var id is needed as otherwise var with the same type will be # a var id is needed as otherwise var with the same type will be
#merged in the graph. # merged in the graph.
varstr = str(var.type) varstr = str(var.type)
if (varstr in all_strings) or with_ids: if (varstr in all_strings) or with_ids:
idx = ' id=' + str(len(var_str)) idx = ' id=' + str(len(var_str))
...@@ -726,7 +730,7 @@ def pydotprint(fct, outfile=None, ...@@ -726,7 +730,7 @@ def pydotprint(fct, outfile=None,
prof_str = '' prof_str = ''
if mode: if mode:
time = mode.profile_stats[fct].apply_time.get(node, 0) time = mode.profile_stats[fct].apply_time.get(node, 0)
#second, % total time in profiler, %fct time in profiler # second, % total time in profiler, %fct time in profiler
if mode.local_time == 0: if mode.local_time == 0:
pt = 0 pt = 0
else: else:
...@@ -738,7 +742,7 @@ def pydotprint(fct, outfile=None, ...@@ -738,7 +742,7 @@ def pydotprint(fct, outfile=None,
prof_str = ' (%.3fs,%.3f%%,%.3f%%)' % (time, pt, pf) prof_str = ' (%.3fs,%.3f%%,%.3f%%)' % (time, pt, pf)
elif profile: elif profile:
time = profile.apply_time.get(node, 0) time = profile.apply_time.get(node, 0)
#second, %fct time in profiler # second, %fct time in profiler
if profile.fct_callcount == 0: if profile.fct_callcount == 0:
pf = 0 pf = 0
else: else:
...@@ -788,7 +792,7 @@ def pydotprint(fct, outfile=None, ...@@ -788,7 +792,7 @@ def pydotprint(fct, outfile=None,
nw_node = pd.Node(astr, shape=apply_shape) nw_node = pd.Node(astr, shape=apply_shape)
elif high_contrast: elif high_contrast:
nw_node = pd.Node(astr, style='filled', fillcolor=use_color, nw_node = pd.Node(astr, style='filled', fillcolor=use_color,
shape=apply_shape) shape=apply_shape)
else: else:
nw_node = pd.Node(astr, color=use_color, shape=apply_shape) nw_node = pd.Node(astr, color=use_color, shape=apply_shape)
g.add_node(nw_node) g.add_node(nw_node)
...@@ -819,7 +823,7 @@ def pydotprint(fct, outfile=None, ...@@ -819,7 +823,7 @@ def pydotprint(fct, outfile=None,
elif var.name or not compact: elif var.name or not compact:
g.add_edge(pd.Edge(varstr, astr, label=label)) g.add_edge(pd.Edge(varstr, astr, label=label))
else: else:
#no name, so we don't make a var ellipse # no name, so we don't make a var ellipse
g.add_edge(pd.Edge(apply_name(var.owner), astr, label=label)) g.add_edge(pd.Edge(apply_name(var.owner), astr, label=label))
for id, var in enumerate(node.outputs): for id, var in enumerate(node.outputs):
...@@ -902,7 +906,7 @@ def pydotprint_variables(vars, ...@@ -902,7 +906,7 @@ def pydotprint_variables(vars,
''' '''
warnings.warn("pydotprint_variables() is deprecated." warnings.warn("pydotprint_variables() is deprecated."
" Use pydotprint() instead.") " Use pydotprint() instead.")
if colorCodes is None: if colorCodes is None:
colorCodes = default_colorCodes colorCodes = default_colorCodes
...@@ -986,7 +990,7 @@ def pydotprint_variables(vars, ...@@ -986,7 +990,7 @@ def pydotprint_variables(vars,
g.add_node(pd.Node(varastr)) g.add_node(pd.Node(varastr))
elif high_contrast: elif high_contrast:
g.add_node(pd.Node(varastr, style='filled', g.add_node(pd.Node(varastr, style='filled',
fillcolor='green')) fillcolor='green'))
else: else:
g.add_node(pd.Node(varastr, color='green')) g.add_node(pd.Node(varastr, color='green'))
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论