提交 816a83e3 authored 作者: nouiz's avatar nouiz

Merge pull request #1271 from lamblin/arch_in_cache

Always put arch bitwidth in cache
...@@ -364,7 +364,7 @@ import theano and print the config variable, as in: ...@@ -364,7 +364,7 @@ import theano and print the config variable, as in:
.. attribute:: compiledir_format .. attribute:: compiledir_format
Default: "compiledir_%(platform)s-%(processor)s-%(python_version)s" Default: "compiledir_%(platform)s-%(processor)s-%(python_version)s-%(python_bitwidth)s"
This is a Python format string that specifies the subdirectory This is a Python format string that specifies the subdirectory
of ``config.base_compiledir`` in which to store platform-dependent of ``config.base_compiledir`` in which to store platform-dependent
......
...@@ -548,7 +548,7 @@ class Test_pfunc(unittest.TestCase): ...@@ -548,7 +548,7 @@ class Test_pfunc(unittest.TestCase):
def test_default_updates_input(self): def test_default_updates_input(self):
x = shared(0) x = shared(0)
y = shared(1) y = shared(1)
if theano.gof.cmodule.python_int_bitwidth() == 32: if theano.gof.python_int_bitwidth() == 32:
a = iscalar('a') a = iscalar('a')
else: else:
a = lscalar('a') a = lscalar('a')
......
...@@ -18,7 +18,7 @@ class Test_SharedVariable(unittest.TestCase): ...@@ -18,7 +18,7 @@ class Test_SharedVariable(unittest.TestCase):
assert shared(7, dtype='float64').type == Scalar('float64') assert shared(7, dtype='float64').type == Scalar('float64')
else: else:
if theano.gof.cmodule.python_int_bitwidth() == 32: if theano.gof.python_int_bitwidth() == 32:
assert shared(7).type == theano.tensor.iscalar, shared(7).type assert shared(7).type == theano.tensor.iscalar, shared(7).type
else: else:
assert shared(7).type == theano.tensor.lscalar, shared(7).type assert shared(7).type == theano.tensor.lscalar, shared(7).type
......
...@@ -38,7 +38,9 @@ e-mail thread "What is gof?" ...@@ -38,7 +38,9 @@ e-mail thread "What is gof?"
from theano.gof.cc import \ from theano.gof.cc import \
CLinker, OpWiseCLinker, DualLinker CLinker, OpWiseCLinker, DualLinker
import theano.gof.compiledir # adds config vars # Also adds config vars
from theano.gof.compiledir import \
local_bitwidth, python_int_bitwidth
from theano.gof.fg import \ from theano.gof.fg import \
InconsistencyError, MissingInputError, FunctionGraph InconsistencyError, MissingInputError, FunctionGraph
...@@ -77,4 +79,3 @@ from theano.gof.type import \ ...@@ -77,4 +79,3 @@ from theano.gof.type import \
from theano.gof.utils import \ from theano.gof.utils import \
object2, MethodNotDefined object2, MethodNotDefined
...@@ -8,7 +8,6 @@ import os ...@@ -8,7 +8,6 @@ import os
import shutil import shutil
import stat import stat
import StringIO import StringIO
import struct
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
...@@ -27,7 +26,7 @@ from theano.misc.windows import call_subprocess_Popen ...@@ -27,7 +26,7 @@ from theano.misc.windows import call_subprocess_Popen
# we will abuse the lockfile mechanism when reading and writing the registry # we will abuse the lockfile mechanism when reading and writing the registry
from theano.gof import compilelock from theano.gof import compilelock
from theano.gof.compiledir import gcc_version_str from theano.gof.compiledir import gcc_version_str, local_bitwidth
from theano.configparser import AddConfigVar, BoolParam from theano.configparser import AddConfigVar, BoolParam
...@@ -55,29 +54,6 @@ AddConfigVar('cmodule.compilation_warning', ...@@ -55,29 +54,6 @@ AddConfigVar('cmodule.compilation_warning',
BoolParam(False)) BoolParam(False))
def local_bitwidth():
"""
Return 32 for 32bit arch, 64 for 64bit arch
By "architecture", we mean the size of memory pointers (size_t in C),
*not* the size of long int, as it can be different.
"""
# Note that according to Python documentation, `platform.architecture()` is
# not reliable on OS X with universal binaries.
# Also, sys.maxsize does not exist in Python < 2.6.
# 'P' denotes a void*, and the size is expressed in bytes.
return struct.calcsize('P') * 8
def python_int_bitwidth():
"""
Return the bit width of Python int (C long int).
Note that it can be different from the size of a memory pointer.
"""
# 'l' denotes a C long int, and the size is expressed in bytes.
return struct.calcsize('l') * 8
_logger = logging.getLogger("theano.gof.cmodule") _logger = logging.getLogger("theano.gof.cmodule")
_logger.setLevel(logging.WARNING) _logger.setLevel(logging.WARNING)
...@@ -176,14 +152,14 @@ static struct PyModuleDef moduledef = {{ ...@@ -176,14 +152,14 @@ static struct PyModuleDef moduledef = {{
}}; }};
""".format(name=self.name) """.format(name=self.name)
print >> stream, "PyMODINIT_FUNC PyInit_%s(void) {" % self.name print >> stream, "PyMODINIT_FUNC PyInit_%s(void) {" % self.name
for b in self.init_blocks: for block in self.init_blocks:
print >> stream, ' ', b print >> stream, ' ', block
print >> stream, " PyObject *m = PyModule_Create(&moduledef);" print >> stream, " PyObject *m = PyModule_Create(&moduledef);"
print >> stream, " return m;" print >> stream, " return m;"
else: else:
print >> stream, "PyMODINIT_FUNC init%s(void){" % self.name print >> stream, "PyMODINIT_FUNC init%s(void){" % self.name
for b in self.init_blocks: for block in self.init_blocks:
print >> stream, ' ', b print >> stream, ' ', block
print >> stream, ' ', ('(void) Py_InitModule("%s", MyMethods);' print >> stream, ' ', ('(void) Py_InitModule("%s", MyMethods);'
% self.name) % self.name)
print >> stream, "}" print >> stream, "}"
...@@ -1564,7 +1540,8 @@ class GCC_compiler(object): ...@@ -1564,7 +1540,8 @@ class GCC_compiler(object):
lines = stdout + stderr lines = stdout + stderr
return lines return lines
# The '-' at the end is needed. Otherwise, g++ do not output enough information. # The '-' at the end is needed. Otherwise, g++ do not output
# enough information.
native_lines = get_lines("g++ -march=native -E -v -") native_lines = get_lines("g++ -march=native -E -v -")
_logger.info("g++ -march=native selected lines: %s", native_lines) _logger.info("g++ -march=native selected lines: %s", native_lines)
if len(native_lines) != 1: if len(native_lines) != 1:
...@@ -1619,6 +1596,39 @@ class GCC_compiler(object): ...@@ -1619,6 +1596,39 @@ class GCC_compiler(object):
cxxflags.append("-D NPY_ARRAY_UPDATE_ALL=NPY_UPDATE_ALL") cxxflags.append("-D NPY_ARRAY_UPDATE_ALL=NPY_UPDATE_ALL")
cxxflags.append("-D NPY_ARRAY_C_CONTIGUOUS=NPY_C_CONTIGUOUS") cxxflags.append("-D NPY_ARRAY_C_CONTIGUOUS=NPY_C_CONTIGUOUS")
cxxflags.append("-D NPY_ARRAY_F_CONTIGUOUS=NPY_F_CONTIGUOUS") cxxflags.append("-D NPY_ARRAY_F_CONTIGUOUS=NPY_F_CONTIGUOUS")
# Platform-specific flags.
# We put them here, rather than in compile_str(), so they en up
# in the key of the compiled module, avoiding potential conflicts.
# Figure out whether the current Python executable is 32
# or 64 bit and compile accordingly.
n_bits = local_bitwidth()
cxxflags.append('-m%d' % n_bits)
_logger.debug("Compiling for %s bit architecture", n_bits)
if sys.platform != 'win32':
# Under Windows it looks like fPIC is useless. Compiler warning:
# '-fPIC ignored for target (all code is position independent)'
cxxflags.append('-fPIC')
if sys.platform == 'win32' and local_bitwidth() == 64:
# Under 64-bit Windows installation, sys.platform is 'win32'.
# We need to define MS_WIN64 for the preprocessor to be able to
# link with libpython.
cxxflags.append('-DMS_WIN64')
#DSE Patch 1 for supporting OSX frameworks; add -framework Python
if sys.platform == 'darwin':
cxxflags.extend(['-undefined', 'dynamic_lookup'])
python_inc = distutils.sysconfig.get_python_inc()
# link with the framework library *if specifically requested*
# config.mac_framework_link is by default False, since on some mac
# installs linking with -framework causes a Bus Error
if (python_inc.count('Python.framework') > 0 and
config.cmodule.mac_framework_link):
cxxflags.extend(['-framework', 'Python'])
return cxxflags return cxxflags
@staticmethod @staticmethod
...@@ -1744,40 +1754,10 @@ class GCC_compiler(object): ...@@ -1744,40 +1754,10 @@ class GCC_compiler(object):
else: else:
preargs = list(preargs) preargs = list(preargs)
if sys.platform != 'win32':
# Under Windows it looks like fPIC is useless. Compiler warning:
# '-fPIC ignored for target (all code is position independent)'
preargs.append('-fPIC')
if sys.platform == 'win32' and local_bitwidth() == 64:
# Under 64-bit Windows installation, sys.platform is 'win32'.
# We need to define MS_WIN64 for the preprocessor to be able to
# link with libpython.
preargs.append('-DMS_WIN64')
# We also add "-m64", in case the installed gcc is 32-bit
preargs.append('-m64')
include_dirs = include_dirs + std_include_dirs() include_dirs = include_dirs + std_include_dirs()
libs = std_libs() + libs libs = std_libs() + libs
lib_dirs = std_lib_dirs() + lib_dirs lib_dirs = std_lib_dirs() + lib_dirs
#DSE Patch 1 for supporting OSX frameworks; add -framework Python
if sys.platform == 'darwin':
preargs.extend(['-undefined', 'dynamic_lookup'])
python_inc = distutils.sysconfig.get_python_inc()
# link with the framework library *if specifically requested*
# config.mac_framework_link is by default False, since on some mac
# installs linking with -framework causes a Bus Error
if (python_inc.count('Python.framework') > 0 and
config.cmodule.mac_framework_link):
preargs.extend(['-framework', 'Python'])
# Figure out whether the current Python executable is 32
# or 64 bit and compile accordingly.
n_bits = local_bitwidth()
preargs.extend(['-m%s' % n_bits])
_logger.debug("OS X: compiling for %s bit architecture", n_bits)
# sometimes, the linker cannot find -lpython so we need to tell it # sometimes, the linker cannot find -lpython so we need to tell it
# explicitly where it is located # explicitly where it is located
# this returns somepath/lib/python2.x # this returns somepath/lib/python2.x
......
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
import platform import platform
import re import re
import shutil import shutil
import struct
import subprocess import subprocess
import sys import sys
import textwrap import textwrap
...@@ -32,16 +33,44 @@ except OSError: ...@@ -32,16 +33,44 @@ except OSError:
del p del p
del dummy_err del dummy_err
compiledir_format_dict = {"platform": platform.platform(),
def local_bitwidth():
"""
Return 32 for 32bit arch, 64 for 64bit arch
By "architecture", we mean the size of memory pointers (size_t in C),
*not* the size of long int, as it can be different.
"""
# Note that according to Python documentation, `platform.architecture()` is
# not reliable on OS X with universal binaries.
# Also, sys.maxsize does not exist in Python < 2.6.
# 'P' denotes a void*, and the size is expressed in bytes.
return struct.calcsize('P') * 8
def python_int_bitwidth():
"""
Return the bit width of Python int (C long int).
Note that it can be different from the size of a memory pointer.
"""
# 'l' denotes a C long int, and the size is expressed in bytes.
return struct.calcsize('l') * 8
compiledir_format_dict = {
"platform": platform.platform(),
"processor": platform.processor(), "processor": platform.processor(),
"python_version": platform.python_version(), "python_version": platform.python_version(),
"python_bitwidth": local_bitwidth(),
"python_int_bitwidth": python_int_bitwidth(),
"theano_version": theano.__version__, "theano_version": theano.__version__,
"numpy_version": numpy.__version__, "numpy_version": numpy.__version__,
"gxx_version": gcc_version_str.replace(" ", "_"), "gxx_version": gcc_version_str.replace(" ", "_"),
} }
compiledir_format_keys = ", ".join(sorted(compiledir_format_dict.keys())) compiledir_format_keys = ", ".join(sorted(compiledir_format_dict.keys()))
default_compiledir_format =\ default_compiledir_format = ("compiledir_%(platform)s-%(processor)s-"
"compiledir_%(platform)s-%(processor)s-%(python_version)s" "%(python_version)s-%(python_bitwidth)s")
AddConfigVar("compiledir_format", AddConfigVar("compiledir_format",
textwrap.fill(textwrap.dedent("""\ textwrap.fill(textwrap.dedent("""\
......
import re
# import op import traceback
# import variable
from theano import config from theano import config
import re, traceback
def add_tag_trace(thing): def add_tag_trace(thing):
"""Add tag.trace to an node or variable. """Add tag.trace to an node or variable.
...@@ -11,15 +10,18 @@ def add_tag_trace(thing): ...@@ -11,15 +10,18 @@ def add_tag_trace(thing):
The argument is returned after being affected (inplace). The argument is returned after being affected (inplace).
""" """
limit = config.traceback.limit limit = config.traceback.limit
if limit == -1: limit = None if limit == -1:
limit = None
thing.tag.trace = traceback.extract_stack(limit=limit)[:-1] thing.tag.trace = traceback.extract_stack(limit=limit)[:-1]
return thing return thing
def hashgen(): def hashgen():
hashgen.next += 1 hashgen.next += 1
return hashgen.next return hashgen.next
hashgen.next = 0 hashgen.next = 0
class MethodNotDefined(Exception): class MethodNotDefined(Exception):
""" """
To be raised by functions defined as part of an interface. To be raised by functions defined as part of an interface.
...@@ -28,6 +30,7 @@ class MethodNotDefined(Exception): ...@@ -28,6 +30,7 @@ class MethodNotDefined(Exception):
function has been left out of an implementation class. function has been left out of an implementation class.
""" """
class object2(object): class object2(object):
__slots__ = [] __slots__ = []
if 0: if 0:
...@@ -36,23 +39,30 @@ class object2(object): ...@@ -36,23 +39,30 @@ class object2(object):
if hasattr(self, '__eq__') or hasattr(self, '__cmp__'): if hasattr(self, '__eq__') or hasattr(self, '__cmp__'):
raise TypeError("unhashable object: %s" % self) raise TypeError("unhashable object: %s" % self)
return id(self) return id(self)
def __ne__(self, other): def __ne__(self, other):
return not self == other return not self == other
class scratchpad: class scratchpad:
def clear(self): def clear(self):
self.__dict__.clear() self.__dict__.clear()
def __update__(self, other): def __update__(self, other):
self.__dict__.update(other.__dict__) self.__dict__.update(other.__dict__)
return self return self
def __str__(self): def __str__(self):
return "scratchpad" + str(self.__dict__) return "scratchpad" + str(self.__dict__)
def __repr__(self): def __repr__(self):
return "scratchpad" + str(self.__dict__) return "scratchpad" + str(self.__dict__)
def info(self): def info(self):
print "<theano.gof.utils.scratchpad instance at %i>"%id(self) print "<theano.gof.utils.scratchpad instance at %i>" % id(self)
for k,v in self.__dict__.items(): for k, v in self.__dict__.items():
print " %s: %s" % (k,v) print " %s: %s" % (k, v)
class D: class D:
def __init__(self, **d): def __init__(self, **d):
...@@ -63,6 +73,7 @@ def memoize(f): ...@@ -63,6 +73,7 @@ def memoize(f):
"""Cache the return value for each tuple of arguments """Cache the return value for each tuple of arguments
(which must be hashable) """ (which must be hashable) """
cache = {} cache = {}
def rval(*args, **kwargs): def rval(*args, **kwargs):
kwtup = tuple(kwargs.items()) kwtup = tuple(kwargs.items())
key = (args, kwtup) key = (args, kwtup)
...@@ -72,8 +83,8 @@ def memoize(f): ...@@ -72,8 +83,8 @@ def memoize(f):
else: else:
val = cache[key] val = cache[key]
return val return val
return rval
return rval
def deprecated(filename, msg=''): def deprecated(filename, msg=''):
...@@ -92,6 +103,7 @@ def deprecated(filename, msg=''): ...@@ -92,6 +103,7 @@ def deprecated(filename, msg=''):
""" """
def _deprecated(f): def _deprecated(f):
printme = [True] printme = [True]
def g(*args, **kwargs): def g(*args, **kwargs):
if printme[0]: if printme[0]:
print 'WARNING: %s.%s deprecated. %s'\ print 'WARNING: %s.%s deprecated. %s'\
...@@ -99,12 +111,16 @@ def deprecated(filename, msg=''): ...@@ -99,12 +111,16 @@ def deprecated(filename, msg=''):
printme[0] = False printme[0] = False
return f(*args, **kwargs) return f(*args, **kwargs)
return g return g
return _deprecated return _deprecated
def uniq(seq): def uniq(seq):
#TODO: consider building a set out of seq so that the if condition is constant time -JB #TODO: consider building a set out of seq so that the if condition
#is constant time -JB
return [x for i, x in enumerate(seq) if seq.index(x) == i] return [x for i, x in enumerate(seq) if seq.index(x) == i]
def difference(seq1, seq2): def difference(seq1, seq2):
""" """
Returns all elements in seq1 which are not in seq2: i.e seq1\seq2 Returns all elements in seq1 which are not in seq2: i.e seq1\seq2
...@@ -132,13 +148,16 @@ def partition(f, seq): ...@@ -132,13 +148,16 @@ def partition(f, seq):
seqf.append(elem) seqf.append(elem)
return seqt, seqf return seqt, seqf
def attr_checker(*attrs): def attr_checker(*attrs):
def f(candidate): def f(candidate):
for attr in attrs: for attr in attrs:
if not hasattr(candidate, attr): if not hasattr(candidate, attr):
return False return False
return True return True
f.__doc__ = "Checks that the candidate has the following attributes: %s" % ", ".join(["'%s'"%attr for attr in attrs])
f.__doc__ = ("Checks that the candidate has the following attributes: %s"
% ", ".join(["'%s'" % attr for attr in attrs]))
return f return f
...@@ -149,7 +168,6 @@ def all_bases(cls, accept): ...@@ -149,7 +168,6 @@ def all_bases(cls, accept):
return [cls for cls in rval if accept(cls)] return [cls for cls in rval if accept(cls)]
def all_bases_collect(cls, raw_name): def all_bases_collect(cls, raw_name):
rval = set() rval = set()
name = "__%s__" % raw_name name = "__%s__" % raw_name
...@@ -162,7 +180,7 @@ def all_bases_collect(cls, raw_name): ...@@ -162,7 +180,7 @@ def all_bases_collect(cls, raw_name):
return rval return rval
def camelcase_to_separated(string, sep = "_"): def camelcase_to_separated(string, sep="_"):
return re.sub('(.)([A-Z])', '\\1%s\\2' % sep, string).lower() return re.sub('(.)([A-Z])', '\\1%s\\2' % sep, string).lower()
...@@ -172,6 +190,7 @@ def to_return_values(values): ...@@ -172,6 +190,7 @@ def to_return_values(values):
else: else:
return values return values
def from_return_values(values): def from_return_values(values):
if isinstance(values, (list, tuple)): if isinstance(values, (list, tuple)):
return values return values
...@@ -186,7 +205,8 @@ class ClsInit(type): ...@@ -186,7 +205,8 @@ class ClsInit(type):
Validate and initialize the L{Op} subclass 'cls' Validate and initialize the L{Op} subclass 'cls'
This function: This function:
- changes class attributes input_names and output_names to be lists if they are single strings. - changes class attributes input_names and output_names to be lists
if they are single strings.
""" """
type.__init__(cls, name, bases, dct) type.__init__(cls, name, bases, dct)
...@@ -195,8 +215,10 @@ class ClsInit(type): ...@@ -195,8 +215,10 @@ class ClsInit(type):
def toposort(prereqs_d): def toposort(prereqs_d):
""" """
Sorts prereqs_d.keys() topologically. prereqs_d[x] contains all the elements Sorts prereqs_d.keys() topologically.
that must come before x in the ordering.
prereqs_d[x] contains all the elements that must come before x
in the ordering.
""" """
# all1 = set(prereqs_d.keys()) # all1 = set(prereqs_d.keys())
...@@ -223,19 +245,26 @@ def toposort(prereqs_d): ...@@ -223,19 +245,26 @@ def toposort(prereqs_d):
if not prereqs_d[postreq].difference(done): if not prereqs_d[postreq].difference(done):
next.add(postreq) next.add(postreq)
if len(prereqs_d) != len(seq): if len(prereqs_d) != len(seq):
raise Exception("Cannot sort topologically: there might be cycles, " + \ raise Exception("Cannot sort topologically: there might be cycles, "
"prereqs_d does not have a key for each element or " + \ "prereqs_d does not have a key for each element or "
"some orderings contain invalid elements.") "some orderings contain invalid elements.")
return seq return seq
def print_for_dot(self): def print_for_dot(self):
#TODO: popen2("dot -Tpng | display") and actually make the graph window pop up #TODO: popen2("dot -Tpng | display") and actually make the graph window
print "digraph unix { size = '6,6'; node [color = lightblue2; style = filled];" #pop up
print ("digraph unix { size = '6,6'; node [color = lightblue2;"
"style = filled];")
for op in self.order: for op in self.order:
for input in op.inputs: for input in op.inputs:
if input.owner: if input.owner:
print input.owner.__class__.__name__ + str(abs(id(input.owner))), " -> ", op.__class__.__name__ + str(abs(id(op))), ";" print ' '.join((
input.owner.__class__.__name__ + str(abs(id(input.owner))),
" -> ",
op.__class__.__name__ + str(abs(id(op))),
";"))
class Keyword: class Keyword:
...@@ -263,9 +292,11 @@ simple_types = (int, float, str, bool, None.__class__, Keyword) ...@@ -263,9 +292,11 @@ simple_types = (int, float, str, bool, None.__class__, Keyword)
ANY_TYPE = Keyword("ANY_TYPE") ANY_TYPE = Keyword("ANY_TYPE")
FALL_THROUGH = Keyword("FALL_THROUGH") FALL_THROUGH = Keyword("FALL_THROUGH")
def comm_guard(type1, type2): def comm_guard(type1, type2):
def wrap(f): def wrap(f):
old_f = f.func_globals[f.__name__] old_f = f.func_globals[f.__name__]
def new_f(arg1, arg2, *rest): def new_f(arg1, arg2, *rest):
if (type1 is ANY_TYPE or isinstance(arg1, type1)) \ if (type1 is ANY_TYPE or isinstance(arg1, type1)) \
and (type2 is ANY_TYPE or isinstance(arg2, type2)): and (type2 is ANY_TYPE or isinstance(arg2, type2)):
...@@ -283,6 +314,7 @@ def comm_guard(type1, type2): ...@@ -283,6 +314,7 @@ def comm_guard(type1, type2):
return variable return variable
new_f.__name__ = f.__name__ new_f.__name__ = f.__name__
def typename(type): def typename(type):
if isinstance(type, Keyword): if isinstance(type, Keyword):
return str(type) return str(type)
...@@ -290,14 +322,19 @@ def comm_guard(type1, type2): ...@@ -290,14 +322,19 @@ def comm_guard(type1, type2):
return "(" + ", ".join([x.__name__ for x in type]) + ")" return "(" + ", ".join([x.__name__ for x in type]) + ")"
else: else:
return type.__name__ return type.__name__
new_f.__doc__ = str(old_f.__doc__) + "\n" + ", ".join([typename(type) for type in (type1, type2)]) + "\n" + str(f.__doc__ or "")
new_f.__doc__ = (str(old_f.__doc__) + "\n" +
", ".join([typename(type) for type in (type1, type2)]) +
"\n" + str(f.__doc__ or ""))
return new_f return new_f
return wrap return wrap
def type_guard(type1): def type_guard(type1):
def wrap(f): def wrap(f):
old_f = f.func_globals[f.__name__] old_f = f.func_globals[f.__name__]
def new_f(arg1, *rest): def new_f(arg1, *rest):
if (type1 is ANY_TYPE or isinstance(arg1, type1)): if (type1 is ANY_TYPE or isinstance(arg1, type1)):
variable = f(arg1, *rest) variable = f(arg1, *rest)
...@@ -308,8 +345,8 @@ def type_guard(type1): ...@@ -308,8 +345,8 @@ def type_guard(type1):
else: else:
return old_f(arg1, *rest) return old_f(arg1, *rest)
new_f.__name__ = f.__name__ new_f.__name__ = f.__name__
def typename(type): def typename(type):
if isinstance(type, Keyword): if isinstance(type, Keyword):
return str(type) return str(type)
...@@ -317,8 +354,12 @@ def type_guard(type1): ...@@ -317,8 +354,12 @@ def type_guard(type1):
return "(" + ", ".join([x.__name__ for x in type]) + ")" return "(" + ", ".join([x.__name__ for x in type]) + ")"
else: else:
return type.__name__ return type.__name__
new_f.__doc__ = str(old_f.__doc__) + "\n" + ", ".join([typename(type) for type in (type1,)]) + "\n" + str(f.__doc__ or "")
new_f.__doc__ = (str(old_f.__doc__) + "\n" +
", ".join([typename(type) for type in (type1,)]) +
"\n" + str(f.__doc__ or ""))
return new_f return new_f
return wrap return wrap
...@@ -331,15 +372,18 @@ def flatten(a): ...@@ -331,15 +372,18 @@ def flatten(a):
else: else:
return [a] return [a]
def unique(x): def unique(x):
return len(set(x)) == len(x) return len(set(x)) == len(x)
def hist(coll): def hist(coll):
counts = {} counts = {}
for elem in coll: for elem in coll:
counts[elem] = counts.get(elem, 0) + 1 counts[elem] = counts.get(elem, 0) + 1
return counts return counts
def give_variables_names(variables): def give_variables_names(variables):
""" Gives unique names to an iterable of variables. Modifies input. """ Gives unique names to an iterable of variables. Modifies input.
...@@ -349,7 +393,7 @@ def give_variables_names(variables): ...@@ -349,7 +393,7 @@ def give_variables_names(variables):
bad_var = lambda var: not var.name or h[var.name] > 1 bad_var = lambda var: not var.name or h[var.name] > 1
for i, var in enumerate(filter(bad_var, variables)): for i, var in enumerate(filter(bad_var, variables)):
var.name = (var.name or "") + "_%d"%i var.name = (var.name or "") + "_%d" % i
if not unique(map(str, variables)): if not unique(map(str, variables)):
raise ValueError("Not all variables have unique names." raise ValueError("Not all variables have unique names."
......
...@@ -9,11 +9,11 @@ import warnings ...@@ -9,11 +9,11 @@ import warnings
import numpy import numpy
import theano from theano.gof import local_bitwidth
from theano.gof.cc import hash_from_file from theano.gof.cc import hash_from_file
from theano.gof.cmodule import (std_libs, std_lib_dirs, from theano.gof.cmodule import (std_libs, std_lib_dirs,
std_include_dirs, dlimport, std_include_dirs, dlimport,
get_lib_extension, local_bitwidth) get_lib_extension)
from theano.gof.python25 import any from theano.gof.python25 import any
from theano.misc.windows import call_subprocess_Popen from theano.misc.windows import call_subprocess_Popen
...@@ -245,8 +245,6 @@ class NVCC_compiler(object): ...@@ -245,8 +245,6 @@ class NVCC_compiler(object):
cppfile = file(cppfilename, 'w') cppfile = file(cppfilename, 'w')
_logger.debug('Writing module C++ code to %s', cppfilename) _logger.debug('Writing module C++ code to %s', cppfilename)
ofiles = []
rval = None
cppfile.write(src_code) cppfile.write(src_code)
cppfile.close() cppfile.close()
......
...@@ -114,12 +114,12 @@ class BinCountOp(theano.Op): ...@@ -114,12 +114,12 @@ class BinCountOp(theano.Op):
# Some dtypes are not supported by numpy's implementation of bincount. # Some dtypes are not supported by numpy's implementation of bincount.
# Until another one is available, we should fail at graph construction # Until another one is available, we should fail at graph construction
# time, not wait for execution. # time, not wait for execution.
int_bitwidth = theano.gof.cmodule.python_int_bitwidth() int_bitwidth = theano.gof.python_int_bitwidth()
if int_bitwidth == 64: if int_bitwidth == 64:
numpy_unsupported_dtypes = ('uint64',) numpy_unsupported_dtypes = ('uint64',)
if int_bitwidth == 32: if int_bitwidth == 32:
numpy_unsupported_dtypes = ('uint32', 'int64', 'uint64') numpy_unsupported_dtypes = ('uint32', 'int64', 'uint64')
intp_bitwidth = theano.gof.cmodule.local_bitwidth() intp_bitwidth = theano.gof.local_bitwidth()
if intp_bitwidth == 32: if intp_bitwidth == 32:
out_type = basic.ivector() out_type = basic.ivector()
elif intp_bitwidth == 64: elif intp_bitwidth == 64:
...@@ -246,7 +246,7 @@ class RepeatOp(theano.Op): ...@@ -246,7 +246,7 @@ class RepeatOp(theano.Op):
# Some dtypes are not supported by numpy's implementation of repeat. # Some dtypes are not supported by numpy's implementation of repeat.
# Until another one is available, we should fail at graph construction # Until another one is available, we should fail at graph construction
# time, not wait for execution. # time, not wait for execution.
int_bitwidth = theano.gof.cmodule.python_int_bitwidth() int_bitwidth = theano.gof.python_int_bitwidth()
if int_bitwidth == 64: if int_bitwidth == 64:
numpy_unsupported_dtypes = ('uint64',) numpy_unsupported_dtypes = ('uint64',)
if int_bitwidth == 32: if int_bitwidth == 32:
...@@ -259,7 +259,7 @@ class RepeatOp(theano.Op): ...@@ -259,7 +259,7 @@ class RepeatOp(theano.Op):
% numpy_unsupported_dtypes), repeats.dtype) % numpy_unsupported_dtypes), repeats.dtype)
if self.axis is None: if self.axis is None:
broadcastable=[False] broadcastable = [False]
else: else:
try: try:
const_reps = basic.get_scalar_constant_value(repeats) const_reps = basic.get_scalar_constant_value(repeats)
......
...@@ -13,6 +13,7 @@ from theano import config, tensor, function ...@@ -13,6 +13,7 @@ from theano import config, tensor, function
numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]] numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]]
numpy_16 = bool(numpy_ver >= [1, 6]) numpy_16 = bool(numpy_ver >= [1, 6])
class TestBinCountOp(utt.InferShapeTester): class TestBinCountOp(utt.InferShapeTester):
def setUp(self): def setUp(self):
super(TestBinCountOp, self).setUp() super(TestBinCountOp, self).setUp()
...@@ -25,7 +26,7 @@ class TestBinCountOp(utt.InferShapeTester): ...@@ -25,7 +26,7 @@ class TestBinCountOp(utt.InferShapeTester):
'uint8', 'uint16', 'uint32', 'uint64'): 'uint8', 'uint16', 'uint32', 'uint64'):
# uint64 always fails # uint64 always fails
# int64 and uint32 also fail if python int are 32-bit # int64 and uint32 also fail if python int are 32-bit
int_bitwidth = theano.gof.cmodule.python_int_bitwidth() int_bitwidth = theano.gof.python_int_bitwidth()
if int_bitwidth == 64: if int_bitwidth == 64:
numpy_unsupported_dtypes = ('uint64',) numpy_unsupported_dtypes = ('uint64',)
if int_bitwidth == 32: if int_bitwidth == 32:
...@@ -57,7 +58,7 @@ class TestBinCountOp(utt.InferShapeTester): ...@@ -57,7 +58,7 @@ class TestBinCountOp(utt.InferShapeTester):
for dtype in tensor.discrete_dtypes: for dtype in tensor.discrete_dtypes:
# uint64 always fails # uint64 always fails
# int64 and uint32 also fail if python int are 32-bit # int64 and uint32 also fail if python int are 32-bit
int_bitwidth = theano.gof.cmodule.python_int_bitwidth() int_bitwidth = theano.gof.python_int_bitwidth()
if int_bitwidth == 64: if int_bitwidth == 64:
numpy_unsupported_dtypes = ('uint64',) numpy_unsupported_dtypes = ('uint64',)
if int_bitwidth == 32: if int_bitwidth == 32:
...@@ -188,7 +189,6 @@ class SqueezeTester(utt.InferShapeTester): ...@@ -188,7 +189,6 @@ class SqueezeTester(utt.InferShapeTester):
def test_grad(self): def test_grad(self):
for shape, broadcast in zip(self.shape_list, self.broadcast_list): for shape, broadcast in zip(self.shape_list, self.broadcast_list):
data = numpy.random.random(size=shape).astype(theano.config.floatX) data = numpy.random.random(size=shape).astype(theano.config.floatX)
variable = tensor.TensorType(theano.config.floatX, broadcast)()
utt.verify_grad(self.op, [data]) utt.verify_grad(self.op, [data])
...@@ -203,7 +203,7 @@ class TestRepeatOp(utt.InferShapeTester): ...@@ -203,7 +203,7 @@ class TestRepeatOp(utt.InferShapeTester):
self.op = RepeatOp() self.op = RepeatOp()
# uint64 always fails # uint64 always fails
# int64 and uint32 also fail if python int are 32-bit # int64 and uint32 also fail if python int are 32-bit
int_bitwidth = theano.gof.cmodule.python_int_bitwidth() int_bitwidth = theano.gof.python_int_bitwidth()
if int_bitwidth == 64: if int_bitwidth == 64:
self.numpy_unsupported_dtypes = ('uint64',) self.numpy_unsupported_dtypes = ('uint64',)
if int_bitwidth == 32: if int_bitwidth == 32:
...@@ -292,6 +292,7 @@ class TestRepeatOp(utt.InferShapeTester): ...@@ -292,6 +292,7 @@ class TestRepeatOp(utt.InferShapeTester):
r = RepeatOp(axis=0)(x, 2) r = RepeatOp(axis=0)(x, 2)
self.assertEqual(r.broadcastable, (False, True, False)) self.assertEqual(r.broadcastable, (False, True, False))
class TestBartlett(utt.InferShapeTester): class TestBartlett(utt.InferShapeTester):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论