提交 eb6569b5 authored 作者: bergstra@tikuanyin's avatar bergstra@tikuanyin

ModuleCache works without the pkl file now, more robust to various errors

上级 ada92aea
...@@ -147,3 +147,13 @@ def dot(l, r): ...@@ -147,3 +147,13 @@ def dot(l, r):
raise NotImplementedError("Dot failed for the following reaons:", (e0, e1)) raise NotImplementedError("Dot failed for the following reaons:", (e0, e1))
return rval return rval
###
# Set a default logger
#
import logging
logging_default_handler = logging.StreamHandler()
logging.getLogger("theano").addHandler(logging_default_handler)
logging.getLogger("theano").setLevel(logging.WARNING)
...@@ -81,6 +81,11 @@ class TanhRnn(Op): ...@@ -81,6 +81,11 @@ class TanhRnn(Op):
in which z[0] = z0. in which z[0] = z0.
""" """
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, z0, A): def make_node(self, x, z0, A):
""" """
...@@ -121,7 +126,7 @@ class TanhRnnGrad(Op): ...@@ -121,7 +126,7 @@ class TanhRnnGrad(Op):
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
def __hash__(self, other): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, A, z, gz): def make_node(self, A, z, gz):
......
...@@ -26,10 +26,10 @@ import cmodule ...@@ -26,10 +26,10 @@ import cmodule
import logging import logging
_logger=logging.getLogger("theano.gof.cc") _logger=logging.getLogger("theano.gof.cc")
def info(*args): def info(*args):
sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n') #sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger.info(' '.join(str(a) for a in args)) _logger.info(' '.join(str(a) for a in args))
def debug(*args): def debug(*args):
sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n') #sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
_logger.debug(' '.join(str(a) for a in args)) _logger.debug(' '.join(str(a) for a in args))
def warning(*args): def warning(*args):
sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n') sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n')
...@@ -367,6 +367,7 @@ class CLinker(link.Linker): ...@@ -367,6 +367,7 @@ class CLinker(link.Linker):
# The orphans field is listified to ensure a consistent order. # The orphans field is listified to ensure a consistent order.
self.orphans = list(r for r in self.variables if isinstance(r, graph.Value) and r not in self.inputs) #list(env.orphans.difference(self.outputs)) self.orphans = list(r for r in self.variables if isinstance(r, graph.Value) and r not in self.inputs) #list(env.orphans.difference(self.outputs))
self.temps = list(set(self.variables).difference(self.inputs).difference(self.outputs).difference(self.orphans)) self.temps = list(set(self.variables).difference(self.inputs).difference(self.outputs).difference(self.orphans))
self.consts = []
self.node_order = env.toposort() self.node_order = env.toposort()
def code_gen(self): def code_gen(self):
...@@ -390,7 +391,7 @@ class CLinker(link.Linker): ...@@ -390,7 +391,7 @@ class CLinker(link.Linker):
env = self.env env = self.env
consts = [] self.consts = []
symbol = {} symbol = {}
...@@ -428,7 +429,7 @@ class CLinker(link.Linker): ...@@ -428,7 +429,7 @@ class CLinker(link.Linker):
if isinstance(variable, graph.Constant): if isinstance(variable, graph.Constant):
try: try:
symbol[variable] = "(" + variable.type.c_literal(variable.data) + ")" symbol[variable] = "(" + variable.type.c_literal(variable.data) + ")"
consts.append(variable) self.consts.append(variable)
self.orphans.remove(variable) self.orphans.remove(variable)
continue continue
except (utils.MethodNotDefined, NotImplementedError): except (utils.MethodNotDefined, NotImplementedError):
...@@ -530,7 +531,12 @@ class CLinker(link.Linker): ...@@ -530,7 +531,12 @@ class CLinker(link.Linker):
self.tasks = tasks self.tasks = tasks
all = self.inputs + self.outputs + self.orphans all = self.inputs + self.outputs + self.orphans
assert (self.init_tasks, self.tasks) == self.get_init_tasks() if (self.init_tasks, self.tasks) != self.get_init_tasks():
print >> sys.stderr, "init_tasks\n", self.init_tasks
print >> sys.stderr, self.get_init_tasks()[0]
print >> sys.stderr, "tasks\n", self.tasks
print >> sys.stderr, self.get_init_tasks()[1]
assert (self.init_tasks, self.tasks) == self.get_init_tasks()
# List of indices that should be ignored when passing the arguments # List of indices that should be ignored when passing the arguments
# (basically, everything that the previous call to uniq eliminated) # (basically, everything that the previous call to uniq eliminated)
...@@ -646,6 +652,14 @@ class CLinker(link.Linker): ...@@ -646,6 +652,14 @@ class CLinker(link.Linker):
tasks = [] tasks = []
id=1 id=1
for v in self.variables: for v in self.variables:
if v in self.consts:
continue
if v in self.orphans and isinstance(v, graph.Constant):
try:
v.type.c_literal(v.data) #constant will be inlined, no need to get
continue
except (utils.MethodNotDefined, NotImplementedError):
pass
init_tasks.append((v, 'init', id)) init_tasks.append((v, 'init', id))
tasks.append((v, 'get', id+1)) tasks.append((v, 'get', id+1))
id += 2 id += 2
...@@ -687,7 +701,7 @@ class CLinker(link.Linker): ...@@ -687,7 +701,7 @@ class CLinker(link.Linker):
The signature has the following form: The signature has the following form:
{{{ {{{
'CLinker.cmodule_key', 'CLinker.cmodule_key', compilation args, libraries,
op0, (input0.type, input1.type, input0 pos, input1 pos) op0, (input0.type, input1.type, input0 pos, input1 pos)
op1, (...) op1, (...)
... ...
...@@ -717,6 +731,9 @@ class CLinker(link.Linker): ...@@ -717,6 +731,9 @@ class CLinker(link.Linker):
env_computed_set = set() env_computed_set = set()
op_pos = {} # Apply -> topological position op_pos = {} # Apply -> topological position
rval = ['CLinker.cmodule_key'] # will be cast to tuple on return rval = ['CLinker.cmodule_key'] # will be cast to tuple on return
rval.append(tuple(self.compile_args()))
rval.append(tuple(self.libraries()))
version = []
# assert that every input to every node is one of' # assert that every input to every node is one of'
# - an env input # - an env input
...@@ -735,12 +752,19 @@ class CLinker(link.Linker): ...@@ -735,12 +752,19 @@ class CLinker(link.Linker):
return (op_pos[i.owner], i.owner.outputs.index(i)) return (op_pos[i.owner], i.owner.outputs.index(i))
for opos, o in enumerate(order): for opos, o in enumerate(order):
version.append(o.op.c_code_cache_version())
for i in o.inputs:
version.append(i.type.c_code_cache_version())
for i in o.outputs:
version.append(i.type.c_code_cache_version())
rval.append((o.op, tuple((i.type, graphpos(i)) for i in o.inputs))) rval.append((o.op, tuple((i.type, graphpos(i)) for i in o.inputs)))
op_pos[o] = opos op_pos[o] = opos
env_computed_set.update(o.outputs) env_computed_set.update(o.outputs)
rval = tuple(rval) for v in version:
return rval if not v: #one of the ops or types here is unversioned
return ((), tuple(rval))
return tuple(version), tuple(rval)
def compile_cmodule(self, location=None): def compile_cmodule(self, location=None):
""" """
......
差异被折叠。
...@@ -162,6 +162,16 @@ class CLinkerOp(object): ...@@ -162,6 +162,16 @@ class CLinkerOp(object):
raise utils.MethodNotDefined('%s.c_support_code' \ raise utils.MethodNotDefined('%s.c_support_code' \
% self.__class__.__name__) % self.__class__.__name__)
def c_code_cache_version(self):
"""Return a tuple of integers indicating the version of this Op.
An empty tuple indicates an 'unversioned' Op that will not be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
"""
return (1,)
class PureOp(object): class PureOp(object):
""" """
An :term:`Op` is a type of operation. An :term:`Op` is a type of operation.
......
...@@ -57,6 +57,9 @@ class TDouble(Type): ...@@ -57,6 +57,9 @@ class TDouble(Type):
free(%(name)s_bad_thing); free(%(name)s_bad_thing);
""" % locals() """ % locals()
def c_code_cache_version(self):
return ()
tdouble = TDouble() tdouble = TDouble()
def double(name): def double(name):
...@@ -83,6 +86,8 @@ class MyOp(Op): ...@@ -83,6 +86,8 @@ class MyOp(Op):
def perform(self, node, inputs, (out, )): def perform(self, node, inputs, (out, )):
out[0] = self.impl(*inputs) out[0] = self.impl(*inputs)
def c_code_cache_version(self):
return ()
class Unary(MyOp): class Unary(MyOp):
......
...@@ -210,6 +210,16 @@ class CLinkerType(object): ...@@ -210,6 +210,16 @@ class CLinkerType(object):
""" """
raise MethodNotDefined("c_support_code", type(self), self.__class__.__name__) raise MethodNotDefined("c_support_code", type(self), self.__class__.__name__)
def c_code_cache_version(self):
"""Return a tuple of integers indicating the version of this Op.
An empty tuple indicates an 'unversioned' Op that will not be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
"""
return (1,)
class PureType(object): class PureType(object):
"""Interface specification for variable type instances. """Interface specification for variable type instances.
......
...@@ -444,6 +444,10 @@ class DenseFromSparse(gof.op.Op): ...@@ -444,6 +444,10 @@ class DenseFromSparse(gof.op.Op):
""" """
sparse_grad = True sparse_grad = True
"""WRITEME""" """WRITEME"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x): def make_node(self, x):
x = as_sparse_variable(x) x = as_sparse_variable(x)
...@@ -495,6 +499,10 @@ csc_from_dense = SparseFromDense('csc') ...@@ -495,6 +499,10 @@ csc_from_dense = SparseFromDense('csc')
class Transpose(gof.op.Op): class Transpose(gof.op.Op):
format_map = {'csr' : 'csc', format_map = {'csr' : 'csc',
'csc' : 'csr'} 'csc' : 'csr'}
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x): def make_node(self, x):
x = as_sparse_variable(x) x = as_sparse_variable(x)
return gof.Apply(self, return gof.Apply(self,
...@@ -510,6 +518,10 @@ class Transpose(gof.op.Op): ...@@ -510,6 +518,10 @@ class Transpose(gof.op.Op):
transpose = Transpose() transpose = Transpose()
class Neg(gof.op.Op): class Neg(gof.op.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x): def make_node(self, x):
x = as_sparse_variable(x) x = as_sparse_variable(x)
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
...@@ -523,6 +535,10 @@ neg = Neg() ...@@ -523,6 +535,10 @@ neg = Neg()
class AddSS(gof.op.Op): class AddSS(gof.op.Op):
'''Add two sparse matrices ''' '''Add two sparse matrices '''
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y): def make_node(self, x, y):
x, y = map(as_sparse_variable, [x, y]) x, y = map(as_sparse_variable, [x, y])
if x.type.dtype != y.type.dtype: if x.type.dtype != y.type.dtype:
...@@ -545,6 +561,10 @@ class AddSS(gof.op.Op): ...@@ -545,6 +561,10 @@ class AddSS(gof.op.Op):
add_s_s = AddSS() add_s_s = AddSS()
class AddSD(gof.op.Op): class AddSD(gof.op.Op):
''' Add a sparse and a dense matrix ''' ''' Add a sparse and a dense matrix '''
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y): def make_node(self, x, y):
x, y = as_sparse_variable(x), tensor.as_tensor_variable(y) x, y = as_sparse_variable(x), tensor.as_tensor_variable(y)
if x.type.dtype != y.type.dtype: if x.type.dtype != y.type.dtype:
...@@ -586,6 +606,10 @@ def sub(x,y): ...@@ -586,6 +606,10 @@ def sub(x,y):
class MulSS(gof.op.Op): class MulSS(gof.op.Op):
''' Elementwise multiply a sparse and a ndarray ''' ''' Elementwise multiply a sparse and a ndarray '''
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y): def make_node(self, x, y):
x, y = as_sparse_variable(x), as_sparse_variable(y) x, y = as_sparse_variable(x), as_sparse_variable(y)
if x.type != y.type: if x.type != y.type:
...@@ -605,6 +629,10 @@ class MulSS(gof.op.Op): ...@@ -605,6 +629,10 @@ class MulSS(gof.op.Op):
mul_s_s = MulSS() mul_s_s = MulSS()
class MulSD(gof.op.Op): class MulSD(gof.op.Op):
''' Elementwise multiply a sparse and a ndarray ''' ''' Elementwise multiply a sparse and a ndarray '''
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y): def make_node(self, x, y):
x, y = as_sparse_variable(x), tensor.as_tensor_variable(y) x, y = as_sparse_variable(x), tensor.as_tensor_variable(y)
if x.type.dtype != y.type.dtype: if x.type.dtype != y.type.dtype:
...@@ -686,6 +714,10 @@ class StructuredDot(gof.Op): ...@@ -686,6 +714,10 @@ class StructuredDot(gof.Op):
The output is presumed to be a dense matrix, and is represented by a TensorType instance. The output is presumed to be a dense matrix, and is represented by a TensorType instance.
""" """
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a, b): def make_node(self, a, b):
if type(a) is not SparseVariable and type(a) is not SparseConstant: if type(a) is not SparseVariable and type(a) is not SparseConstant:
raise TypeError('First argument must be of type SparseVariable or SparseConstant'); raise TypeError('First argument must be of type SparseVariable or SparseConstant');
...@@ -750,6 +782,10 @@ def structured_dot(x, y): ...@@ -750,6 +782,10 @@ def structured_dot(x, y):
return _structured_dot(y.T, x.T).T return _structured_dot(y.T, x.T).T
class StructuredDotCSC(gof.Op): class StructuredDotCSC(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_val, a_ind, a_ptr, a_nrows, b): def make_node(self, a_val, a_ind, a_ptr, a_nrows, b):
dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype) dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype)
r = gof.Apply(self, [a_val, a_ind, a_ptr, a_nrows, b], r = gof.Apply(self, [a_val, a_ind, a_ptr, a_nrows, b],
...@@ -900,6 +936,10 @@ class StructuredDotCSC(gof.Op): ...@@ -900,6 +936,10 @@ class StructuredDotCSC(gof.Op):
sd_csc = StructuredDotCSC() sd_csc = StructuredDotCSC()
class StructuredDotCSR(gof.Op): class StructuredDotCSR(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_val, a_ind, a_ptr, b): def make_node(self, a_val, a_ind, a_ptr, b):
self.dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype) self.dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype)
r = gof.Apply(self, [a_val, a_ind, a_ptr, b], r = gof.Apply(self, [a_val, a_ind, a_ptr, b],
...@@ -1055,6 +1095,10 @@ def structured_dot_grad(sparse_A, dense_B, ga): ...@@ -1055,6 +1095,10 @@ def structured_dot_grad(sparse_A, dense_B, ga):
class StructuredDotGradCSC(gof.Op): class StructuredDotGradCSC(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_indices, a_indptr, b, g_ab): def make_node(self, a_indices, a_indptr, b, g_ab):
return gof.Apply(self, [a_indices, a_indptr, b, g_ab], return gof.Apply(self, [a_indices, a_indptr, b, g_ab],
[tensor.tensor(g_ab.dtype, (False,))]) [tensor.tensor(g_ab.dtype, (False,))])
...@@ -1155,6 +1199,10 @@ sdg_csc = StructuredDotGradCSC() ...@@ -1155,6 +1199,10 @@ sdg_csc = StructuredDotGradCSC()
class StructuredDotGradCSR(gof.Op): class StructuredDotGradCSR(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_indices, a_indptr, b, g_ab): def make_node(self, a_indices, a_indptr, b, g_ab):
return gof.Apply(self, [a_indices, a_indptr, b, g_ab], [tensor.tensor(b.dtype, (False,))]) return gof.Apply(self, [a_indices, a_indptr, b, g_ab], [tensor.tensor(b.dtype, (False,))])
...@@ -1256,3 +1304,4 @@ class StructuredDotGradCSR(gof.Op): ...@@ -1256,3 +1304,4 @@ class StructuredDotGradCSR(gof.Op):
"""% dict(locals(), **sub) """% dict(locals(), **sub)
sdg_csr = StructuredDotGradCSR() sdg_csr = StructuredDotGradCSR()
...@@ -49,6 +49,10 @@ class GemmRelated(Op): ...@@ -49,6 +49,10 @@ class GemmRelated(Op):
This class provides a kind of templated gemm Op. This class provides a kind of templated gemm Op.
""" """
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def c_support_code(self): def c_support_code(self):
#return cblas_header_text() #return cblas_header_text()
mod_str = """ mod_str = """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论