提交 eb6569b5 authored 作者: bergstra@tikuanyin's avatar bergstra@tikuanyin

ModuleCache works without the pkl file now, more robust to various errors

上级 ada92aea
......@@ -147,3 +147,13 @@ def dot(l, r):
raise NotImplementedError("Dot failed for the following reaons:", (e0, e1))
return rval
###
# Set a default logger
#
import logging
logging_default_handler = logging.StreamHandler()
logging.getLogger("theano").addHandler(logging_default_handler)
logging.getLogger("theano").setLevel(logging.WARNING)
......@@ -81,6 +81,11 @@ class TanhRnn(Op):
in which z[0] = z0.
"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, z0, A):
"""
......@@ -121,7 +126,7 @@ class TanhRnnGrad(Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self, other):
def __hash__(self):
return hash(type(self))
def make_node(self, A, z, gz):
......
......@@ -26,10 +26,10 @@ import cmodule
import logging
_logger=logging.getLogger("theano.gof.cc")
def info(*args):
sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
#sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger.info(' '.join(str(a) for a in args))
def debug(*args):
sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
#sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
_logger.debug(' '.join(str(a) for a in args))
def warning(*args):
sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n')
......@@ -367,6 +367,7 @@ class CLinker(link.Linker):
# The orphans field is listified to ensure a consistent order.
self.orphans = list(r for r in self.variables if isinstance(r, graph.Value) and r not in self.inputs) #list(env.orphans.difference(self.outputs))
self.temps = list(set(self.variables).difference(self.inputs).difference(self.outputs).difference(self.orphans))
self.consts = []
self.node_order = env.toposort()
def code_gen(self):
......@@ -390,7 +391,7 @@ class CLinker(link.Linker):
env = self.env
consts = []
self.consts = []
symbol = {}
......@@ -428,7 +429,7 @@ class CLinker(link.Linker):
if isinstance(variable, graph.Constant):
try:
symbol[variable] = "(" + variable.type.c_literal(variable.data) + ")"
consts.append(variable)
self.consts.append(variable)
self.orphans.remove(variable)
continue
except (utils.MethodNotDefined, NotImplementedError):
......@@ -530,7 +531,12 @@ class CLinker(link.Linker):
self.tasks = tasks
all = self.inputs + self.outputs + self.orphans
assert (self.init_tasks, self.tasks) == self.get_init_tasks()
if (self.init_tasks, self.tasks) != self.get_init_tasks():
print >> sys.stderr, "init_tasks\n", self.init_tasks
print >> sys.stderr, self.get_init_tasks()[0]
print >> sys.stderr, "tasks\n", self.tasks
print >> sys.stderr, self.get_init_tasks()[1]
assert (self.init_tasks, self.tasks) == self.get_init_tasks()
# List of indices that should be ignored when passing the arguments
# (basically, everything that the previous call to uniq eliminated)
......@@ -646,6 +652,14 @@ class CLinker(link.Linker):
tasks = []
id=1
for v in self.variables:
if v in self.consts:
continue
if v in self.orphans and isinstance(v, graph.Constant):
try:
v.type.c_literal(v.data) #constant will be inlined, no need to get
continue
except (utils.MethodNotDefined, NotImplementedError):
pass
init_tasks.append((v, 'init', id))
tasks.append((v, 'get', id+1))
id += 2
......@@ -687,7 +701,7 @@ class CLinker(link.Linker):
The signature has the following form:
{{{
'CLinker.cmodule_key',
'CLinker.cmodule_key', compilation args, libraries,
op0, (input0.type, input1.type, input0 pos, input1 pos)
op1, (...)
...
......@@ -717,6 +731,9 @@ class CLinker(link.Linker):
env_computed_set = set()
op_pos = {} # Apply -> topological position
rval = ['CLinker.cmodule_key'] # will be cast to tuple on return
rval.append(tuple(self.compile_args()))
rval.append(tuple(self.libraries()))
version = []
# assert that every input to every node is one of'
# - an env input
......@@ -735,12 +752,19 @@ class CLinker(link.Linker):
return (op_pos[i.owner], i.owner.outputs.index(i))
for opos, o in enumerate(order):
version.append(o.op.c_code_cache_version())
for i in o.inputs:
version.append(i.type.c_code_cache_version())
for i in o.outputs:
version.append(i.type.c_code_cache_version())
rval.append((o.op, tuple((i.type, graphpos(i)) for i in o.inputs)))
op_pos[o] = opos
env_computed_set.update(o.outputs)
rval = tuple(rval)
return rval
for v in version:
if not v: #one of the ops or types here is unversioned
return ((), tuple(rval))
return tuple(version), tuple(rval)
def compile_cmodule(self, location=None):
"""
......
差异被折叠。
......@@ -162,6 +162,16 @@ class CLinkerOp(object):
raise utils.MethodNotDefined('%s.c_support_code' \
% self.__class__.__name__)
def c_code_cache_version(self):
"""Return a tuple of integers indicating the version of this Op.
An empty tuple indicates an 'unversioned' Op that will not be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
"""
return (1,)
class PureOp(object):
"""
An :term:`Op` is a type of operation.
......
......@@ -57,6 +57,9 @@ class TDouble(Type):
free(%(name)s_bad_thing);
""" % locals()
def c_code_cache_version(self):
return ()
tdouble = TDouble()
def double(name):
......@@ -83,6 +86,8 @@ class MyOp(Op):
def perform(self, node, inputs, (out, )):
out[0] = self.impl(*inputs)
def c_code_cache_version(self):
return ()
class Unary(MyOp):
......
......@@ -210,6 +210,16 @@ class CLinkerType(object):
"""
raise MethodNotDefined("c_support_code", type(self), self.__class__.__name__)
def c_code_cache_version(self):
"""Return a tuple of integers indicating the version of this Op.
An empty tuple indicates an 'unversioned' Op that will not be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
"""
return (1,)
class PureType(object):
"""Interface specification for variable type instances.
......
......@@ -444,6 +444,10 @@ class DenseFromSparse(gof.op.Op):
"""
sparse_grad = True
"""WRITEME"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x):
x = as_sparse_variable(x)
......@@ -495,6 +499,10 @@ csc_from_dense = SparseFromDense('csc')
class Transpose(gof.op.Op):
format_map = {'csr' : 'csc',
'csc' : 'csr'}
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x):
x = as_sparse_variable(x)
return gof.Apply(self,
......@@ -510,6 +518,10 @@ class Transpose(gof.op.Op):
transpose = Transpose()
class Neg(gof.op.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x):
x = as_sparse_variable(x)
return gof.Apply(self, [x], [x.type()])
......@@ -523,6 +535,10 @@ neg = Neg()
class AddSS(gof.op.Op):
'''Add two sparse matrices '''
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x, y = map(as_sparse_variable, [x, y])
if x.type.dtype != y.type.dtype:
......@@ -545,6 +561,10 @@ class AddSS(gof.op.Op):
add_s_s = AddSS()
class AddSD(gof.op.Op):
''' Add a sparse and a dense matrix '''
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x, y = as_sparse_variable(x), tensor.as_tensor_variable(y)
if x.type.dtype != y.type.dtype:
......@@ -586,6 +606,10 @@ def sub(x,y):
class MulSS(gof.op.Op):
''' Elementwise multiply a sparse and a ndarray '''
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x, y = as_sparse_variable(x), as_sparse_variable(y)
if x.type != y.type:
......@@ -605,6 +629,10 @@ class MulSS(gof.op.Op):
mul_s_s = MulSS()
class MulSD(gof.op.Op):
''' Elementwise multiply a sparse and a ndarray '''
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x, y = as_sparse_variable(x), tensor.as_tensor_variable(y)
if x.type.dtype != y.type.dtype:
......@@ -686,6 +714,10 @@ class StructuredDot(gof.Op):
The output is presumed to be a dense matrix, and is represented by a TensorType instance.
"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a, b):
if type(a) is not SparseVariable and type(a) is not SparseConstant:
raise TypeError('First argument must be of type SparseVariable or SparseConstant');
......@@ -750,6 +782,10 @@ def structured_dot(x, y):
return _structured_dot(y.T, x.T).T
class StructuredDotCSC(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_val, a_ind, a_ptr, a_nrows, b):
dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype)
r = gof.Apply(self, [a_val, a_ind, a_ptr, a_nrows, b],
......@@ -900,6 +936,10 @@ class StructuredDotCSC(gof.Op):
sd_csc = StructuredDotCSC()
class StructuredDotCSR(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_val, a_ind, a_ptr, b):
self.dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype)
r = gof.Apply(self, [a_val, a_ind, a_ptr, b],
......@@ -1055,6 +1095,10 @@ def structured_dot_grad(sparse_A, dense_B, ga):
class StructuredDotGradCSC(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_indices, a_indptr, b, g_ab):
return gof.Apply(self, [a_indices, a_indptr, b, g_ab],
[tensor.tensor(g_ab.dtype, (False,))])
......@@ -1155,6 +1199,10 @@ sdg_csc = StructuredDotGradCSC()
class StructuredDotGradCSR(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_indices, a_indptr, b, g_ab):
return gof.Apply(self, [a_indices, a_indptr, b, g_ab], [tensor.tensor(b.dtype, (False,))])
......@@ -1256,3 +1304,4 @@ class StructuredDotGradCSR(gof.Op):
"""% dict(locals(), **sub)
sdg_csr = StructuredDotGradCSR()
......@@ -49,6 +49,10 @@ class GemmRelated(Op):
This class provides a kind of templated gemm Op.
"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def c_support_code(self):
#return cblas_header_text()
mod_str = """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论