提交 995359de authored 作者: Frederic's avatar Frederic

pep8

上级 66cb3bf6
"""WRITEME""" """WRITEME"""
from copy import copy
import sys
import traceback
import theano import theano
from theano.gof import utils from theano.gof import utils
from theano.gof import graph from theano.gof import graph
from theano.gof.type import Type from theano.gof.type import Type
import sys, traceback
from copy import copy
__excepthook = sys.excepthook __excepthook = sys.excepthook
def log_thunk_trace(value, f=sys.stderr): def log_thunk_trace(value, f=sys.stderr):
"""Log theano's diagnostic stack trace for an exception """Log theano's diagnostic stack trace for an exception
raised by raise_with_op. raised by raise_with_op.
...@@ -110,6 +112,7 @@ def raise_with_op(op, exc_info=None): ...@@ -110,6 +112,7 @@ def raise_with_op(op, exc_info=None):
raise_with_op.print_thunk_trace = False raise_with_op.print_thunk_trace = False
class Linker(object): class Linker(object):
"""WRITEME""" """WRITEME"""
...@@ -132,10 +135,11 @@ class Linker(object): ...@@ -132,10 +135,11 @@ class Linker(object):
print new_e.data # 3.0 print new_e.data # 3.0
print e.data # 3.0 iff inplace == True (else unknown) print e.data # 3.0 iff inplace == True (else unknown)
""" """
raise utils.MethodNotDefined("make_thunk", type(self), self.__class__.__name__) raise utils.MethodNotDefined("make_thunk", type(self),
self.__class__.__name__)
## DELETEME ## ## DELETEME ##
def make_function(self, unpack_single = True, **kwargs): def make_function(self, unpack_single=True, **kwargs):
""" """
Returns a function that takes values corresponding to the inputs of the Returns a function that takes values corresponding to the inputs of the
fgraph used by this L{Linker} and returns values corresponding the the outputs fgraph used by this L{Linker} and returns values corresponding the the outputs
...@@ -155,6 +159,7 @@ class Linker(object): ...@@ -155,6 +159,7 @@ class Linker(object):
length 1 will be returned. length 1 will be returned.
""" """
thunk, inputs, outputs = self.make_thunk(**kwargs) thunk, inputs, outputs = self.make_thunk(**kwargs)
def execute(*args): def execute(*args):
def e_arity(takes, got): def e_arity(takes, got):
return 'Function call takes exactly %i %s (%i given)' \ return 'Function call takes exactly %i %s (%i given)' \
...@@ -165,7 +170,8 @@ class Linker(object): ...@@ -165,7 +170,8 @@ class Linker(object):
variable.data = arg variable.data = arg
thunk() thunk()
if unpack_single: if unpack_single:
return utils.to_return_values([variable.data for variable in outputs]) return utils.to_return_values([variable.data
for variable in outputs])
else: else:
return [variable.data for variable in outputs] return [variable.data for variable in outputs]
execute.thunk = thunk execute.thunk = thunk
...@@ -177,6 +183,7 @@ class Linker(object): ...@@ -177,6 +183,7 @@ class Linker(object):
def schedule(self, fgraph): def schedule(self, fgraph):
return fgraph.toposort() return fgraph.toposort()
#TODO: Move this class to the compile module, where it is used (and for which it exists). #TODO: Move this class to the compile module, where it is used (and for which it exists).
class Container(object): class Container(object):
"""This class joins a variable with its computed value. """This class joins a variable with its computed value.
...@@ -228,8 +235,10 @@ class Container(object): ...@@ -228,8 +235,10 @@ class Container(object):
kwargs['strict'] = True kwargs['strict'] = True
if self.allow_downcast is not None: if self.allow_downcast is not None:
kwargs['allow_downcast'] = self.allow_downcast kwargs['allow_downcast'] = self.allow_downcast
if hasattr(self.type,'filter_inplace'): if hasattr(self.type, 'filter_inplace'):
self.storage[0] = self.type.filter_inplace(value, self.storage[0], **kwargs) self.storage[0] = self.type.filter_inplace(value,
self.storage[0],
**kwargs)
else: else:
self.storage[0] = self.type.filter(value, **kwargs) self.storage[0] = self.type.filter(value, **kwargs)
...@@ -238,8 +247,10 @@ class Container(object): ...@@ -238,8 +247,10 @@ class Container(object):
raise raise
data = property(__get__, __set__) data = property(__get__, __set__)
value = property(__get__, __set__) value = property(__get__, __set__)
def __str__(self): def __str__(self):
return "<" + str(self.storage[0]) + ">" return "<" + str(self.storage[0]) + ">"
def __repr__(self): def __repr__(self):
return "<" + repr(self.storage[0]) + ">" return "<" + repr(self.storage[0]) + ">"
...@@ -302,6 +313,7 @@ def map_storage(fgraph, order, input_storage, output_storage): ...@@ -302,6 +313,7 @@ def map_storage(fgraph, order, input_storage, output_storage):
return input_storage, output_storage, storage_map return input_storage, output_storage, storage_map
def streamline(fgraph, thunks, order, post_thunk_old_storage=None, def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
no_recycling=None, profiler=None, nice_errors=True): no_recycling=None, profiler=None, nice_errors=True):
"""WRITEME """WRITEME
...@@ -335,14 +347,16 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None, ...@@ -335,14 +347,16 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
if post_thunk_old_storage: if post_thunk_old_storage:
if len(thunks) != len(post_thunk_old_storage): if len(thunks) != len(post_thunk_old_storage):
raise ValueError('Length of thunks and post_thunk_old_storage must match', raise ValueError(
'Length of thunks and post_thunk_old_storage must match',
(len(thunks), len(post_thunk_old_storage))) (len(thunks), len(post_thunk_old_storage)))
def streamline_default_f(): def streamline_default_f():
for x in no_recycling: for x in no_recycling:
x[0] = None x[0] = None
try: try:
for thunk, node, old_storage in zip(thunks, order, post_thunk_old_storage): for thunk, node, old_storage in zip(thunks, order,
post_thunk_old_storage):
thunk() thunk()
for old_s in old_storage: for old_s in old_storage:
old_s[0] = None old_s[0] = None
...@@ -351,6 +365,7 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None, ...@@ -351,6 +365,7 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
f = streamline_default_f f = streamline_default_f
elif nice_errors: elif nice_errors:
thunk_node_list = zip(thunks, order) thunk_node_list = zip(thunks, order)
def streamline_nice_errors_f(): def streamline_nice_errors_f():
for x in no_recycling: for x in no_recycling:
x[0] = None x[0] = None
...@@ -371,16 +386,18 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None, ...@@ -371,16 +386,18 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
f = streamline_fast_f f = streamline_fast_f
return f return f
class LocalLinker(Linker): class LocalLinker(Linker):
"""WRITEME """WRITEME
Useful base class for L{Linker}s which keep all nodes in the graph, and run a Useful base class for L{Linker}s which keep all nodes in the graph, and run a
thunk associated with each node. thunk associated with each node.
""" """
def make_thunk(self, profiler = None, input_storage = None, output_storage = None): def make_thunk(self, profiler=None, input_storage=None,
return self.make_all(profiler = profiler, output_storage=None):
input_storage = input_storage, return self.make_all(profiler=profiler,
output_storage = output_storage)[:3] input_storage=input_storage,
output_storage=output_storage)[:3]
def make_all(self, profiler, input_storage, output_storage): def make_all(self, profiler, input_storage, output_storage):
# By convention, subclasses of LocalLinker should implement this function! # By convention, subclasses of LocalLinker should implement this function!
...@@ -391,7 +408,9 @@ class LocalLinker(Linker): ...@@ -391,7 +408,9 @@ class LocalLinker(Linker):
# 3. output storage # 3. output storage
# 4. thunks: list of nodes' functions in the order they will be run by the function in (1) # 4. thunks: list of nodes' functions in the order they will be run by the function in (1)
# 5. order: list of nodes, in the order they will be run by the function in (1) # 5. order: list of nodes, in the order they will be run by the function in (1)
raise utils.MethodNotDefined("make_all", type(self), self.__class__.__name__) raise utils.MethodNotDefined("make_all", type(self),
self.__class__.__name__)
def gc_helper(node_list): def gc_helper(node_list):
""" """
...@@ -413,6 +432,7 @@ def gc_helper(node_list): ...@@ -413,6 +432,7 @@ def gc_helper(node_list):
computed.add(output) computed.add(output)
return computed, last_user return computed, last_user
class PerformLinker(LocalLinker): class PerformLinker(LocalLinker):
"""WRITEME """WRITEME
...@@ -445,7 +465,7 @@ class PerformLinker(LocalLinker): ...@@ -445,7 +465,7 @@ class PerformLinker(LocalLinker):
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
def make_all(self, profiler = None, input_storage = None, output_storage = None): def make_all(self, profiler=None, input_storage=None, output_storage=None):
""" """
:param profiler: WRITEME :param profiler: WRITEME
:param input_storage: WRITEME :param input_storage: WRITEME
...@@ -460,7 +480,6 @@ class PerformLinker(LocalLinker): ...@@ -460,7 +480,6 @@ class PerformLinker(LocalLinker):
input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage) input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage)
compute_map = {} compute_map = {}
for k in storage_map: for k in storage_map:
compute_map[k] = [k.owner is None] compute_map[k] = [k.owner is None]
...@@ -511,6 +530,7 @@ class PerformLinker(LocalLinker): ...@@ -511,6 +530,7 @@ class PerformLinker(LocalLinker):
[Container(output, storage, True) for output, storage in zip(fgraph.outputs, output_storage)], \ [Container(output, storage, True) for output, storage in zip(fgraph.outputs, output_storage)], \
thunks, order thunks, order
def add_clear_storage(f, computed, storage_map): def add_clear_storage(f, computed, storage_map):
def clear_storage(): def clear_storage():
for c in computed: for c in computed:
...@@ -591,11 +611,13 @@ class WrapLinker(Linker): ...@@ -591,11 +611,13 @@ class WrapLinker(Linker):
if no_recycling is None: if no_recycling is None:
no_recycling = [] no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph: if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)(self.linkers, self.wrapper).accept(fgraph, no_recycling) return type(self)(self.linkers, self.wrapper).accept(fgraph,
no_recycling)
self.fgraph = fgraph self.fgraph = fgraph
self.no_recycling = no_recycling self.no_recycling = no_recycling
self.linkers = [linker.accept(fgraph, no_recycling) for linker in self.linkers] self.linkers = [linker.accept(fgraph, no_recycling)
for linker in self.linkers]
return self return self
def pre(self, f, inputs, order, thunk_groups): def pre(self, f, inputs, order, thunk_groups):
...@@ -614,7 +636,8 @@ class WrapLinker(Linker): ...@@ -614,7 +636,8 @@ class WrapLinker(Linker):
order_list0 = order_lists[0] order_list0 = order_lists[0]
for order_list in order_lists[1:]: for order_list in order_lists[1:]:
if not order_list0 == order_list: if not order_list0 == order_list:
raise Exception("All linkers to WrapLinker should execute operations in the same order.") raise Exception(
"All linkers to WrapLinker should execute operations in the same order.")
inputs0 = input_lists[0] inputs0 = input_lists[0]
outputs0 = output_lists[0] outputs0 = output_lists[0]
...@@ -631,13 +654,15 @@ class WrapLinker(Linker): ...@@ -631,13 +654,15 @@ class WrapLinker(Linker):
wrapper = self.wrapper wrapper = self.wrapper
pre = self.pre pre = self.pre
def f(): def f():
for inputs in input_lists[1:]: for inputs in input_lists[1:]:
for input1, input2 in zip(inputs0, inputs): for input1, input2 in zip(inputs0, inputs):
input2.storage[0] = copy(input1.storage[0]) input2.storage[0] = copy(input1.storage[0])
for x in to_reset: for x in to_reset:
x[0] = None x[0] = None
pre(self, [input.data for input in input_lists[0]], order, thunk_groups) pre(self, [input.data for input in input_lists[0]],
order, thunk_groups)
for i, (thunks, node) in enumerate(zip(thunk_groups, order)): for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
try: try:
wrapper(i, node, *thunks) wrapper(i, node, *thunks)
...@@ -647,6 +672,7 @@ class WrapLinker(Linker): ...@@ -647,6 +672,7 @@ class WrapLinker(Linker):
return f, inputs0, outputs0 return f, inputs0, outputs0
def WrapLinkerMany(linkers, wrappers): def WrapLinkerMany(linkers, wrappers):
""" """
Variant on WrapLinker that runs a series of wrapper functions instead of Variant on WrapLinker that runs a series of wrapper functions instead of
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论