提交 bfef7ff0 authored 作者: James Bergstra's avatar James Bergstra

merge

...@@ -352,17 +352,17 @@ If a list of ``Variable`` or ``Out`` instances is given as argument, then the co ...@@ -352,17 +352,17 @@ If a list of ``Variable`` or ``Out`` instances is given as argument, then the co
x, y, s = T.matrices('xys') x, y, s = T.matrices('xys')
# print a list of 2 ndarrays # print a list of 2 ndarrays
fn1 = theano.function([x], [x+x, Out((x+x).T, borrow=True]) fn1 = theano.function([x], [x+x, Out((x+x).T, borrow=True)])
print fn1(ndarray([[1,0],[0,1]])) print fn1(numpy.asarray([[1,0],[0,1]]))
# print a list of 1 ndarray # print a list of 1 ndarray
fn2 = theano.function([x], [x+x]) fn2 = theano.function([x], [x+x])
print fn2(ndarray([[1,0],[0,1]])) print fn2(numpy.asarray([[1,0],[0,1]]))
# print an ndarray # print an ndarray
fn3 = theano.function([x], outputs=x+x) fn3 = theano.function([x], outputs=x+x)
print fn3(ndarray([[1,0],[0,1]])) print fn3(numpy.asarray([[1,0],[0,1]]))
.. _function_mode: .. _function_mode:
......
...@@ -275,6 +275,12 @@ class T_function(unittest.TestCase): ...@@ -275,6 +275,12 @@ class T_function(unittest.TestCase):
self.failUnless(dec[s] == -1) self.failUnless(dec[s] == -1)
def test_borrow_output(self):
a = T.dmatrix()
f = function([a], Out(a, borrow=False))
o = N.ones((3,3))
assert o is not f(o)
class T_picklefunction(unittest.TestCase): class T_picklefunction(unittest.TestCase):
def test_deepcopy(self): def test_deepcopy(self):
......
...@@ -7,7 +7,7 @@ import numpy.distutils #TODO: TensorType should handle this ...@@ -7,7 +7,7 @@ import numpy.distutils #TODO: TensorType should handle this
import compilelock # we will abuse the lockfile mechanism when reading and writing the registry import compilelock # we will abuse the lockfile mechanism when reading and writing the registry
_logger=logging.getLogger("theano.gof.cmodule") _logger=logging.getLogger("theano.gof.cmodule")
_logger.setLevel(logging.INFO) _logger.setLevel(logging.WARN)
def error(*args): def error(*args):
#sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n') #sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n')
...@@ -124,7 +124,7 @@ class DynamicModule(object): ...@@ -124,7 +124,7 @@ class DynamicModule(object):
#TODO: add_type #TODO: add_type
def dlimport(fullpath, suffix=None): def dlimport(fullpath, suffix=None):
"""Dynamically load a .so, .dll, or .py file """Dynamically load a .so, .pyd, .dll, or .py file
:type fullpath: string :type fullpath: string
:param fullpath: a fully-qualified path do a compiled python module :param fullpath: a fully-qualified path do a compiled python module
...@@ -137,6 +137,8 @@ def dlimport(fullpath, suffix=None): ...@@ -137,6 +137,8 @@ def dlimport(fullpath, suffix=None):
if suffix is None: if suffix is None:
if fullpath.endswith('.so'): if fullpath.endswith('.so'):
suffix = '.so' suffix = '.so'
elif fullpath.endswith('.pyd'):
suffix = '.pyd'
elif fullpath.endswith('.dll'): elif fullpath.endswith('.dll'):
suffix = '.dll' suffix = '.dll'
elif fullpath.endswith('.py'): elif fullpath.endswith('.py'):
...@@ -172,9 +174,9 @@ def module_name_from_dir(dirname): ...@@ -172,9 +174,9 @@ def module_name_from_dir(dirname):
"""Scan the contents of a cache directory and return full path of the dynamic lib in it. """Scan the contents of a cache directory and return full path of the dynamic lib in it.
""" """
files = os.listdir(dirname) files = os.listdir(dirname)
names = [file for file in files if file.endswith('.so')] names = [file for file in files if file.endswith('.so') or file.endswith('.pyd')]
if len(names) != 1: if len(names) != 1:
raise Exception('Failed to load .so from dir', dirname) raise Exception('Failed to load dynamic libraries from dir', dirname)
return os.path.join(dirname, names[0]) return os.path.join(dirname, names[0])
class ModuleCache(object): class ModuleCache(object):
...@@ -210,7 +212,7 @@ class ModuleCache(object): ...@@ -210,7 +212,7 @@ class ModuleCache(object):
"""maps module names to loaded module objects""" """maps module names to loaded module objects"""
entry_from_key = {} entry_from_key = {}
"""Maps keys to the filename of a .so """Maps keys to the filename of a .so/.pyd.
""" """
stats = [] stats = []
...@@ -353,7 +355,13 @@ class ModuleCache(object): ...@@ -353,7 +355,13 @@ class ModuleCache(object):
key_pkl = os.path.join(location, 'key.pkl') key_pkl = os.path.join(location, 'key.pkl')
key_file = file(key_pkl, 'w') key_file = file(key_pkl, 'w')
try: try:
cPickle.dump(key, key_file, cPickle.HIGHEST_PROTOCOL) if sys.platform == 'win32':
# Looks like there is a bug under Windows, where using the
# highest protocol will result in a pickle file that cannot
# be loaded afterwards.
cPickle.dump(key, key_file)
else:
cPickle.dump(key, key_file, cPickle.HIGHEST_PROTOCOL)
key_file.close() key_file.close()
key_broken = False key_broken = False
except cPickle.PicklingError: except cPickle.PicklingError:
...@@ -449,6 +457,13 @@ def get_module_cache(dirname, force_fresh=None): ...@@ -449,6 +457,13 @@ def get_module_cache(dirname, force_fresh=None):
atexit.register(_module_cache._on_atexit) atexit.register(_module_cache._on_atexit)
return _module_cache return _module_cache
def get_lib_extension():
"""Return the platform-dependent extension for compiled modules."""
if sys.platform == 'win32':
return 'pyd'
else:
return 'so'
def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[], def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[],
preargs=[], tmpdir=None): preargs=[], tmpdir=None):
#TODO: don't to the dlimport in this function #TODO: don't to the dlimport in this function
...@@ -456,11 +471,21 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[] ...@@ -456,11 +471,21 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
preargs.append('-fPIC') preargs.append('-fPIC')
no_opt = False no_opt = False
include_dirs = [distutils.sysconfig.get_python_inc()] + \ include_dirs = [distutils.sysconfig.get_python_inc()] + \
numpy.distutils.misc_util.get_numpy_include_dirs()\ numpy.distutils.misc_util.get_numpy_include_dirs()\
+ include_dirs + include_dirs
libs = [os.path.split(distutils.sysconfig.get_python_inc())[-1]] + libs python_inc = distutils.sysconfig.get_python_inc()
if sys.platform == 'win32':
# Typical include directory: C:\Python26\include
libname = os.path.basename(os.path.dirname(python_inc)).lower()
# Also add directory containing the Python library to the library
# directories.
python_lib_dir = os.path.join(os.path.dirname(python_inc), 'libs')
lib_dirs = [python_lib_dir] + lib_dirs
else:
# Typical include directory: /usr/include/python2.6
libname = os.path.basename(python_inc)
libs = [libname] + libs
workdir = location workdir = location
...@@ -474,7 +499,8 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[] ...@@ -474,7 +499,8 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
cppfile.write(src_code) cppfile.write(src_code)
cppfile.close() cppfile.close()
lib_filename = os.path.join(workdir, '%s.so'% module_name) lib_filename = os.path.join(workdir, '%s.%s' %
(module_name, get_lib_extension()))
debug('Generating shared lib', lib_filename) debug('Generating shared lib', lib_filename)
cmd = ['g++', '-shared', '-g'] cmd = ['g++', '-shared', '-g']
...@@ -499,7 +525,6 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[] ...@@ -499,7 +525,6 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
file(os.path.join(workdir, "__init__.py"),'w').close() file(os.path.join(workdir, "__init__.py"),'w').close()
rval = dlimport(lib_filename) rval = dlimport(lib_filename)
return rval return rval
...@@ -531,7 +556,8 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[ ...@@ -531,7 +556,8 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[
try: try:
cppfile.write(src_code) cppfile.write(src_code)
cppfile.close() cppfile.close()
lib_filename = os.path.join(workdir, '%s.so'% module_name) lib_filename = os.path.join(workdir, '%s.%s' %
(module_name, get_lib_extension()))
debug('Generating shared lib', lib_filename) debug('Generating shared lib', lib_filename)
cmd = ['nvcc', '-shared', '-g'] cmd = ['nvcc', '-shared', '-g']
......
...@@ -13,3 +13,21 @@ def test_bug_2009_06_02_trac_387(): ...@@ -13,3 +13,21 @@ def test_bug_2009_06_02_trac_387():
#z = tensor.lscalar('z') #z = tensor.lscalar('z')
#f = theano.function([z], tensor.DimShuffle([], ['x'])(z) / 2) #f = theano.function([z], tensor.DimShuffle([], ['x'])(z) / 2)
def test_bug_2009_07_17_borrowed_output():
"""Regression test for a bug where output was borrowed by mistake."""
a = theano.tensor.dmatrix()
b = theano.tensor.dmatrix()
# The output should *NOT* be borrowed.
g = theano.function([a, b],
theano.Out(theano.tensor.dot(a, b), borrow=False))
x = numpy.zeros((1, 2))
y = numpy.ones((2, 5))
z = g(x, y)
print z # Should be zero.
x.fill(1)
print g(x, y) # Should be non-zero.
print z # Should still be zero.
assert numpy.linalg.norm(z) == 0
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论