提交 6800a4ff authored 作者: Frederic's avatar Frederic

pep8

上级 aa775b96
...@@ -5,14 +5,12 @@ import logging ...@@ -5,14 +5,12 @@ import logging
import sys import sys
import time import time
import link import link
import traceback
from theano.gof.python25 import all from theano.gof.python25 import all
import theano import theano
config = theano.config config = theano.config
from theano.configparser import config, AddConfigVar, BoolParam from theano.configparser import config, AddConfigVar, BoolParam
from theano import config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -33,13 +31,13 @@ class VM(object): ...@@ -33,13 +31,13 @@ class VM(object):
number of times thunks[i] was called in the course of computations number of times thunks[i] was called in the course of computations
performed by call_with_timers(). performed by call_with_timers().
call_times - list of floats, one for each thunk. call_times[i] is the amount call_times - list of floats, one for each thunk. call_times[i] is
of runtime spent on thunks[i] in the course of computations performed by the amount of runtime spent on thunks[i] in the course of
call_with_timers(). computations performed by call_with_timers().
need_update_inputs - bool. True indicates that Function.__call__ must need_update_inputs - bool. True indicates that Function.__call__
implement the feedback from output storage to input storage. False means must implement the feedback from output storage to input
it *must not* repeat that feedback. storage. False means it *must not* repeat that feedback.
""" """
def __init__(self, nodes, thunks, pre_call_clear): def __init__(self, nodes, thunks, pre_call_clear):
...@@ -58,8 +56,8 @@ class VM(object): ...@@ -58,8 +56,8 @@ class VM(object):
self.nodes = nodes self.nodes = nodes
self.thunks = thunks self.thunks = thunks
self.pre_call_clear = pre_call_clear self.pre_call_clear = pre_call_clear
self.call_counts = [0]*len(nodes) self.call_counts = [0] * len(nodes)
self.call_times = [0]*len(nodes) self.call_times = [0] * len(nodes)
self.time_thunks = False self.time_thunks = False
# This variable (self.need_update_inputs) is overshadowed by # This variable (self.need_update_inputs) is overshadowed by
...@@ -88,14 +86,15 @@ class VM(object): ...@@ -88,14 +86,15 @@ class VM(object):
def update_profile(self, profile): def update_profile(self, profile):
# accumulate into the profile object # accumulate into the profile object
for node, thunk, t, c in zip(self.nodes, self.thunks, self.call_times, self.call_counts): for node, thunk, t, c in zip(self.nodes, self.thunks,
profile.apply_time.setdefault(node,0.0) self.call_times, self.call_counts):
profile.apply_time.setdefault(node, 0.0)
profile.apply_time[node] += t profile.apply_time[node] += t
profile.apply_callcount.setdefault(node,0) profile.apply_callcount.setdefault(node, 0)
profile.apply_callcount[node] += c profile.apply_callcount[node] += c
profile.apply_cimpl[node] = hasattr(thunk,'cthunk') profile.apply_cimpl[node] = hasattr(thunk, 'cthunk')
# clear the timer info out of the buffers # clear the timer info out of the buffers
for i in xrange(len(self.call_times)): for i in xrange(len(self.call_times)):
...@@ -113,7 +112,8 @@ class Loop(VM): ...@@ -113,7 +112,8 @@ class Loop(VM):
for cont in self.pre_call_clear: for cont in self.pre_call_clear:
cont[0] = None cont[0] = None
try: try:
for i, (thunk, node) in enumerate(zip(self.thunks, self.nodes)): for i, (thunk, node) in enumerate(zip(self.thunks,
self.nodes)):
t0 = time.time() t0 = time.time()
thunk() thunk()
t1 = time.time() t1 = time.time()
...@@ -141,13 +141,16 @@ class LoopGC(VM): ...@@ -141,13 +141,16 @@ class LoopGC(VM):
self.post_thunk_clear = post_thunk_clear self.post_thunk_clear = post_thunk_clear
if not (len(nodes) == len(thunks) == len(post_thunk_clear)): if not (len(nodes) == len(thunks) == len(post_thunk_clear)):
raise ValueError() raise ValueError()
def __call__(self): def __call__(self):
if self.time_thunks: if self.time_thunks:
for cont in self.pre_call_clear: for cont in self.pre_call_clear:
cont[0] = None cont[0] = None
try: try:
i = 0 i = 0
for thunk, node, old_storage in zip(self.thunks, self.nodes, self.post_thunk_clear): for thunk, node, old_storage in zip(self.thunks,
self.nodes,
self.post_thunk_clear):
t0 = time.time() t0 = time.time()
thunk() thunk()
t1 = time.time() t1 = time.time()
...@@ -162,7 +165,8 @@ class LoopGC(VM): ...@@ -162,7 +165,8 @@ class LoopGC(VM):
for cont in self.pre_call_clear: for cont in self.pre_call_clear:
cont[0] = None cont[0] = None
try: try:
for thunk, node, old_storage in zip(self.thunks, self.nodes, self.post_thunk_clear): for thunk, node, old_storage in zip(self.thunks, self.nodes,
self.post_thunk_clear):
thunk() thunk()
for old_s in old_storage: for old_s in old_storage:
old_s[0] = None old_s[0] = None
...@@ -200,8 +204,8 @@ class Stack(VM): ...@@ -200,8 +204,8 @@ class Stack(VM):
for i, node in enumerate(self.nodes): for i, node in enumerate(self.nodes):
node_idx[node] = i node_idx[node] = i
self.apply_time[node] = 0 self.apply_time[node] = 0
self.outputs_size[node] = [] self.outputs_size[node] = []
node.destroy_dependencies = [] node.destroy_dependencies = []
if node in ords: if node in ords:
for prereq in ords[node]: for prereq in ords[node]:
...@@ -217,9 +221,9 @@ class Stack(VM): ...@@ -217,9 +221,9 @@ class Stack(VM):
if cl[0] is not 'output': if cl[0] is not 'output':
ls += cl[0].outputs ls += cl[0].outputs
dependencies[k] += ls dependencies[k] += ls
if config.profile: if config.profile:
self.memory_size_map = {"nt8": 1, "t16": 2, "t32": 4, "t64": 8, "128": 16} self.memory_size_map = {"nt8": 1, "t16": 2, "t32": 4,
"t64": 8, "128": 16}
atexit.register(self.atexit_print_all) atexit.register(self.atexit_print_all)
def run_thunk_of_node(self, node): def run_thunk_of_node(self, node):
...@@ -257,11 +261,13 @@ class Stack(VM): ...@@ -257,11 +261,13 @@ class Stack(VM):
last_apply_stack_len = -1 last_apply_stack_len = -1
ls = [] ls = []
while apply_stack: while apply_stack:
# Make sure something happened last time round. # Make sure something happened last time round. This is
# This is just a safety check to make sure the op is written correctly # just a safety check to make sure the op is written
# apply_stack should either decrease in length by one (a thunk successfully applied), or # correctly apply_stack should either decrease in length
# increase in length (added dependencies over and above the original). # by one (a thunk successfully applied), or increase in
# NB: this doesn't catch cycles (would be too expensive/slow), just stalls. # length (added dependencies over and above the original).
# NB: this doesn't catch cycles (would be too expensive/slow),
# just stalls.
apply_stack_len = len(apply_stack) apply_stack_len = len(apply_stack)
assert apply_stack_len != last_apply_stack_len assert apply_stack_len != last_apply_stack_len
last_apply_stack_len = apply_stack_len last_apply_stack_len = apply_stack_len
...@@ -289,8 +295,8 @@ class Stack(VM): ...@@ -289,8 +295,8 @@ class Stack(VM):
if not thunks[self.node_idx[current_apply]].lazy: if not thunks[self.node_idx[current_apply]].lazy:
# Check if all inputs are in place # Check if all inputs are in place
# If so compute thunk and remove it from the apply_stack # If so compute thunk and remove it from the apply_stack
# If not leave it in, and add to the apply_stack those that will # If not leave it in, and add to the apply_stack those
# produce you those inputs # that will produce you those inputs
if computed_ins and not computed_outs: if computed_ins and not computed_outs:
try: try:
...@@ -302,22 +308,26 @@ class Stack(VM): ...@@ -302,22 +308,26 @@ class Stack(VM):
# ?? What about inplace .. if the op is inplace # ?? What about inplace .. if the op is inplace
# you don't actually ask for more memory! # you don't actually ask for more memory!
size = [] size = []
for (idx,o) in enumerate( for (idx, o) in enumerate(
thunks[self.node_idx[current_apply]].outputs): thunks[self.node_idx[
if not hasattr(o[0],'size'): current_apply]].outputs):
if not hasattr(o[0], 'size'):
size.append(-1) size.append(-1)
continue continue
s=o[0].size s = o[0].size
dtype = str(o[0].dtype) dtype = str(o[0].dtype)
dtype2 = dtype[-3:] dtype2 = dtype[-3:]
s *= self.memory_size_map[dtype2] # KeyError here: couldn't determine the dtype memory size # KeyError here: couldn't determine
# the dtype memory size
s *= self.memory_size_map[dtype2]
size.append(s) size.append(s)
self.outputs_size[current_apply] = size self.outputs_size[current_apply] = size
except Exception: except Exception:
raise_with_op(current_apply) raise_with_op(current_apply)
for o in current_apply.outputs: for o in current_apply.outputs:
compute_map[o][0] = 1 compute_map[o][0] = 1
# Garbage Collection -> check if anybody else uses this input # Garbage Collection -> check if anybody else uses
# this input
if self.allow_gc: if self.allow_gc:
for i in current_apply.inputs: for i in current_apply.inputs:
if (dependencies[i] and i.owner if (dependencies[i] and i.owner
...@@ -332,8 +342,11 @@ class Stack(VM): ...@@ -332,8 +342,11 @@ class Stack(VM):
elif not computed_ins: elif not computed_ins:
apply_stack.append(current_apply) apply_stack.append(current_apply)
apply_stack.extend(inp.owner for inp in current_apply.inputs if inp.owner) apply_stack.extend(inp.owner for inp
apply_stack.extend(inp.owner for inp in current_apply.destroy_dependencies if inp.owner) in current_apply.inputs if inp.owner)
apply_stack.extend(inp.owner for inp
in current_apply.destroy_dependencies
if inp.owner)
elif not computed_outs: elif not computed_outs:
# Try and run it to see if it works # Try and run it to see if it works
...@@ -346,22 +359,26 @@ class Stack(VM): ...@@ -346,22 +359,26 @@ class Stack(VM):
if requires: if requires:
for r in requires: for r in requires:
# We are not done with this op .. # We are not done with this op .. so we added
# so we added back and see to get the inputs we are missing # back and see to get the inputs we are
# missing
apply_stack.append(current_apply) apply_stack.append(current_apply)
if current_apply.inputs[r].owner: if current_apply.inputs[r].owner:
apply_stack.append(current_apply.inputs[r].owner) apply_stack.append(current_apply.inputs[r].owner)
else: else:
if config.profile: if config.profile:
size = [] size = []
for (idx,o) in enumerate(thunks[self.node_idx[current_apply]].outputs): for (idx, o) in enumerate(thunks[
self.node_idx[current_apply]].outputs):
if not hasattr(o[0], 'size'): if not hasattr(o[0], 'size'):
size.append(-1) size.append(-1)
continue continue
s=o[0].size s=o[0].size
dtype = str(o[0].dtype) dtype = str(o[0].dtype)
dtype2 = dtype[-2:] dtype2 = dtype[-2:]
s *= self.memory_size_map[dtype2] # KeyError here: couldn't determine the dtype memory size # KeyError here: couldn't determine the
# dtype memory size
s *= self.memory_size_map[dtype2]
size.append(s) size.append(s)
self.outputs_size[current_apply] = size self.outputs_size[current_apply] = size
if self.allow_gc: if self.allow_gc:
...@@ -379,6 +396,7 @@ class Stack(VM): ...@@ -379,6 +396,7 @@ class Stack(VM):
try: try:
import lazylinker_c import lazylinker_c
class CVM(lazylinker_c.CLazyLinker, VM): class CVM(lazylinker_c.CLazyLinker, VM):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
lazylinker_c.CLazyLinker.__init__(self, *args, **kwargs) lazylinker_c.CLazyLinker.__init__(self, *args, **kwargs)
...@@ -394,9 +412,9 @@ class VM_Linker(link.LocalLinker): ...@@ -394,9 +412,9 @@ class VM_Linker(link.LocalLinker):
def __init__(self, allow_gc=True, use_cloop=False, callback=None): def __init__(self, allow_gc=True, use_cloop=False, callback=None):
""" """
allow_gc - force the virtual machine to clean up unnecessary references, allow_gc - force the virtual machine to clean up unnecessary
in order to allow garbage collection on intermediate values during references, in order to allow garbage collection on
computation of a function. intermediate values during computation of a function.
use_cloop - use the C-based virtual machine if possible use_cloop - use the C-based virtual machine if possible
...@@ -411,9 +429,10 @@ class VM_Linker(link.LocalLinker): ...@@ -411,9 +429,10 @@ class VM_Linker(link.LocalLinker):
self.callback = callback self.callback = callback
self.updated_vars = {} self.updated_vars = {}
def accept(self, env, no_recycling = []): def accept(self, env, no_recycling=[]):
""" """
:param env: a PerformLinker can have accepted one Env instance at a time. :param env: a PerformLinker can have accepted one Env instance
at a time.
:param no_recycling: WRITEME :param no_recycling: WRITEME
...@@ -464,9 +483,9 @@ class VM_Linker(link.LocalLinker): ...@@ -464,9 +483,9 @@ class VM_Linker(link.LocalLinker):
nodes_idx_inv = {} nodes_idx_inv = {}
vars_idx_inv = {} vars_idx_inv = {}
for (node,i) in nodes_idx.items(): for (node, i) in nodes_idx.items():
nodes_idx_inv[i] = node nodes_idx_inv[i] = node
for (var,i) in vars_idx.items(): for (var, i) in vars_idx.items():
vars_idx_inv[i] = var vars_idx_inv[i] = var
# put storage_map and compute_map into a int-based scheme # put storage_map and compute_map into a int-based scheme
...@@ -496,8 +515,8 @@ class VM_Linker(link.LocalLinker): ...@@ -496,8 +515,8 @@ class VM_Linker(link.LocalLinker):
base_input_output_list.extend(outputs_idx) base_input_output_list.extend(outputs_idx)
# build the var owner array # build the var owner array
var_owner = [None]*len(vars_idx) var_owner = [None] * len(vars_idx)
for (var,i) in vars_idx.items(): for (var, i) in vars_idx.items():
if var.owner: if var.owner:
var_owner[i] = nodes_idx[var.owner] var_owner[i] = nodes_idx[var.owner]
...@@ -511,18 +530,18 @@ class VM_Linker(link.LocalLinker): ...@@ -511,18 +530,18 @@ class VM_Linker(link.LocalLinker):
for i, node in enumerate(nodes): for i, node in enumerate(nodes):
node_output_size.append(0) node_output_size.append(0)
prereq_var_idxs = [] prereq_var_idxs = []
for prereq_node in ords.get(node,[]): for prereq_node in ords.get(node, []):
prereq_var_idxs.extend( prereq_var_idxs.extend(
[vars_idx[v] for v in prereq_node.outputs]) [vars_idx[v] for v in prereq_node.outputs])
prereq_var_idxs = list(set(prereq_var_idxs)) prereq_var_idxs = list(set(prereq_var_idxs))
prereq_var_idxs.sort() # TODO: why sort? prereq_var_idxs.sort() # TODO: why sort?
node_prereqs.append(prereq_var_idxs) node_prereqs.append(prereq_var_idxs)
update_storage = [] update_storage = []
for (ivar, ovar) in updated_vars.items(): for (ivar, ovar) in updated_vars.items():
if ivar != ovar: if ivar != ovar:
update_storage.append(vars_idx[ivar]) #dst update_storage.append(vars_idx[ivar]) # dst
update_storage.append(vars_idx[ovar]) #src update_storage.append(vars_idx[ovar]) # src
c0 = sys.getrefcount(node_n_inputs) c0 = sys.getrefcount(node_n_inputs)
vm = CVM( vm = CVM(
...@@ -530,8 +549,8 @@ class VM_Linker(link.LocalLinker): ...@@ -530,8 +549,8 @@ class VM_Linker(link.LocalLinker):
thunks, thunks,
pre_call_clear, pre_call_clear,
allow_gc=self.allow_gc, allow_gc=self.allow_gc,
call_counts=[0]*len(nodes), call_counts=[0] * len(nodes),
call_times=[0.0]*len(nodes), call_times=[0.0] * len(nodes),
compute_map_list=compute_map_list, compute_map_list=compute_map_list,
storage_map_list=storage_map_list, storage_map_list=storage_map_list,
base_input_output_list=base_input_output_list, base_input_output_list=base_input_output_list,
...@@ -569,7 +588,7 @@ class VM_Linker(link.LocalLinker): ...@@ -569,7 +588,7 @@ class VM_Linker(link.LocalLinker):
) )
return vm return vm
def make_all(self, profiler = None, input_storage = None, def make_all(self, profiler=None, input_storage=None,
output_storage = None, output_storage = None,
): ):
env = self.env env = self.env
...@@ -617,4 +636,3 @@ class VM_Linker(link.LocalLinker): ...@@ -617,4 +636,3 @@ class VM_Linker(link.LocalLinker):
for output, storage in zip(env.outputs, output_storage)], for output, storage in zip(env.outputs, output_storage)],
thunks, thunks,
order) order)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论