提交 eb6569b5 authored 作者: bergstra@tikuanyin's avatar bergstra@tikuanyin

ModuleCache works without the pkl file now, more robust to various errors

上级 ada92aea
......@@ -147,3 +147,13 @@ def dot(l, r):
raise NotImplementedError("Dot failed for the following reaons:", (e0, e1))
return rval
###
# Set a default logger
#
import logging
logging_default_handler = logging.StreamHandler()
logging.getLogger("theano").addHandler(logging_default_handler)
logging.getLogger("theano").setLevel(logging.WARNING)
......@@ -81,6 +81,11 @@ class TanhRnn(Op):
in which z[0] = z0.
"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, z0, A):
"""
......@@ -121,7 +126,7 @@ class TanhRnnGrad(Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self, other):
def __hash__(self):
return hash(type(self))
def make_node(self, A, z, gz):
......
......@@ -26,10 +26,10 @@ import cmodule
import logging
_logger=logging.getLogger("theano.gof.cc")
def info(*args):
sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
#sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger.info(' '.join(str(a) for a in args))
def debug(*args):
sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
#sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
_logger.debug(' '.join(str(a) for a in args))
def warning(*args):
sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n')
......@@ -367,6 +367,7 @@ class CLinker(link.Linker):
# The orphans field is listified to ensure a consistent order.
self.orphans = list(r for r in self.variables if isinstance(r, graph.Value) and r not in self.inputs) #list(env.orphans.difference(self.outputs))
self.temps = list(set(self.variables).difference(self.inputs).difference(self.outputs).difference(self.orphans))
self.consts = []
self.node_order = env.toposort()
def code_gen(self):
......@@ -390,7 +391,7 @@ class CLinker(link.Linker):
env = self.env
consts = []
self.consts = []
symbol = {}
......@@ -428,7 +429,7 @@ class CLinker(link.Linker):
if isinstance(variable, graph.Constant):
try:
symbol[variable] = "(" + variable.type.c_literal(variable.data) + ")"
consts.append(variable)
self.consts.append(variable)
self.orphans.remove(variable)
continue
except (utils.MethodNotDefined, NotImplementedError):
......@@ -530,7 +531,12 @@ class CLinker(link.Linker):
self.tasks = tasks
all = self.inputs + self.outputs + self.orphans
assert (self.init_tasks, self.tasks) == self.get_init_tasks()
if (self.init_tasks, self.tasks) != self.get_init_tasks():
print >> sys.stderr, "init_tasks\n", self.init_tasks
print >> sys.stderr, self.get_init_tasks()[0]
print >> sys.stderr, "tasks\n", self.tasks
print >> sys.stderr, self.get_init_tasks()[1]
assert (self.init_tasks, self.tasks) == self.get_init_tasks()
# List of indices that should be ignored when passing the arguments
# (basically, everything that the previous call to uniq eliminated)
......@@ -646,6 +652,14 @@ class CLinker(link.Linker):
tasks = []
id=1
for v in self.variables:
if v in self.consts:
continue
if v in self.orphans and isinstance(v, graph.Constant):
try:
v.type.c_literal(v.data) #constant will be inlined, no need to get
continue
except (utils.MethodNotDefined, NotImplementedError):
pass
init_tasks.append((v, 'init', id))
tasks.append((v, 'get', id+1))
id += 2
......@@ -687,7 +701,7 @@ class CLinker(link.Linker):
The signature has the following form:
{{{
'CLinker.cmodule_key',
'CLinker.cmodule_key', compilation args, libraries,
op0, (input0.type, input1.type, input0 pos, input1 pos)
op1, (...)
...
......@@ -717,6 +731,9 @@ class CLinker(link.Linker):
env_computed_set = set()
op_pos = {} # Apply -> topological position
rval = ['CLinker.cmodule_key'] # will be cast to tuple on return
rval.append(tuple(self.compile_args()))
rval.append(tuple(self.libraries()))
version = []
# assert that every input to every node is one of'
# - an env input
......@@ -735,12 +752,19 @@ class CLinker(link.Linker):
return (op_pos[i.owner], i.owner.outputs.index(i))
for opos, o in enumerate(order):
version.append(o.op.c_code_cache_version())
for i in o.inputs:
version.append(i.type.c_code_cache_version())
for i in o.outputs:
version.append(i.type.c_code_cache_version())
rval.append((o.op, tuple((i.type, graphpos(i)) for i in o.inputs)))
op_pos[o] = opos
env_computed_set.update(o.outputs)
rval = tuple(rval)
return rval
for v in version:
if not v: #one of the ops or types here is unversioned
return ((), tuple(rval))
return tuple(version), tuple(rval)
def compile_cmodule(self, location=None):
"""
......
"""Generate and compile C modules for Python
"""Generate and compile C modules for Python,
"""
import os, tempfile, StringIO, sys, logging, subprocess, cPickle, atexit
import os, tempfile, StringIO, sys, logging, subprocess, cPickle, atexit, time, shutil, stat
import compilelock # we will abuse the lockfile mechanism when reading and writing the registry
_logger=logging.getLogger("theano.gof.cmodule")
_logger.setLevel(logging.INFO)
def error(*args):
#sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n')
_logger.error("ERROR: "+' '.join(str(a) for a in args))
def warning(*args):
#sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n')
_logger.warning(' '.join(str(a) for a in args))
_logger.warning("WARNING: "+' '.join(str(a) for a in args))
def info(*args):
#sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger.info(' '.join(str(a) for a in args))
_logger.info("INFO: "+' '.join(str(a) for a in args))
def debug(*args):
#sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
_logger.debug(' '.join(str(a) for a in args))
_logger.debug("DEBUG: "+' '.join(str(a) for a in args))
METH_VARARGS="METH_VARARGS"
METH_NOARGS="METH_NOARGS"
......@@ -156,35 +162,158 @@ def dlimport(fullpath, suffix=None):
assert fullpath.startswith(rval.__file__)
return rval
def last_access_time(path):
"""Return the number of seconds since the epoch of the last access of a given file"""
return os.stat(path)[stat.ST_ATIME]
def module_name_from_dir(dirname):
"""Scan the contents of a cache directory and return full path of the dynamic lib in it.
"""
files = os.listdir(dirname)
names = [file for file in files if file.endswith('.so')]
if len(names) != 1:
raise Exception('Failed to load .so from dir', dirname)
return os.path.join(dirname, names[0])
class ModuleCache(object):
def __init__(self, dirname, force_fresh=False):
"""Interface to the cache of dynamically compiled modules on disk
Note that this interface does not assume exclusive use of the cache directory.
It is built to handle the case where multiple programs are also using instances of this
class to manage the same directory.
The cache works on the basis of keys. Keys are used to uniquely identify a dynamic module.
Keys should be tuples of length 2: (version, rest)
The ``rest`` can be anything hashable and picklable, that uniquely identifies the
computation in the module.
The ``version`` should be a hierarchy of tuples of integers.
If the ``version`` is either 0 or (), then the corresponding module is unversioned, and
will be deleted in an atexit() handler.
If the ``version`` is neither 0 nor (), then the module will be kept in the cache between
processes, but it may be deleted if another key comes
along that has the same ``rest``, and a ``version`` that is considered higher than the
first one.
:todo: Versioning functionality is planned for implementation later, it is not implemented
yet.
"""
dirname = ""
"""The working directory that is managed by this interface"""
module_from_name = {}
"""maps module names to loaded module objects"""
entry_from_key = {}
"""Maps keys to the filename of a .so
"""
stats = []
"""A list with counters for the number of hits, loads, compiles issued by module_from_key()
"""
force_fresh = False
"""True -> Ignore previously-compiled modules
"""
loaded_key_pkl = set()
"""set of all key.pkl files that have been loaded.
"""
def __init__(self, dirname, force_fresh=None, check_for_broken_eq=True):
"""
:param check_for_broken_eq: A bad __eq__ implemenation can break this cache mechanism.
This option turns on a not-too-expensive sanity check during the load of an old cache
file.
"""
self.dirname = dirname
self.module_from_name = {}
self.name_from_key_filename = os.path.join(self.dirname, 'module_cache.pkl')
self.name_from_key = {}
self.module_from_name = dict(self.module_from_name)
self.entry_from_key = dict(self.entry_from_key)
self.stats = [0, 0, 0]
if not force_fresh:
try:
f = file(self.name_from_key_filename, 'r')
self.name_from_key = cPickle.load(f)
debug('ModuleCache loaded', len(self.name_from_key))
f.close()
except (IOError, EOFError):
debug('cache load failed. Using fresh cache')
pass
def persist(self):
f = file(self.name_from_key_filename, 'w')
cPickle.dump(self.name_from_key, f)
f.close()
self.force_fresh = self.force_fresh if force_fresh is None else force_fresh
self.loaded_key_pkl = set()
self.refresh()
if check_for_broken_eq:
for k0 in self.entry_from_key:
for k1 in self.entry_from_key:
if k0 == k1 and not (k0 is k1):
warning(("The __eq__ and __hash__ functions are broken for some element"
" in the following two keys. The cache mechanism will say that"
" graphs like this need recompiling, when they could have been"
" retrieved):"))
warning("Key 0:", k0)
warning("Key 1:", k1)
def refresh(self):
"""Update self.entry_from_key by walking the cache directory structure.
Add entries that are not in the entry_from_key dictionary.
Remove entries which have been removed from the filesystem.
"""
compilelock.get_lock()
try:
# add entries that are not in the entry_from_key dictionary
for root, dirs, files in os.walk(self.dirname):
if os.path.join(root, 'key.pkl') in self.loaded_key_pkl:
continue
if 'key.pkl' in files:
key_pkl = os.path.join(root, 'key.pkl')
debug('refresh adding', key_pkl)
try:
key = cPickle.load(file(key_pkl))
except:
error("ModuleCache.refresh() Failed to unpickle cache key", key_pkl)
info("Erasing broken file", key_pkl)
os.remove(key_pkl)
continue
if key not in self.entry_from_key:
entry = module_name_from_dir(root)
self.entry_from_key[key] = entry
# assert that we haven't already got this entry somehow
assert entry not in self.module_from_name
self.loaded_key_pkl.add(key_pkl)
# remove entries that are not in the filesystem
items_copy = list(self.entry_from_key.iteritems())
for key, entry in items_copy:
try:
# test to see that the file is [present and] readable
open(entry).close()
gone = False
except IOError:
gone = True
if gone:
# assert that we didn't have one of the deleted files
# loaded up and in use.
# If so, it should not have been deleted. This should be considered a
# failure of the OTHER process, that deleted it.
if entry in self.module_from_name:
error("The module %s that was loaded by this ModuleCache can no longer be read from file... this could lead to problems." % name)
del self.module_from_name[entry]
info("deleting ModuleCache entry", entry)
del self.entry_from_key[key]
self.loaded_key_pkl.remove(os.path.join(os.path.dirname(entry), 'key.pkl'))
finally:
compilelock.release_lock()
def module_from_key(self, key, fn=None):
rval = None
if key in self.name_from_key:
try:
_version, _rest = key
except:
raise ValueError("Invalid key. key must have form (version, rest)", key)
if key in self.entry_from_key:
# we have seen this key either in this process or previously
#debug('OLD KEY HASH', hash(key), hash(key[1][0]), key[1][0])
name = self.name_from_key[key]
name = self.entry_from_key[key]
if name not in self.module_from_name:
#debug('loading name', name)
......@@ -199,49 +328,115 @@ class ModuleCache(object):
#debug("LOCATION*", location)
try:
module = fn(location=location) # WILL FAIL FOR BAD C CODE
finally:
# >>TODO: erase location
pass
debug('NEW KEY HASH', hash(key), hash(key[1][0]), key[1][0])
for k,n in self.name_from_key.iteritems():
if k == key:
debug("HASH OF RELOAD IS DIFFERENT", hash(k), hash(key))
print ''
print hash(k[0])
print hash(key[0])
print ''
print "OLD",
print hash(k[1][0])
print k[1][0].rehash()
print ""
print "NEW", hash(key[1][0]), key[1][0].rehash()
print ''
print hash(k[1][1])
print hash(key[1][1])
assert k != key
except Exception, e:
shutil.rmtree(location)
#try:
#except Exception, ee:
#error('failed to cleanup location', location, ee)
raise
name = module.__file__
#debug("LOCATION**", location)
#debug("NAME**", name)
assert name.startswith(location)
debug("Adding module to cache", key, name)
assert name.startswith(location)
assert name not in self.module_from_name
assert key not in self.name_from_key
self.name_from_key[key] = name
assert key not in self.entry_from_key
key_pkl = os.path.join(location, 'key.pkl')
key_file = file(key_pkl, 'w')
try:
cPickle.dump(key, key_file, cPickle.HIGHEST_PROTOCOL)
key_file.close()
key_broken = False
except cPickle.PicklingError:
key_file.close()
os.remove(key_pkl)
warning("Cache leak due to unpickle-able key", key)
key_broken = True
if _version and not key_broken:
key_from_file = cPickle.load(file(key_pkl))
if key != key_from_file:
raise Exception("key not equal to unpickled version (Hint: verify the __eq__ and __hash__ functions for your Ops", (key, key_from_file))
self.entry_from_key[key] = name
self.module_from_name[name] = module
self.loaded_key_pkl.add(key_pkl)
self.stats[2] += 1
rval = module
#debug('stats', self.stats, sum(self.stats))
return rval
age_thresh = 60*60*24*31
"""The default age threshold for `clear_old` (in seconds)
"""
def clear_old(self, age_thresh=None): #default to a 31-day age_threshold
"""
Delete entries from the filesystem for cache entries that are too old.
:param age_thresh: dynamic modules whose last access time is more than ``age_thresh``
seconds ago will be erased.
"""
age_thresh = self.age_thresh if age_thresh is None else age_thresh
compilelock.get_lock()
try:
# update the age of modules that have been accessed by other processes
self.refresh()
time_now = time.time()
# the .items() is important here:
# we need to get a copy of the whole list of keys and entries
items_copy = list(self.entry_from_key.iteritems())
for key, entry in items_copy:
age = time_now - last_access_time(entry)
if age > age_thresh:
# TODO: we are assuming that modules that haven't been accessed in over
# age_thresh are not currently in use by other processes, but that could be
# false for long-running jobs...
assert entry not in self.module_from_name
del self.entry_from_key[key]
parent = os.path.dirname(entry)
assert parent.startswith(os.path.join(self.dirname, 'tmp'))
debug("Removing cache dir", parent)
shutil.rmtree(parent)
finally:
compilelock.release_lock()
def clear(self):
"""
Clear all the elements of the cache
"""
return self.clear_old(-1.0)
def clear_unversioned(self):
"""Delete unversioned dynamic modules from the internal dictionaries and from the
filesystem.
"""
items_copy = list(self.entry_from_key.iteritems())
for key, entry in items_copy:
version, rest = key
if not version:
del self.entry_from_key[key]
# entry is guaranteed to be in this dictionary,
# because an unversioned entry should never have been loaded via refresh
assert entry in self.module_from_name
del self.module_from_name[entry]
parent = os.path.dirname(entry)
assert parent.startswith(os.path.join(self.dirname, 'tmp'))
debug("Removing unversioned dir", parent)
shutil.rmtree(parent)
def _on_atexit(self):
self.refresh()
self.clear_old()
self.clear_unversioned()
_module_cache = None
def get_module_cache(dirname):
def get_module_cache(dirname, force_fresh=None):
global _module_cache
if _module_cache is None:
_module_cache = ModuleCache(dirname, force_fresh=False)
atexit.register(_module_cache.persist)
_module_cache = ModuleCache(dirname, force_fresh=force_fresh)
atexit.register(_module_cache._on_atexit)
return _module_cache
def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[],
......@@ -263,41 +458,36 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
debug('Writing module C++ code to', cppfilename)
ofiles = []
rval = None
try:
cppfile.write(src_code)
cppfile.close()
lib_filename = os.path.join(workdir, '%s.so'% module_name)
debug('Generating shared lib', lib_filename)
cmd = ['g++', '-shared', '-g']
if no_opt:
cmd.extend(p for p in preargs if not p.startswith('-O'))
else:
cmd.extend(preargs)
cmd.extend('-I%s'%idir for idir in include_dirs)
cmd.extend(['-o',lib_filename])
cmd.append(cppfilename)
cmd.extend(['-L%s'%ldir for ldir in lib_dirs])
cmd.extend(['-l%s'%l for l in libs])
debug('Running cmd', ' '.join(cmd))
cppfile.write(src_code)
cppfile.close()
p = subprocess.Popen(cmd)
status = p.wait()
lib_filename = os.path.join(workdir, '%s.so'% module_name)
if status:
error('g++ return status', status)
else:
#touch the __init__ file
file(os.path.join(workdir, "__init__.py"),'w').close()
debug('Generating shared lib', lib_filename)
cmd = ['g++', '-shared', '-g']
if no_opt:
cmd.extend(p for p in preargs if not p.startswith('-O'))
else:
cmd.extend(preargs)
cmd.extend('-I%s'%idir for idir in include_dirs)
cmd.extend(['-o',lib_filename])
cmd.append(cppfilename)
cmd.extend(['-L%s'%ldir for ldir in lib_dirs])
cmd.extend(['-l%s'%l for l in libs])
debug('Running cmd', ' '.join(cmd))
p = subprocess.Popen(cmd)
status = p.wait()
if status:
error('g++ return status', status)
else:
#touch the __init__ file
file(os.path.join(workdir, "__init__.py"),'w').close()
rval = dlimport(lib_filename)
rval = dlimport(lib_filename)
finally:
warning("TODO: cleanup")
#os.remove(cppfilename)
for ofile in ofiles:
#os.remove(ofiles[0])
pass
return rval
......
......@@ -162,6 +162,16 @@ class CLinkerOp(object):
raise utils.MethodNotDefined('%s.c_support_code' \
% self.__class__.__name__)
def c_code_cache_version(self):
"""Return a tuple of integers indicating the version of this Op.
An empty tuple indicates an 'unversioned' Op that will not be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
"""
return (1,)
class PureOp(object):
"""
An :term:`Op` is a type of operation.
......
......@@ -57,6 +57,9 @@ class TDouble(Type):
free(%(name)s_bad_thing);
""" % locals()
def c_code_cache_version(self):
return ()
tdouble = TDouble()
def double(name):
......@@ -83,6 +86,8 @@ class MyOp(Op):
def perform(self, node, inputs, (out, )):
out[0] = self.impl(*inputs)
def c_code_cache_version(self):
return ()
class Unary(MyOp):
......
......@@ -210,6 +210,16 @@ class CLinkerType(object):
"""
raise MethodNotDefined("c_support_code", type(self), self.__class__.__name__)
def c_code_cache_version(self):
"""Return a tuple of integers indicating the version of this Op.
An empty tuple indicates an 'unversioned' Op that will not be cached between processes.
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
"""
return (1,)
class PureType(object):
"""Interface specification for variable type instances.
......
......@@ -444,6 +444,10 @@ class DenseFromSparse(gof.op.Op):
"""
sparse_grad = True
"""WRITEME"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x):
x = as_sparse_variable(x)
......@@ -495,6 +499,10 @@ csc_from_dense = SparseFromDense('csc')
class Transpose(gof.op.Op):
format_map = {'csr' : 'csc',
'csc' : 'csr'}
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x):
x = as_sparse_variable(x)
return gof.Apply(self,
......@@ -510,6 +518,10 @@ class Transpose(gof.op.Op):
transpose = Transpose()
class Neg(gof.op.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x):
x = as_sparse_variable(x)
return gof.Apply(self, [x], [x.type()])
......@@ -523,6 +535,10 @@ neg = Neg()
class AddSS(gof.op.Op):
'''Add two sparse matrices '''
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x, y = map(as_sparse_variable, [x, y])
if x.type.dtype != y.type.dtype:
......@@ -545,6 +561,10 @@ class AddSS(gof.op.Op):
add_s_s = AddSS()
class AddSD(gof.op.Op):
''' Add a sparse and a dense matrix '''
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x, y = as_sparse_variable(x), tensor.as_tensor_variable(y)
if x.type.dtype != y.type.dtype:
......@@ -586,6 +606,10 @@ def sub(x,y):
class MulSS(gof.op.Op):
''' Elementwise multiply a sparse and a ndarray '''
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x, y = as_sparse_variable(x), as_sparse_variable(y)
if x.type != y.type:
......@@ -605,6 +629,10 @@ class MulSS(gof.op.Op):
mul_s_s = MulSS()
class MulSD(gof.op.Op):
''' Elementwise multiply a sparse and a ndarray '''
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x, y = as_sparse_variable(x), tensor.as_tensor_variable(y)
if x.type.dtype != y.type.dtype:
......@@ -686,6 +714,10 @@ class StructuredDot(gof.Op):
The output is presumed to be a dense matrix, and is represented by a TensorType instance.
"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a, b):
if type(a) is not SparseVariable and type(a) is not SparseConstant:
raise TypeError('First argument must be of type SparseVariable or SparseConstant');
......@@ -750,6 +782,10 @@ def structured_dot(x, y):
return _structured_dot(y.T, x.T).T
class StructuredDotCSC(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_val, a_ind, a_ptr, a_nrows, b):
dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype)
r = gof.Apply(self, [a_val, a_ind, a_ptr, a_nrows, b],
......@@ -900,6 +936,10 @@ class StructuredDotCSC(gof.Op):
sd_csc = StructuredDotCSC()
class StructuredDotCSR(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_val, a_ind, a_ptr, b):
self.dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype)
r = gof.Apply(self, [a_val, a_ind, a_ptr, b],
......@@ -1055,6 +1095,10 @@ def structured_dot_grad(sparse_A, dense_B, ga):
class StructuredDotGradCSC(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_indices, a_indptr, b, g_ab):
return gof.Apply(self, [a_indices, a_indptr, b, g_ab],
[tensor.tensor(g_ab.dtype, (False,))])
......@@ -1155,6 +1199,10 @@ sdg_csc = StructuredDotGradCSC()
class StructuredDotGradCSR(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, a_indices, a_indptr, b, g_ab):
return gof.Apply(self, [a_indices, a_indptr, b, g_ab], [tensor.tensor(b.dtype, (False,))])
......@@ -1256,3 +1304,4 @@ class StructuredDotGradCSR(gof.Op):
"""% dict(locals(), **sub)
sdg_csr = StructuredDotGradCSR()
......@@ -49,6 +49,10 @@ class GemmRelated(Op):
This class provides a kind of templated gemm Op.
"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def c_support_code(self):
#return cblas_header_text()
mod_str = """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论