提交 aa048697 authored 作者: James Bergstra's avatar James Bergstra

merge

......@@ -199,20 +199,21 @@ class ModuleCache(object):
If the ``version`` is either 0 or (), then the corresponding module is unversioned, and
will be deleted in an atexit() handler.
If the ``version`` is neither 0 nor (), then the module will be kept in the cache between
processes, but it may be deleted if another key comes
along that has the same ``rest``, and a ``version`` that is considered higher than the
first one.
processes.
:todo: Versioning functionality is planned for implementation later, it is not implemented
yet.
An unversioned module is not deleted by the process that creates it. Deleting such modules
does not work on NFS filesystems because the tmpdir in which the library resides is in use
until the end of the process' lifetime. Instead, unversioned modules are left in their
tmpdirs without corresponding .pkl files. These modules and their directories are erased
by subsequent processes' refresh() functions.
"""
dirname = ""
"""The working directory that is managed by this interface"""
module_from_name = {}
"""maps module names to loaded module objects"""
"""maps a module filename to the loaded module object"""
entry_from_key = {}
"""Maps keys to the filename of a .so/.pyd.
......@@ -272,7 +273,17 @@ class ModuleCache(object):
for root, dirs, files in os.walk(self.dirname):
if os.path.join(root, 'key.pkl') in self.loaded_key_pkl:
continue
if 'key.pkl' in files:
elif 'delete.me' in files or len(files)==0:
# On NFS filesystems, it is impossible to delete a directory with open
# files in it. So instead, some commands in this file will respond to a
# failed rmtree() by touching a 'delete.me' file. This file is a message
# for a future process to try deleting the directory.
try:
shutil.rmtree(root)
except:
# the directory is still in use?? We just leave it for future removal.
pass
elif 'key.pkl' in files:
key_pkl = os.path.join(root, 'key.pkl')
debug('refresh adding', key_pkl)
try:
......@@ -318,14 +329,19 @@ class ModuleCache(object):
# If so, it should not have been deleted. This should be considered a
# failure of the OTHER process, that deleted it.
if entry in self.module_from_name:
error("The module %s that was loaded by this ModuleCache can no longer be read from file... this could lead to problems." % name)
error("The module %s that was loaded by this ModuleCache can no longer be read from file %s ... this could lead to problems." % (key,entry))
del self.module_from_name[entry]
info("deleting ModuleCache entry", entry)
del self.entry_from_key[key]
if key[0]:
#this is a versioned entry, so should have been on disk
self.loaded_key_pkl.remove(os.path.join(os.path.dirname(entry), 'key.pkl'))
# this is a versioned entry, so should have been on disk
# Something weird happened to cause this, so we are responding by
# printing a warning, removing evidence that we ever saw this mystery
# key.
pkl_file_to_remove = os.path.join(os.path.dirname(entry), 'key.pkl')
warn('Removing key file %s because the corresponding module is gone from the file system.' % pkl_file_to_remove)
self.loaded_key_pkl.remove(pkl_file_to_remove)
finally:
compilelock.release_lock()
......@@ -432,8 +448,8 @@ class ModuleCache(object):
del self.entry_from_key[key]
parent = os.path.dirname(entry)
assert parent.startswith(os.path.join(self.dirname, 'tmp'))
debug("Removing cache dir", parent)
shutil.rmtree(parent)
info("clear_old removing cache dir", parent)
_rmtree(parent)
finally:
compilelock.release_lock()
......@@ -462,13 +478,24 @@ class ModuleCache(object):
parent = os.path.dirname(entry)
assert parent.startswith(os.path.join(self.dirname, 'tmp'))
debug("Removing unversioned dir", parent)
shutil.rmtree(parent)
info("clear_unversioned removing cache dir", parent)
_rmtree(parent)
def _on_atexit(self):
self.refresh()
self.clear_old()
self.clear_unversioned()
def _rmtree(parent):
try:
shutil.rmtree(parent)
except Exception, e:
try:
# mark this directory for deletion by a future refresh()
open(os.path.join(parent,'delete.me'), 'w').close()
except Exception, ee:
warning('Failed to remove or mark cache directory %s for removal' % parent, ee)
_module_cache = None
def get_module_cache(dirname, force_fresh=None):
global _module_cache
......
......@@ -342,9 +342,14 @@ class Generic(SingletonType):
PyObject* %(name)s;
""" % locals()
def c_init(self, name, sub):
return """
%(name)s = NULL;
""" % locals()
def c_extract(self, name, sub):
return """
Py_XINCREF(py_%(name)s);
Py_INCREF(py_%(name)s);
%(name)s = py_%(name)s;
""" % locals()
......@@ -355,9 +360,10 @@ class Generic(SingletonType):
def c_sync(self, name, sub):
return """
Py_XDECREF(py_%(name)s);
py_%(name)s = %(name)s;
Py_XINCREF(py_%(name)s);
assert(py_%(name)s->ob_refcnt > 1);
Py_DECREF(py_%(name)s);
py_%(name)s = %(name)s ? %(name)s : Py_None;
Py_INCREF(py_%(name)s);
""" % locals()
......
......@@ -1814,6 +1814,12 @@ class Split(Op):
if axis.type not in int_types:
raise TypeError('axis must have type lscalar', axis.type)
# # The following lines are necessary if we allow splits of zero
# if isinstance(axis, gof.Constant):
# x = unbroadcast(x, int(axis.data))
# else:
# x = unbroadcast(x, *range(x.type.ndim))
inputs = [x, axis, splits]
outputs = [x.type() for i in xrange(self.len_splits)]
......@@ -1830,6 +1836,11 @@ class Split(Op):
if len(splits) != self.len_splits:
raise ValueError('In Split.perform(), len(splits) != len_splits.',
(len(splits), self.len_splits))
if numpy.sum(splits) != len_along_axis:
raise ValueError('The splits sum to %s, expected %s' % (numpy.sum(splits), len_along_axis))
if not all(splits):
raise ValueError('Cannot have a split of zero.')
# Checking is done, let's roll the splitting algorithm!
# Basically we step along the given axis of x, extracting subtensors of size splits[i]
......@@ -1847,6 +1858,47 @@ class Split(Op):
"""Join the gradients along the axis that was used to split x."""
return [join(axis, *g_outputs), None, None]
class Rebroadcast(Op):
"""
Change the input's broadcastable fields in
some predetermined way.
e.g.: Rebroadcast((0, True), (1, False))(x)
would make x broadcastable in axis 0
and not broadcastable in axis 1
See also the unbroadcast function.
"""
view_map = {0: [0]}
def __init__(self, *axis):
self.axis = dict(axis)
def make_node(self, x):
t = TensorType(dtype = x.type.dtype,
broadcastable = [self.axis.get(i, b)
for i, b in enumerate(x.type.broadcastable)])
return Apply(self, [x], [t()])
def perform(self, node, (x, ), (out, )):
for axis, value in self.axis.iteritems():
if value and x.shape[axis] != 1:
raise ValueError('Dimension %s in Rebroadcast\'s input was supposed to be 1 (got %s instead)' % (axis, x.shape[axis]))
out[0] = x
def grad(self, (x, ), (gz,)):
# restore the broadcasting pattern of the input
return Rebroadcast(*[(axis, x.type.broadcastable[axis]) for axis, value in self.axis.iteritems()])(gz),
def addbroadcast(x, *axes):
"""
Make the input broadcastable in the specified axes.
"""
return Rebroadcast(*[(axis, True) for axis in axes])(x)
def unbroadcast(x, *axes):
"""
Make the input impossible to broadcast in the specified axes.
"""
return Rebroadcast(*[(axis, False) for axis in axes])(x)
class Join(Op):
"""
Concatenate multiple `TensorVariable`s along some axis.
......@@ -1909,6 +1961,9 @@ class Join(Op):
bcastable[axis] = False
except IndexError, e:
raise ValueError('Join argument "axis" is out of range (given input dimensions)')
as_tensor_variable_args = [unbroadcast(x, axis) for x in as_tensor_variable_args]
else:
as_tensor_variable_args = [unbroadcast(x, *range(x.type.ndim)) for x in as_tensor_variable_args]
inputs = [as_tensor_variable(axis)] + as_tensor_variable_args
if inputs[0].type not in int_types:
......
......@@ -67,6 +67,9 @@ class DimShuffle(Op):
DimShuffle((False, False, False), [2, 0, 1]) -> AxBxC to CxAxB
DimShuffle((False, False), [0, 'x', 1]) -> AxB to Ax1xB
DimShuffle((False, False), [1, 'x', 0]) -> AxB to Bx1xA
The reordering of the dimensions can be done in numpy with the transpose function.
Adding, subtracting dimensions can be done with reshape.
"""
def __init__(self, input_broadcastable, new_order, inplace = False):
......
......@@ -1944,6 +1944,20 @@ def test_convert_to_complex():
b = value(numpy.ones(3, dtype='complex128'))
f = function([a],basic.convert_to_complex128(a))
assert a.type.values_eq_approx(b.data, f(a.data))
for t in ['int8','int16','int32','int64','float32']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex64'))
f = function([a],basic.convert_to_complex64(a))
assert a.type.values_eq_approx(b.data, f(a.data))
#this work, but should we allow it? How well it is implemented?
for t in ['float64']:
a = value(numpy.ones(3, dtype=t))
b = value(numpy.ones(3, dtype='complex64'))
f = function([a],basic.convert_to_complex64(a))
assert a.type.values_eq_approx(b.data, f(a.data))
def test_bug_complext_10_august_09():
v0 = dmatrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论