提交 897082b5 authored 作者: Frederic's avatar Frederic

pep8

上级 19092641
...@@ -4,17 +4,21 @@ Defines Linkers that deal with C implementations. ...@@ -4,17 +4,21 @@ Defines Linkers that deal with C implementations.
# Python imports # Python imports
from copy import copy from copy import copy
import re #for set_compiledir import re #for set_compiledir
import os, sys, StringIO import os
import StringIO
import sys
from itertools import izip from itertools import izip
if sys.version_info[:2] >= (2,5): if sys.version_info[:2] >= (2, 5):
import hashlib import hashlib
def hash_from_code(msg): def hash_from_code(msg):
return hashlib.md5(msg).hexdigest() return hashlib.md5(msg).hexdigest()
else: else:
import md5 import md5
def hash_from_code(msg): def hash_from_code(msg):
return md5.new(msg).hexdigest() return md5.new(msg).hexdigest()
...@@ -46,14 +50,13 @@ from compilelock import get_lock, release_lock ...@@ -46,14 +50,13 @@ from compilelock import get_lock, release_lock
import cmodule import cmodule
import logging import logging
_logger=logging.getLogger("theano.gof.cc") _logger = logging.getLogger("theano.gof.cc")
_logger.setLevel(logging.WARN) _logger.setLevel(logging.WARN)
from theano.gof.callcache import CallCache from theano.gof.callcache import CallCache
run_cthunk = None # Will be imported only when needed. run_cthunk = None # Will be imported only when needed.
def get_module_cache(init_args=None): def get_module_cache(init_args=None):
...@@ -63,37 +66,47 @@ def get_module_cache(init_args=None): ...@@ -63,37 +66,47 @@ def get_module_cache(init_args=None):
""" """
return cmodule.get_module_cache(config.compiledir, init_args=init_args) return cmodule.get_module_cache(config.compiledir, init_args=init_args)
_persistent_module_cache = None _persistent_module_cache = None
def get_persistent_module_cache(): def get_persistent_module_cache():
global _persistent_module_cache global _persistent_module_cache
if _persistent_module_cache is None: if _persistent_module_cache is None:
_persistent_module_cache = CallCache(os.path.join(config.compiledir, 'persistent_cache')) _persistent_module_cache = CallCache(os.path.join(config.compiledir,
'persistent_cache'))
return _persistent_module_cache return _persistent_module_cache
class CodeBlock: class CodeBlock:
"""WRITEME """WRITEME
Represents a computation unit composed of declare, behavior, and cleanup. Represents a computation unit composed of declare, behavior, and cleanup.
@ivar declare: C code that declares variables for use by the computation @ivar declare: C code that declares variables for use by the computation
@ivar behavior: C code that performs the computation @ivar behavior: C code that performs the computation
@ivar cleanup: C code that cleans up things allocated or incref-ed in behavior @ivar cleanup: C code that cleans up things allocated or incref-ed
in behavior
""" """
def __init__(self, declare, behavior, cleanup, sub): def __init__(self, declare, behavior, cleanup, sub):
""" """
Initialize a L{CodeBlock} with templatized declare, behavior and cleanup. Initialize a L{CodeBlock} with templatized declare, behavior
The sub parameter will be used in the other arguments' templates. sub and cleanup. The sub parameter will be used in the other
should contain a key called 'id' that maps to an identifier for this block. arguments' templates. sub should contain a key called 'id'
The identifier will be used to determine the failure code and a label that maps to an identifier for this block.
to jump to. It should also contain a key called 'failure_var' that contains The identifier will be used to determine the failure code and
the name of the variable that contains the error code. a label to jump to. It should also contain a key called
'failure_var' that contains the name of the variable that
contains the error code.
""" """
self.declare = declare self.declare = declare
self.behavior = behavior self.behavior = behavior
# the dummy is because gcc throws an error when a label's right next to a closing # the dummy is because gcc throws an error when a label's
# brace (maybe there's an ignore flag for that...) # right next to a closing brace (maybe there's an ignore flag
# we need the label even if cleanup is empty because the behavior block jumps there # for that...)
# on failure # we need the label even if cleanup is empty because the
self.cleanup = ("__label_%(id)i:\n"%sub + cleanup + "\ndouble __DUMMY_%(id)i;\n"%sub) #% sub # behavior block jumps there on failure
self.cleanup = ("__label_%(id)i:\n" % sub + cleanup +
"\ndouble __DUMMY_%(id)i;\n" % sub) # % sub
def failure_code(sub): def failure_code(sub):
...@@ -102,10 +115,10 @@ def failure_code(sub): ...@@ -102,10 +115,10 @@ def failure_code(sub):
def code_gen(blocks): def code_gen(blocks):
"""WRITEME """WRITEME From a list of L{CodeBlock} instances, returns a string
From a list of L{CodeBlock} instances, returns a string that executes them that executes them all in sequence. eg for C{(decl1, task1,
all in sequence. eg for C{(decl1, task1, cleanup1)} and C{(decl2, task2, cleanup2)} cleanup1)} and C{(decl2, task2, cleanup2)} the returned string
the returned string will be of the form:: will be of the form::
decl1 decl1
decl2 decl2
...@@ -181,10 +194,11 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -181,10 +194,11 @@ def struct_gen(args, struct_builders, blocks, sub):
args_names = ", ".join(args) args_names = ", ".join(args)
args_decl = ", ".join(["PyObject* %s" % arg for arg in args]) args_decl = ", ".join(["PyObject* %s" % arg for arg in args])
# The following code stores the exception data in __ERROR, which is a special # The following code stores the exception data in __ERROR, which
# field of the struct. __ERROR is a list of length 3 that holds the type, the # is a special field of the struct. __ERROR is a list of length 3
# value and the traceback. After storing the error, we return the failure code # that holds the type, the value and the traceback. After storing
# so we know which code block failed. # the error, we return the failure code so we know which code
# block failed.
do_return = """ do_return = """
if (%(failure_var)s) { if (%(failure_var)s) {
// When there is a failure, this code puts the exception // When there is a failure, this code puts the exception
...@@ -213,8 +227,8 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -213,8 +227,8 @@ def struct_gen(args, struct_builders, blocks, sub):
sub = dict(sub) sub = dict(sub)
sub.update(locals()) sub.update(locals())
# TODO: add some error checking to make sure storage_<x> are 1-element lists # TODO: add some error checking to make sure storage_<x> are
# and __ERROR is a 3-elements list. # 1-element lists and __ERROR is a 3-elements list.
struct_code = """ struct_code = """
struct %(name)s { struct %(name)s {
PyObject* __ERROR; PyObject* __ERROR;
...@@ -260,6 +274,7 @@ def get_nothing(r, name, sub): ...@@ -260,6 +274,7 @@ def get_nothing(r, name, sub):
"""WRITEME""" """WRITEME"""
return "" return ""
def get_c_declare(r, name, sub): def get_c_declare(r, name, sub):
"""WRITEME""" """WRITEME"""
pre = """ pre = """
...@@ -267,6 +282,7 @@ def get_c_declare(r, name, sub): ...@@ -267,6 +282,7 @@ def get_c_declare(r, name, sub):
""" % locals() """ % locals()
return pre + r.type.c_declare(name, sub) return pre + r.type.c_declare(name, sub)
def get_c_init(r, name, sub): def get_c_init(r, name, sub):
"""WRITEME""" """WRITEME"""
pre = "" """ pre = "" """
...@@ -275,6 +291,7 @@ def get_c_init(r, name, sub): ...@@ -275,6 +291,7 @@ def get_c_init(r, name, sub):
""" % locals() """ % locals()
return pre + r.type.c_init(name, sub) return pre + r.type.c_init(name, sub)
def get_c_extract(r, name, sub): def get_c_extract(r, name, sub):
"""WRITEME""" """WRITEME"""
pre = """ pre = """
...@@ -283,6 +300,7 @@ def get_c_extract(r, name, sub): ...@@ -283,6 +300,7 @@ def get_c_extract(r, name, sub):
""" % locals() """ % locals()
return pre + r.type.c_extract(name, sub) return pre + r.type.c_extract(name, sub)
def get_c_cleanup(r, name, sub): def get_c_cleanup(r, name, sub):
"""WRITEME""" """WRITEME"""
post = """ post = """
...@@ -290,6 +308,7 @@ def get_c_cleanup(r, name, sub): ...@@ -290,6 +308,7 @@ def get_c_cleanup(r, name, sub):
""" % locals() """ % locals()
return r.type.c_cleanup(name, sub) + post return r.type.c_cleanup(name, sub) + post
def get_c_sync(r, name, sub): def get_c_sync(r, name, sub):
"""WRITEME""" """WRITEME"""
return """ return """
...@@ -300,11 +319,13 @@ def get_c_sync(r, name, sub): ...@@ -300,11 +319,13 @@ def get_c_sync(r, name, sub):
PyList_SET_ITEM(storage_%(name)s, 0, py_%(name)s); PyList_SET_ITEM(storage_%(name)s, 0, py_%(name)s);
{Py_XDECREF(old);} {Py_XDECREF(old);}
} }
""" % dict(sync = r.type.c_sync(name, sub), name = name, **sub) """ % dict(sync=r.type.c_sync(name, sub), name=name, **sub)
def apply_policy(policy, r, name, sub): def apply_policy(policy, r, name, sub):
"""WRITEME """WRITEME
@param policy: list of functions that map a L{Variable} to a string, or a single such function @param policy: list of functions that map a L{Variable} to a string,
or a single such function
@type r: L{Variable} @type r: L{Variable}
@return: C{policy[0](r) + policy[1](r) + ...} @return: C{policy[0](r) + policy[1](r) + ...}
""" """
...@@ -316,7 +337,6 @@ def apply_policy(policy, r, name, sub): ...@@ -316,7 +337,6 @@ def apply_policy(policy, r, name, sub):
return policy(r, name, sub) return policy(r, name, sub)
def struct_variable_codeblocks(variable, policies, id, symbol_table, sub): def struct_variable_codeblocks(variable, policies, id, symbol_table, sub):
"""WRITEME """WRITEME
variable -> a Variable variable -> a Variable
...@@ -339,17 +359,20 @@ def struct_variable_codeblocks(variable, policies, id, symbol_table, sub): ...@@ -339,17 +359,20 @@ def struct_variable_codeblocks(variable, policies, id, symbol_table, sub):
sub['fail'] = failure_code(sub) sub['fail'] = failure_code(sub)
sub['py_ptr'] = "py_%s" % name sub['py_ptr'] = "py_%s" % name
sub['stor_ptr'] = "storage_%s" % name sub['stor_ptr'] = "storage_%s" % name
# struct_declare, struct_behavior, struct_cleanup, sub)
struct_builder = CodeBlock(*[apply_policy(policy, variable, name, sub) struct_builder = CodeBlock(*[apply_policy(policy, variable, name, sub)
for policy in policies[0]]+[sub]) # struct_declare, struct_behavior, struct_cleanup, sub) for policy in policies[0]]+[sub])
sub['id'] = id + 1 sub['id'] = id + 1
sub['fail'] = failure_code(sub) sub['fail'] = failure_code(sub)
sub['py_ptr'] = "py_%s" % name sub['py_ptr'] = "py_%s" % name
sub['stor_ptr'] = "storage_%s" % name sub['stor_ptr'] = "storage_%s" % name
# run_declare, run_behavior, run_cleanup, sub)
block = CodeBlock(*[apply_policy(policy, variable, name, sub) block = CodeBlock(*[apply_policy(policy, variable, name, sub)
for policy in policies[1]]+[sub]) # run_declare, run_behavior, run_cleanup, sub) for policy in policies[1]]+[sub])
return struct_builder, block return struct_builder, block
class CLinker(link.Linker): class CLinker(link.Linker):
"""WRITEME """WRITEME
...@@ -365,7 +388,7 @@ class CLinker(link.Linker): ...@@ -365,7 +388,7 @@ class CLinker(link.Linker):
def __init__(self): def __init__(self):
self.env = None self.env = None
def accept(self, env, no_recycling = []): def accept(self, env, no_recycling=[]):
"""WRITEME""" """WRITEME"""
if self.env is not None and self.env is not env: if self.env is not None and self.env is not env:
return type(self)().accept(env, no_recycling) return type(self)().accept(env, no_recycling)
...@@ -377,15 +400,21 @@ class CLinker(link.Linker): ...@@ -377,15 +400,21 @@ class CLinker(link.Linker):
def fetch_variables(self): def fetch_variables(self):
"""WRITEME """WRITEME
Fills the inputs, outputs, variables, orphans, temps and node_order fields. Fills the inputs, outputs, variables, orphans,
temps and node_order fields.
""" """
env = self.env env = self.env
self.inputs = env.inputs self.inputs = env.inputs
self.outputs = env.outputs self.outputs = env.outputs
self.variables = graph.variables(self.inputs, self.outputs) # list(env.variables) # list(env.variables)
self.variables = graph.variables(self.inputs, self.outputs)
# The orphans field is listified to ensure a consistent order. # The orphans field is listified to ensure a consistent order.
self.orphans = list(r for r in self.variables if isinstance(r, graph.Value) and r not in self.inputs) #list(env.orphans.difference(self.outputs)) #list(env.orphans.difference(self.outputs))
self.temps = list(set(self.variables).difference(self.inputs).difference(self.outputs).difference(self.orphans)) self.orphans = list(r for r in self.variables
if isinstance(r, graph.Value) and
r not in self.inputs)
self.temps = list(set(self.variables).difference(
self.inputs).difference(self.outputs).difference(self.orphans))
self.consts = [] self.consts = []
self.node_order = env.toposort() self.node_order = env.toposort()
...@@ -429,7 +458,7 @@ class CLinker(link.Linker): ...@@ -429,7 +458,7 @@ class CLinker(link.Linker):
failure_var = "__failure" failure_var = "__failure"
id = 1 id = 1
sub = dict(failure_var = failure_var) sub = dict(failure_var=failure_var)
for variable in self.variables: for variable in self.variables:
...@@ -455,36 +484,49 @@ class CLinker(link.Linker): ...@@ -455,36 +484,49 @@ class CLinker(link.Linker):
continue continue
except (utils.MethodNotDefined, NotImplementedError): except (utils.MethodNotDefined, NotImplementedError):
pass pass
# orphans are not inputs so we'll just get fetch them when we initialize the struct and assume they stay the same # orphans are not inputs so we'll just get fetch them
# when we initialize the struct and assume they stay
# the same
policy = [[get_c_declare, get_c_extract, get_c_cleanup], policy = [[get_c_declare, get_c_extract, get_c_cleanup],
[get_nothing, get_nothing, get_nothing]] [get_nothing, get_nothing, get_nothing]]
elif variable in self.temps: elif variable in self.temps:
# temps don't need to be extracted from Python, so we call c_init rather than c_extract # temps don't need to be extracted from Python, so we
# they do not need to be relayed to Python, so we don't sync # call c_init rather than c_extract they do not need
# to be relayed to Python, so we don't sync
if variable.type.c_is_simple() or variable in no_recycling: if variable.type.c_is_simple() or variable in no_recycling:
policy = [[get_nothing, get_nothing, get_nothing], policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_init, get_c_cleanup]] [get_c_declare, get_c_init, get_c_cleanup]]
else: else:
# it is useful for complex temps to reuse storage at each run, so we only clean up in the destructor # it is useful for complex temps to reuse storage
# at each run, so we only clean up in the
# destructor
policy = [[get_c_declare, get_c_init, get_c_cleanup], policy = [[get_c_declare, get_c_init, get_c_cleanup],
[get_nothing, get_nothing, get_nothing]] [get_nothing, get_nothing, get_nothing]]
elif variable in self.outputs: elif variable in self.outputs:
# outputs don't need to be extracted from Python, so we call c_init rather than c_extract # outputs don't need to be extracted from Python, so
# we call c_init rather than c_extract
if variable.type.c_is_simple() or variable in no_recycling: if variable.type.c_is_simple() or variable in no_recycling:
policy = [[get_nothing, get_nothing, get_nothing], policy = [[get_nothing, get_nothing, get_nothing],
[get_c_declare, get_c_init, (get_c_sync, get_c_cleanup)]] [get_c_declare, get_c_init,
(get_c_sync, get_c_cleanup)]]
else: else:
# it is useful for complex outputs to reuse storage at each run, so we only clean up in the destructor # it is useful for complex outputs to reuse
# storage at each run, so we only clean up in the
# destructor
policy = [[get_c_declare, get_c_init, get_c_cleanup], policy = [[get_c_declare, get_c_init, get_c_cleanup],
[get_nothing, get_nothing, get_c_sync]] [get_nothing, get_nothing, get_c_sync]]
else: else:
raise Exception("what the fuck") raise Exception("what the fuck")
builder, block = struct_variable_codeblocks(variable, policy, id, symbol, sub) builder, block = struct_variable_codeblocks(variable, policy,
id, symbol, sub)
# each Variable generates two CodeBlocks, one to declare/initialize/destroy struct variables # each Variable generates two CodeBlocks, one to
# and the other to declare/extract/cleanup each time the function is run. # declare/initialize/destroy struct variables and the
# Typically, only one of the two actually does anything (see all the possible combinations above) # other to declare/extract/cleanup each time the function
# is run.
# Typically, only one of the two actually does anything
# (see all the possible combinations above)
init_tasks.append((variable, 'init', id)) init_tasks.append((variable, 'init', id))
init_blocks.append(builder) init_blocks.append(builder)
...@@ -496,19 +538,22 @@ class CLinker(link.Linker): ...@@ -496,19 +538,22 @@ class CLinker(link.Linker):
for node_num, node in enumerate(self.node_order): for node_num, node in enumerate(self.node_order):
# We populate sub with a mapping from the variable names specified by the op's c_var_names # We populate sub with a mapping from the variable names
# method to the actual variable names that we will use. # specified by the op's c_var_names method to the actual
# variable names that we will use.
## ivnames, ovnames = op.c_var_names() ## ivnames, ovnames = op.c_var_names()
sub = dict(failure_var = failure_var) sub = dict(failure_var = failure_var)
## for variable, vname in zip(op.inputs + op.outputs, ivnames + ovnames): ## for variable, vname in zip(op.inputs + op.outputs, ivnames + ovnames):
## sub[vname] = symbol[variable] ## sub[vname] = symbol[variable]
name = "node_%i" % node_num name = "node_%i" % node_num
isyms, osyms = [symbol[r] for r in node.inputs], [symbol[r] for r in node.outputs] isyms = [symbol[r] for r in node.inputs]
osyms = [symbol[r] for r in node.outputs]
# c_validate_update is deprecated # c_validate_update is deprecated
if hasattr(node.op, 'c_validate_update'): if hasattr(node.op, 'c_validate_update'):
raise Exception("c_validate_update is deprecated, move contents to c_code", node.op) raise Exception("c_validate_update is deprecated,"
" move contents to c_code", node.op)
# Make the CodeBlock for c_code # Make the CodeBlock for c_code
sub['id'] = id sub['id'] = id
...@@ -517,20 +562,23 @@ class CLinker(link.Linker): ...@@ -517,20 +562,23 @@ class CLinker(link.Linker):
op = node.op op = node.op
# type-specific support code # type-specific support code
try: try:
c_support_code_apply.append(op.c_support_code_apply(node, name)) c_support_code_apply.append(op.c_support_code_apply(node,
name))
except utils.MethodNotDefined: except utils.MethodNotDefined:
pass pass
else: else:
# The following will be executed if the "try" block succeeds # The following will be executed if the "try" block succeeds
assert isinstance(c_support_code_apply[-1], basestring), ( assert isinstance(c_support_code_apply[-1], basestring), (
str(node.op)+" didn't returned a string for c_support_code_apply") str(node.op) +
" didn't returned a string for c_support_code_apply")
# emit c_code # emit c_code
try: try:
behavior = op.c_code(node, name, isyms, osyms, sub) behavior = op.c_code(node, name, isyms, osyms, sub)
except utils.MethodNotDefined: except utils.MethodNotDefined:
raise NotImplementedError("%s cannot produce C code" % op) raise NotImplementedError("%s cannot produce C code" % op)
assert isinstance(behavior, basestring), str(node.op)+" didn't returned a string for c_code" assert isinstance(behavior, basestring), (
str(node.op) + " didn't returned a string for c_code")
try: try:
cleanup = op.c_code_cleanup(node, name, isyms, osyms, sub) cleanup = op.c_code_cleanup(node, name, isyms, osyms, sub)
...@@ -543,18 +591,24 @@ class CLinker(link.Linker): ...@@ -543,18 +591,24 @@ class CLinker(link.Linker):
tasks.append((node, 'code', id)) tasks.append((node, 'code', id))
id += 1 id += 1
# List of arg names for use in struct_gen. Note the call to uniq: duplicate inputs # List of arg names for use in struct_gen. Note the call to
# must only be passed once because they are mapped to the same name. # uniq: duplicate inputs must only be passed once because they
# Duplicates are defined by (a is b), rather than (a==b) since Constant instances can # are mapped to the same name. Duplicates are defined by (a
# is b), rather than (a==b) since Constant instances can
# compare equal to equivalent Constant instances. # compare equal to equivalent Constant instances.
args = [] args = []
args += ["storage_%s" % symbol[variable] for variable in utils.uniq(self.inputs + self.outputs + self.orphans)] args += ["storage_%s" % symbol[variable] for variable
in utils.uniq(self.inputs + self.outputs + self.orphans)]
struct_code = struct_gen(args, init_blocks, blocks, dict(failure_var = failure_var, name = "<<<<NAME>>>>")) struct_code = struct_gen(args, init_blocks, blocks,
dict(failure_var=failure_var,
name="<<<<NAME>>>>"))
# TODO: still needed? We do not use weave anymore. # TODO: still needed? We do not use weave anymore.
# The hash calculated on the code identifies it so weave can cache properly. # The hash calculated on the code identifies it so weave can
# (the hash has to be used outside of the support code because weave does not consider changes in the support code) # cache properly. (the hash has to be used outside of the
# support code because weave does not consider changes in the
# support code)
hash = hash_from_code(struct_code) hash = hash_from_code(struct_code)
struct_name = '__struct_compiled_op_%s' % hash struct_name = '__struct_compiled_op_%s' % hash
...@@ -582,7 +636,8 @@ class CLinker(link.Linker): ...@@ -582,7 +636,8 @@ class CLinker(link.Linker):
# List of indices that should be ignored when passing the arguments # List of indices that should be ignored when passing the arguments
# (basically, everything that the previous call to uniq eliminated) # (basically, everything that the previous call to uniq eliminated)
self.dupidx = [i for i, x in enumerate(all) if all.count(x) > 1 and all.index(x) != i] self.dupidx = [i for i, x in enumerate(all)
if all.count(x) > 1 and all.index(x) != i]
return self.struct_code return self.struct_code
def support_code(self): def support_code(self):
...@@ -595,9 +650,12 @@ class CLinker(link.Linker): ...@@ -595,9 +650,12 @@ class CLinker(link.Linker):
""" """
ret = [] ret = []
# generic support code # generic support code
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [
try: ret.append(x.c_support_code()) y.op for y in self.node_order]:
except utils.MethodNotDefined: pass try:
ret.append(x.c_support_code())
except utils.MethodNotDefined:
pass
return ret return ret
def compile_args(self): def compile_args(self):
...@@ -608,33 +666,43 @@ class CLinker(link.Linker): ...@@ -608,33 +666,43 @@ class CLinker(link.Linker):
This might contain duplicates. This might contain duplicates.
""" """
ret = ["-O3"] ret = ["-O3"]
# this is the param the -ffast-math activate. I put the explicitly as FillMissing must disable some of them. Putting -ffast-math would make it disable all other parameter at the same time. # this is the param the -ffast-math activate. I put the explicitly as
# FillMissing must disable some of them. Putting -ffast-math would
# make it disable all other parameter at the same time.
ret += ["-fno-math-errno", ret += ["-fno-math-errno",
#"-funsafe-math-optimizations", #"-funsafe-math-optimizations",
#"-fno-signaling-nans", #"-fno-signaling-nans",
#"-fcx-limited-range", #"-fcx-limited-range",
#"-fno-rounding-math", #"-fno-rounding-math",
#"-ffinite-math-only", #"-ffinite-math-only",
"-Wno-unused-label",#the current code generate label event if they are not used. Could use gcc attribute for those label only
"-Wno-unused-variable",#idem as the precedent #the current code generate label event if they are not used.
"-Wno-write-strings",#generated by our code generator... #Could use gcc attribute for those label only
"-Wno-unused-label",
"-Wno-unused-variable", # idem as the precedent
"-Wno-write-strings", # generated by our code generator...
] ]
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [
try: ret += x.c_compile_args() y.op for y in self.node_order]:
except utils.MethodNotDefined: pass try:
ret += x.c_compile_args()
except utils.MethodNotDefined:
pass
c_compiler = self.c_compiler() c_compiler = self.c_compiler()
ret += c_compiler.compile_args() ret += c_compiler.compile_args()
ret=list(set(ret))#to remove duplicate ret = list(set(ret)) # to remove duplicate
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
try: try:
for i in x.c_no_compile_args(): for i in x.c_no_compile_args():
try: try:
ret.remove(i) ret.remove(i)
except ValueError: except ValueError:
pass# in case the value is not there pass # in case the value is not there
except utils.MethodNotDefined: pass except utils.MethodNotDefined:
pass
return ret return ret
def headers(self): def headers(self):
...@@ -645,14 +713,18 @@ class CLinker(link.Linker): ...@@ -645,14 +713,18 @@ class CLinker(link.Linker):
The return value will not contain duplicates. The return value will not contain duplicates.
""" """
ret = [] ret = []
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [
try: ret += x.c_headers() y.op for y in self.node_order]:
except utils.MethodNotDefined: pass try:
ret += x.c_headers()
except utils.MethodNotDefined:
pass
return list(set(ret)) return list(set(ret))
def c_compiler(self): def c_compiler(self):
c_compiler = None c_compiler = None
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [
y.op for y in self.node_order]:
if hasattr(x, 'c_compiler'): if hasattr(x, 'c_compiler'):
x_compiler = x.c_compiler() x_compiler = x.c_compiler()
else: else:
...@@ -662,11 +734,13 @@ class CLinker(link.Linker): ...@@ -662,11 +734,13 @@ class CLinker(link.Linker):
c_compiler = x_compiler c_compiler = x_compiler
else: else:
if x_compiler and (x_compiler != c_compiler): if x_compiler and (x_compiler != c_compiler):
raise Exception('Nodes have requested specific different compilers', raise Exception('Nodes have requested specific'
(c_compiler, x_compiler)) ' different compilers',
(c_compiler, x_compiler))
if (c_compiler is None): if (c_compiler is None):
return cmodule.GCC_compiler return cmodule.GCC_compiler
else: return c_compiler else:
return c_compiler
def header_dirs(self): def header_dirs(self):
"""WRITEME """WRITEME
...@@ -676,9 +750,12 @@ class CLinker(link.Linker): ...@@ -676,9 +750,12 @@ class CLinker(link.Linker):
The return value will not contain duplicates. The return value will not contain duplicates.
""" """
ret = [] ret = []
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [
try: ret += x.c_header_dirs() y.op for y in self.node_order]:
except utils.MethodNotDefined: pass try:
ret += x.c_header_dirs()
except utils.MethodNotDefined:
pass
return list(set(ret)) return list(set(ret))
def libraries(self): def libraries(self):
...@@ -689,9 +766,12 @@ class CLinker(link.Linker): ...@@ -689,9 +766,12 @@ class CLinker(link.Linker):
The return value will not contain duplicates. The return value will not contain duplicates.
""" """
ret = [] ret = []
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [
try: ret += x.c_libraries() y.op for y in self.node_order]:
except utils.MethodNotDefined: pass try:
ret += x.c_libraries()
except utils.MethodNotDefined:
pass
return list(set(ret)) return list(set(ret))
def lib_dirs(self): def lib_dirs(self):
...@@ -702,12 +782,16 @@ class CLinker(link.Linker): ...@@ -702,12 +782,16 @@ class CLinker(link.Linker):
The return value will not contain duplicates. The return value will not contain duplicates.
""" """
ret = [] ret = []
for x in [y.type for y in self.variables] + [y.op for y in self.node_order]: for x in [y.type for y in self.variables] + [
try: ret += x.c_lib_dirs() y.op for y in self.node_order]:
except utils.MethodNotDefined: pass try:
ret += x.c_lib_dirs()
except utils.MethodNotDefined:
pass
return list(set(ret)) return list(set(ret))
def __compile__(self, input_storage = None, output_storage = None, keep_lock=False): def __compile__(self, input_storage=None,
output_storage=None, keep_lock=False):
"""WRITEME """WRITEME
Compiles this linker's env. Compiles this linker's env.
...@@ -737,33 +821,37 @@ class CLinker(link.Linker): ...@@ -737,33 +821,37 @@ class CLinker(link.Linker):
input_storage, input_storage,
output_storage, output_storage,
keep_lock=keep_lock) keep_lock=keep_lock)
return thunk, \ return (thunk,
[link.Container(input, storage) for input, storage in izip(self.env.inputs, input_storage)], \ [link.Container(input, storage) for input, storage in
[link.Container(output, storage, True) for output, storage in izip(self.env.outputs, output_storage)], \ izip(self.env.inputs, input_storage)],
error_storage [link.Container(output, storage, True) for output, storage in
izip(self.env.outputs, output_storage)],
error_storage)
def get_init_tasks(self): def get_init_tasks(self):
init_tasks = [] init_tasks = []
tasks = [] tasks = []
id=1 id = 1
for v in self.variables: for v in self.variables:
if v in self.consts: if v in self.consts:
continue continue
if v in self.orphans and isinstance(v, graph.Constant): if v in self.orphans and isinstance(v, graph.Constant):
try: try:
v.type.c_literal(v.data) #constant will be inlined, no need to get # constant will be inlined, no need to get
v.type.c_literal(v.data)
continue continue
except (utils.MethodNotDefined, NotImplementedError): except (utils.MethodNotDefined, NotImplementedError):
pass pass
init_tasks.append((v, 'init', id)) init_tasks.append((v, 'init', id))
tasks.append((v, 'get', id+1)) tasks.append((v, 'get', id + 1))
id += 2 id += 2
for node in self.node_order: for node in self.node_order:
tasks.append((node, 'code', id)) tasks.append((node, 'code', id))
id += 1 id += 1
return init_tasks, tasks return init_tasks, tasks
def make_thunk(self, input_storage = None, output_storage = None, keep_lock=False): def make_thunk(self, input_storage=None, output_storage=None,
keep_lock=False):
"""WRITEME """WRITEME
Compiles this linker's env and returns a function to perform the Compiles this linker's env and returns a function to perform the
computations, as well as lists of storage cells for both the computations, as well as lists of storage cells for both the
...@@ -787,16 +875,18 @@ class CLinker(link.Linker): ...@@ -787,16 +875,18 @@ class CLinker(link.Linker):
first_output = ostor[0].data first_output = ostor[0].data
""" """
init_tasks, tasks = self.get_init_tasks() init_tasks, tasks = self.get_init_tasks()
cthunk, in_storage, out_storage, error_storage = self.__compile__(input_storage, output_storage, cthunk, in_storage, out_storage, error_storage = self.__compile__(
keep_lock=keep_lock) input_storage, output_storage,
res = _CThunk(cthunk, init_tasks, tasks, error_storage), in_storage, out_storage keep_lock=keep_lock)
return res
res = _CThunk(cthunk, init_tasks, tasks, error_storage)
return res, in_storage, out_storage
def cmodule_key(self): def cmodule_key(self):
"""Return a complete hashable signature of the module we compiled. """Return a complete hashable signature of the module we compiled.
This function must have the property that no two programs that compute different things This function must have the property that no two programs that
yield the same key. compute different things yield the same key.
The key returned by this function is of the form (version, signature) The key returned by this function is of the form (version, signature)
The signature has the following form: The signature has the following form:
...@@ -817,8 +907,9 @@ class CLinker(link.Linker): ...@@ -817,8 +907,9 @@ class CLinker(link.Linker):
It is followed by elements for every node in the It is followed by elements for every node in the
topological ordering of `self.env`. topological ordering of `self.env`.
If the Op of any Apply in the Env does not have c_code_cache_ok()==True, then this If the Op of any Apply in the Env does not have
function raises a KeyError exception. c_code_cache_ok()==True, then this function raises a KeyError
exception.
Input Signature Input Signature
--------------- ---------------
...@@ -865,6 +956,7 @@ class CLinker(link.Linker): ...@@ -865,6 +956,7 @@ class CLinker(link.Linker):
libraries=self.libraries(), libraries=self.libraries(),
header_dirs=self.header_dirs(), header_dirs=self.header_dirs(),
) )
@staticmethod @staticmethod
def cmodule_key_(env, no_recycling, compile_args=[], libraries=[], def cmodule_key_(env, no_recycling, compile_args=[], libraries=[],
header_dirs=[], insert_config_md5=True): header_dirs=[], insert_config_md5=True):
...@@ -876,13 +968,14 @@ class CLinker(link.Linker): ...@@ -876,13 +968,14 @@ class CLinker(link.Linker):
#set of variables that have been computed by nodes we have #set of variables that have been computed by nodes we have
# seen 'so far' in the loop below # seen 'so far' in the loop below
env_computed_set = set() env_computed_set = set()
env_inputs_dict = dict((i, (-1, pos)) for pos, i in enumerate(env.inputs)) env_inputs_dict = dict((i, (-1, pos)) for pos, i in
enumerate(env.inputs))
constant_ids = dict() constant_ids = dict()
op_pos = {} # Apply -> topological position op_pos = {} # Apply -> topological position
# First we put the header, compile_args, library names and config md5 # First we put the header, compile_args, library names and config md5
# into the signature. # into the signature.
sig = ['CLinker.cmodule_key'] # will be cast to tuple on return sig = ['CLinker.cmodule_key'] # will be cast to tuple on return
if compile_args is not None: if compile_args is not None:
# We must sort it as the order from a set is not guaranteed. # We must sort it as the order from a set is not guaranteed.
# In particular, 2 sets with the same content can give different # In particular, 2 sets with the same content can give different
...@@ -912,6 +1005,7 @@ class CLinker(link.Linker): ...@@ -912,6 +1005,7 @@ class CLinker(link.Linker):
sig.append('md5: <omitted>') sig.append('md5: <omitted>')
error_on_play = [False] error_on_play = [False]
def in_sig(i, topological_pos, i_idx): def in_sig(i, topological_pos, i_idx):
# assert that every input to every node is one of' # assert that every input to every node is one of'
# - an env input # - an env input
...@@ -920,7 +1014,7 @@ class CLinker(link.Linker): ...@@ -920,7 +1014,7 @@ class CLinker(link.Linker):
# It is important that a variable (i) # It is important that a variable (i)
# yield a 'position' that reflects its role in code_gen() # yield a 'position' that reflects its role in code_gen()
if isinstance(i, graph.Constant): #orphans if isinstance(i, graph.Constant): # orphans
if id(i) not in constant_ids: if id(i) not in constant_ids:
isig = (i.signature(), topological_pos, i_idx) isig = (i.signature(), topological_pos, i_idx)
# If the Theano constant provides a strong hash # If the Theano constant provides a strong hash
...@@ -933,7 +1027,8 @@ class CLinker(link.Linker): ...@@ -933,7 +1027,8 @@ class CLinker(link.Linker):
isig = (isig[0].theano_hash(), topological_pos, i_idx) isig = (isig[0].theano_hash(), topological_pos, i_idx)
try: try:
hash(isig) hash(isig)
except Exception: #generic constants don't have a hashable signature except Exception:
#generic constants don't have a hashable signature
error_on_play[0] = True error_on_play[0] = True
return None return None
constant_ids[id(i)] = isig constant_ids[id(i)] = isig
...@@ -941,20 +1036,22 @@ class CLinker(link.Linker): ...@@ -941,20 +1036,22 @@ class CLinker(link.Linker):
isig = constant_ids[id(i)] isig = constant_ids[id(i)]
#print 'SIGNATURE', i.signature() #print 'SIGNATURE', i.signature()
#return i.signature() #return i.signature()
elif i in env_inputs_dict: #inputs elif i in env_inputs_dict: # inputs
isig = env_inputs_dict[i] isig = env_inputs_dict[i]
else: else:
if i.owner is None: if i.owner is None:
assert all( all(out is not None for out in o.outputs) for o in order) assert all(all(out is not None for out in o.outputs)
assert all( input.owner is None for input in env.inputs) for o in order)
raise Exception('what is this?', (i, type(i), i.clients, env)) assert all(input.owner is None for input in env.inputs)
raise Exception('what is this?', (i, type(i), i.clients,
env))
if i in env.outputs: if i in env.outputs:
isig = (op_pos[i.owner], # outputs isig = (op_pos[i.owner], # outputs
i.owner.outputs.index(i), i.owner.outputs.index(i),
env.outputs.index(i)) env.outputs.index(i))
else: else:
isig = (op_pos[i.owner], i.owner.outputs.index(i)) # temps isig = (op_pos[i.owner], i.owner.outputs.index(i)) # temps
return (isig, i in no_recycling) return (isig, i in no_recycling)
version = [] version = []
...@@ -973,7 +1070,7 @@ class CLinker(link.Linker): ...@@ -973,7 +1070,7 @@ class CLinker(link.Linker):
sig.append(( sig.append((
node.op, node.op,
tuple((i.type, in_sig(i, node_pos, ipos)) tuple((i.type, in_sig(i, node_pos, ipos))
for ipos,i in enumerate(node.inputs)), for ipos, i in enumerate(node.inputs)),
tuple(o in no_recycling for o in node.outputs))) tuple(o in no_recycling for o in node.outputs)))
if error_on_play[0]: if error_on_play[0]:
...@@ -1026,7 +1123,9 @@ class CLinker(link.Linker): ...@@ -1026,7 +1123,9 @@ class CLinker(link.Linker):
if compiler_name == 'NVCC_compiler' and config.lib.amdlibm: if compiler_name == 'NVCC_compiler' and config.lib.amdlibm:
# This lib does not work correctly with nvcc in device code. # This lib does not work correctly with nvcc in device code.
# and newer version of g++ as 4.5.1. # and newer version of g++ as 4.5.1.
# example of errors: "/usr/lib/gcc/x86_64-redhat-linux/4.5.1/include/mmintrin.h(49): error: identifier "__builtin_ia32_emms" is undefined" # example of errors: "/usr/lib/gcc/x86_64-redhat-linux/4.5.1/
# include/mmintrin.h(49): error: identifier
# "__builtin_ia32_emms" is undefined"
if '<amdlibm.h>' in mod.includes: if '<amdlibm.h>' in mod.includes:
mod.includes.remove('<amdlibm.h>') mod.includes.remove('<amdlibm.h>')
...@@ -1057,7 +1156,8 @@ class CLinker(link.Linker): ...@@ -1057,7 +1156,8 @@ class CLinker(link.Linker):
yield module yield module
def build_dynamic_module(self): def build_dynamic_module(self):
"""Return a cmodule.DynamicModule instance full of the code for our env. """Return a cmodule.DynamicModule instance full of the code
for our env.
""" """
self.code_gen() self.code_gen()
module_name = self.hash module_name = self.hash
...@@ -1065,13 +1165,16 @@ class CLinker(link.Linker): ...@@ -1065,13 +1165,16 @@ class CLinker(link.Linker):
mod = cmodule.DynamicModule(module_name) mod = cmodule.DynamicModule(module_name)
# The code of instantiate # The code of instantiate
code = self.instantiate_code(1+len(self.args)) #the 1 is for error_storage # the 1 is for error_storage
instantiate = cmodule.ExtFunction('instantiate', code, method=cmodule.METH_VARARGS) code = self.instantiate_code(1 + len(self.args))
instantiate = cmodule.ExtFunction('instantiate', code,
method=cmodule.METH_VARARGS)
#['error_storage'] + argnames, #['error_storage'] + argnames,
#local_dict = d, #local_dict = d,
#global_dict = {}) #global_dict = {})
# Static methods that can run and destroy the struct built by instantiate. # Static methods that can run and destroy the struct built by
# instantiate.
static = """ static = """
int %(struct_name)s_executor(%(struct_name)s* self) { int %(struct_name)s_executor(%(struct_name)s* self) {
return self->run(); return self->run();
...@@ -1086,7 +1189,7 @@ class CLinker(link.Linker): ...@@ -1086,7 +1189,7 @@ class CLinker(link.Linker):
//printf("done cleanup\\n"); //printf("done cleanup\\n");
//fflush(stdout); //fflush(stdout);
} }
""" % dict(struct_name = self.struct_name) """ % dict(struct_name=self.struct_name)
# We add all the support code, compile args, headers and libs we need. # We add all the support code, compile args, headers and libs we need.
for support_code in self.support_code() + self.c_support_code_apply: for support_code in self.support_code() + self.c_support_code_apply:
...@@ -1100,7 +1203,8 @@ class CLinker(link.Linker): ...@@ -1100,7 +1203,8 @@ class CLinker(link.Linker):
return mod return mod
def cthunk_factory(self, error_storage, in_storage, out_storage, keep_lock=False): def cthunk_factory(self, error_storage, in_storage, out_storage,
keep_lock=False):
"""WRITEME """WRITEME
error_storage -> list of length 3 error_storage -> list of length 3
in_storage -> list of lists of length 1, one per input in_storage -> list of lists of length 1, one per input
...@@ -1120,18 +1224,22 @@ class CLinker(link.Linker): ...@@ -1120,18 +1224,22 @@ class CLinker(link.Linker):
# If we can't get a key, then forget the cache mechanism. # If we can't get a key, then forget the cache mechanism.
module = self.compile_cmodule() module = self.compile_cmodule()
else: else:
module = get_module_cache().module_from_key(key=key, fn=self.compile_cmodule_by_step, keep_lock=keep_lock) module = get_module_cache().module_from_key(
key=key, fn=self.compile_cmodule_by_step, keep_lock=keep_lock)
vars = self.inputs + self.outputs + self.orphans vars = self.inputs + self.outputs + self.orphans
# List of indices that should be ignored when passing the arguments # List of indices that should be ignored when passing the arguments
# (basically, everything that the previous call to uniq eliminated) # (basically, everything that the previous call to uniq eliminated)
dupidx = [i for i, x in enumerate(vars) if vars.count(x) > 1 and vars.index(x) != i] dupidx = [i for i, x in enumerate(vars)
if vars.count(x) > 1 and vars.index(x) != i]
out_storage = [x for i, x in enumerate(out_storage) if (i+len(in_storage)) not in dupidx] out_storage = [x for i, x in enumerate(out_storage)
if (i + len(in_storage)) not in dupidx]
in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx] in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx]
orphd = [[orphan.data] for orphan in self.orphans] orphd = [[orphan.data] for orphan in self.orphans]
ret = module.instantiate(error_storage, *(in_storage + out_storage + orphd)) ret = module.instantiate(error_storage, *(in_storage + out_storage +
orphd))
return ret return ret
...@@ -1150,6 +1258,7 @@ class CLinker(link.Linker): ...@@ -1150,6 +1258,7 @@ class CLinker(link.Linker):
print >> code, " return thunk; }" print >> code, " return thunk; }"
return code.getvalue() return code.getvalue()
class _CThunk(object): class _CThunk(object):
""" """
A thunk with a C implementation A thunk with a C implementation
...@@ -1181,7 +1290,8 @@ class _CThunk(object): ...@@ -1181,7 +1290,8 @@ class _CThunk(object):
n = len(self.init_tasks) n = len(self.init_tasks)
# note that the failure code is distributed in two lists # note that the failure code is distributed in two lists
if failure_code < 2 * n: if failure_code < 2 * n:
return [self.init_tasks, self.tasks][failure_code % 2][failure_code/2] return [self.init_tasks, self.tasks][
failure_code % 2][failure_code / 2]
else: else:
return self.tasks[failure_code - n] return self.tasks[failure_code - n]
...@@ -1199,19 +1309,16 @@ class _CThunk(object): ...@@ -1199,19 +1309,16 @@ class _CThunk(object):
exc_value = exc_type(_exc_value, task, task.outputs) exc_value = exc_type(_exc_value, task, task.outputs)
else: else:
exc_value = exc_type(_exc_value, task) exc_value = exc_type(_exc_value, task)
exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared # this can be used to retrieve the location the Op was declared
exc_value.__thunk_trace__ = trace
except Exception: except Exception:
print >> sys.stderr, 'ERROR retrieving error_storage', self.error_storage print >> sys.stderr, 'ERROR retrieving error_storage',
print >> sys.stderr, self.error_storage
raise raise
raise exc_type, exc_value, exc_trace raise exc_type, exc_value, exc_trace
class OpWiseCLinker(link.LocalLinker): class OpWiseCLinker(link.LocalLinker):
"""WRITEME """WRITEME
Uses CLinker on the individual Ops that comprise an env and loops Uses CLinker on the individual Ops that comprise an env and loops
...@@ -1227,27 +1334,30 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1227,27 +1334,30 @@ class OpWiseCLinker(link.LocalLinker):
If a Variable is in no_recycling, CLinker will clear the output storage If a Variable is in no_recycling, CLinker will clear the output storage
associated to it prior to computation (to avoid reusing it). associated to it prior to computation (to avoid reusing it).
:note: This is in a sense the 'default' linker for Theano. The overhead of using the :note: This is in a sense the 'default' linker for Theano. The
OpWiseCLinker as compared with the CLinker is only noticeable for graphs of very small overhead of using the OpWiseCLinker as compared with the CLinker
tensors (such as 20 elements or less) is only noticeable for graphs of very small tensors (such as 20
elements or less)
""" """
__cache__ = {} __cache__ = {}
def __init__(self, def __init__(self,
fallback_on_perform = True, fallback_on_perform=True,
allow_gc = True, allow_gc=True,
nice_errors = True): nice_errors=True):
self.env = None self.env = None
self.fallback_on_perform = fallback_on_perform self.fallback_on_perform = fallback_on_perform
self.nice_errors = nice_errors self.nice_errors = nice_errors
self.allow_gc = allow_gc self.allow_gc = allow_gc
def accept(self, env, no_recycling = []): def accept(self, env, no_recycling=[]):
if self.env is not None and self.env is not env: if self.env is not None and self.env is not env:
return type(self)(self.fallback_on_perform).accept(env, no_recycling) return type(self)(self.fallback_on_perform).accept(env,
#raise Exception("Cannot accept from a Linker that is already tied to another Env.") no_recycling)
#raise Exception("Cannot accept from a Linker that is
#already tied to another Env.")
self.env = env self.env = env
self.no_recycling = no_recycling self.no_recycling = no_recycling
return self return self
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论