提交 995359de authored 作者: Frederic's avatar Frederic

pep8

上级 66cb3bf6
"""WRITEME"""
from copy import copy
import sys
import traceback
import theano
from theano.gof import utils
from theano.gof import graph
from theano.gof.type import Type
import sys, traceback
from copy import copy
__excepthook = sys.excepthook
def log_thunk_trace(value, f=sys.stderr):
"""Log theano's diagnostic stack trace for an exception
raised by raise_with_op.
......@@ -24,15 +26,15 @@ def log_thunk_trace(value, f=sys.stderr):
if trace2 is None:
write("Could not find where this Op was defined.")
write(" * You might have instantiated this Op "
"directly instead of using a constructor.")
"directly instead of using a constructor.")
write(" * The Op you constructed might have been"
" optimized. Try turning off optimizations.")
" optimized. Try turning off optimizations.")
elif trace2:
write("Definition in: ")
for line in traceback.format_list(trace2):
write(line)
write("For the full definition stack trace set"
" the Theano flags traceback.limit to -1")
" the Theano flags traceback.limit to -1")
def thunk_hook(type, value, trace):
......@@ -110,6 +112,7 @@ def raise_with_op(op, exc_info=None):
raise_with_op.print_thunk_trace = False
class Linker(object):
"""WRITEME"""
......@@ -132,10 +135,11 @@ class Linker(object):
print new_e.data # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
"""
raise utils.MethodNotDefined("make_thunk", type(self), self.__class__.__name__)
raise utils.MethodNotDefined("make_thunk", type(self),
self.__class__.__name__)
## DELETEME ##
def make_function(self, unpack_single = True, **kwargs):
def make_function(self, unpack_single=True, **kwargs):
"""
Returns a function that takes values corresponding to the inputs of the
fgraph used by this L{Linker} and returns values corresponding the the outputs
......@@ -155,6 +159,7 @@ class Linker(object):
length 1 will be returned.
"""
thunk, inputs, outputs = self.make_thunk(**kwargs)
def execute(*args):
def e_arity(takes, got):
return 'Function call takes exactly %i %s (%i given)' \
......@@ -165,7 +170,8 @@ class Linker(object):
variable.data = arg
thunk()
if unpack_single:
return utils.to_return_values([variable.data for variable in outputs])
return utils.to_return_values([variable.data
for variable in outputs])
else:
return [variable.data for variable in outputs]
execute.thunk = thunk
......@@ -177,13 +183,14 @@ class Linker(object):
def schedule(self, fgraph):
return fgraph.toposort()
#TODO: Move this class to the compile module, where it is used (and for which it exists).
class Container(object):
"""This class joins a variable with its computed value.
It is used in linkers, especially for the inputs and outputs of a Function.
"""
def __init__(self, r, storage, readonly=False, strict=False,
allow_downcast=None, name=None):
allow_downcast=None, name=None):
"""WRITEME
:Parameters:
......@@ -228,8 +235,10 @@ class Container(object):
kwargs['strict'] = True
if self.allow_downcast is not None:
kwargs['allow_downcast'] = self.allow_downcast
if hasattr(self.type,'filter_inplace'):
self.storage[0] = self.type.filter_inplace(value, self.storage[0], **kwargs)
if hasattr(self.type, 'filter_inplace'):
self.storage[0] = self.type.filter_inplace(value,
self.storage[0],
**kwargs)
else:
self.storage[0] = self.type.filter(value, **kwargs)
......@@ -238,8 +247,10 @@ class Container(object):
raise
data = property(__get__, __set__)
value = property(__get__, __set__)
def __str__(self):
return "<" + str(self.storage[0]) + ">"
def __repr__(self):
return "<" + repr(self.storage[0]) + ">"
......@@ -302,6 +313,7 @@ def map_storage(fgraph, order, input_storage, output_storage):
return input_storage, output_storage, storage_map
def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
no_recycling=None, profiler=None, nice_errors=True):
"""WRITEME
......@@ -331,18 +343,20 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
if len(thunks) != len(order):
raise ValueError('Length of thunks and order must match',
(len(thunks), len(order)))
(len(thunks), len(order)))
if post_thunk_old_storage:
if len(thunks) != len(post_thunk_old_storage):
raise ValueError('Length of thunks and post_thunk_old_storage must match',
(len(thunks), len(post_thunk_old_storage)))
raise ValueError(
'Length of thunks and post_thunk_old_storage must match',
(len(thunks), len(post_thunk_old_storage)))
def streamline_default_f():
for x in no_recycling:
x[0] = None
try:
for thunk, node, old_storage in zip(thunks, order, post_thunk_old_storage):
for thunk, node, old_storage in zip(thunks, order,
post_thunk_old_storage):
thunk()
for old_s in old_storage:
old_s[0] = None
......@@ -351,6 +365,7 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
f = streamline_default_f
elif nice_errors:
thunk_node_list = zip(thunks, order)
def streamline_nice_errors_f():
for x in no_recycling:
x[0] = None
......@@ -371,16 +386,18 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
f = streamline_fast_f
return f
class LocalLinker(Linker):
"""WRITEME
Useful base class for L{Linker}s which keep all nodes in the graph, and run a
thunk associated with each node.
"""
def make_thunk(self, profiler = None, input_storage = None, output_storage = None):
return self.make_all(profiler = profiler,
input_storage = input_storage,
output_storage = output_storage)[:3]
def make_thunk(self, profiler=None, input_storage=None,
output_storage=None):
return self.make_all(profiler=profiler,
input_storage=input_storage,
output_storage=output_storage)[:3]
def make_all(self, profiler, input_storage, output_storage):
# By convention, subclasses of LocalLinker should implement this function!
......@@ -391,7 +408,9 @@ class LocalLinker(Linker):
# 3. output storage
# 4. thunks: list of nodes' functions in the order they will be run by the function in (1)
# 5. order: list of nodes, in the order they will be run by the function in (1)
raise utils.MethodNotDefined("make_all", type(self), self.__class__.__name__)
raise utils.MethodNotDefined("make_all", type(self),
self.__class__.__name__)
def gc_helper(node_list):
"""
......@@ -413,6 +432,7 @@ def gc_helper(node_list):
computed.add(output)
return computed, last_user
class PerformLinker(LocalLinker):
"""WRITEME
......@@ -445,7 +465,7 @@ class PerformLinker(LocalLinker):
self.no_recycling = no_recycling
return self
def make_all(self, profiler = None, input_storage = None, output_storage = None):
def make_all(self, profiler=None, input_storage=None, output_storage=None):
"""
:param profiler: WRITEME
:param input_storage: WRITEME
......@@ -460,7 +480,6 @@ class PerformLinker(LocalLinker):
input_storage, output_storage, storage_map = map_storage(fgraph, order, input_storage, output_storage)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
......@@ -475,9 +494,9 @@ class PerformLinker(LocalLinker):
try:
node.op._op_use_c_code = False
thunks += [node.op.make_thunk(node,
storage_map,
compute_map,
no_recycling)]
storage_map,
compute_map,
no_recycling)]
finally:
node.op._op_use_c_code = old_value
......@@ -511,6 +530,7 @@ class PerformLinker(LocalLinker):
[Container(output, storage, True) for output, storage in zip(fgraph.outputs, output_storage)], \
thunks, order
def add_clear_storage(f, computed, storage_map):
def clear_storage():
for c in computed:
......@@ -572,8 +592,8 @@ class WrapLinker(Linker):
is called. In this case, we want the wrapped linkers to be copied too.
"""
other = self.__class__(
linkers=[copy(l) for l in self.linkers],
wrapper=self.wrapper)
linkers=[copy(l) for l in self.linkers],
wrapper=self.wrapper)
return other
def accept(self, fgraph, no_recycling=None):
......@@ -591,11 +611,13 @@ class WrapLinker(Linker):
if no_recycling is None:
no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph:
return type(self)(self.linkers, self.wrapper).accept(fgraph, no_recycling)
return type(self)(self.linkers, self.wrapper).accept(fgraph,
no_recycling)
self.fgraph = fgraph
self.no_recycling = no_recycling
self.linkers = [linker.accept(fgraph, no_recycling) for linker in self.linkers]
self.linkers = [linker.accept(fgraph, no_recycling)
for linker in self.linkers]
return self
def pre(self, f, inputs, order, thunk_groups):
......@@ -614,7 +636,8 @@ class WrapLinker(Linker):
order_list0 = order_lists[0]
for order_list in order_lists[1:]:
if not order_list0 == order_list:
raise Exception("All linkers to WrapLinker should execute operations in the same order.")
raise Exception(
"All linkers to WrapLinker should execute operations in the same order.")
inputs0 = input_lists[0]
outputs0 = output_lists[0]
......@@ -631,13 +654,15 @@ class WrapLinker(Linker):
wrapper = self.wrapper
pre = self.pre
def f():
for inputs in input_lists[1:]:
for input1, input2 in zip(inputs0, inputs):
input2.storage[0] = copy(input1.storage[0])
for x in to_reset:
x[0] = None
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
pre(self, [input.data for input in input_lists[0]],
order, thunk_groups)
for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
try:
wrapper(i, node, *thunks)
......@@ -647,6 +672,7 @@ class WrapLinker(Linker):
return f, inputs0, outputs0
def WrapLinkerMany(linkers, wrappers):
"""
Variant on WrapLinker that runs a series of wrapper functions instead of
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论