提交 ada92aea authored 作者: bergstra@tikuanyin's avatar bergstra@tikuanyin

ModuleCache working

上级 1f2d68ea
...@@ -29,6 +29,7 @@ def info(*args): ...@@ -29,6 +29,7 @@ def info(*args):
sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n') sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger.info(' '.join(str(a) for a in args)) _logger.info(' '.join(str(a) for a in args))
def debug(*args): def debug(*args):
sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
_logger.debug(' '.join(str(a) for a in args)) _logger.debug(' '.join(str(a) for a in args))
def warning(*args): def warning(*args):
sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n') sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n')
...@@ -39,19 +40,14 @@ def error(*args): ...@@ -39,19 +40,14 @@ def error(*args):
from .callcache import CallCache from .callcache import CallCache
_timers = {}
_module_cache = None
def get_module_cache(): def get_module_cache():
global _module_cache return cmodule.get_module_cache(get_compiledir())
if _module_cache is None:
_module_cache = CallCache() #TODO: put a filename here for persistence
return _module_cache
_persistent_module_cache = None _persistent_module_cache = None
def get_persistent_module_cache(): def get_persistent_module_cache():
global _persistent_module_cache global _persistent_module_cache
if _persistent_module_cache is None: if _persistent_module_cache is None:
_persistent_module_cache = CallCache() #TODO: put a filename here for persistence _persistent_module_cache = CallCache(os.path.join(get_compiledir(), 'persistent_cache'))
return _persistent_module_cache return _persistent_module_cache
class CodeBlock: class CodeBlock:
...@@ -746,7 +742,29 @@ class CLinker(link.Linker): ...@@ -746,7 +742,29 @@ class CLinker(link.Linker):
rval = tuple(rval) rval = tuple(rval)
return rval return rval
def compile_cmodule(self): def compile_cmodule(self, location=None):
"""
This method is a callback for `ModuleCache.module_from_key`
"""
location = get_compiledir() if location is None else location
mod = self.build_dynamic_module()
get_lock()
try:
debug("LOCATION", location)
module = self.module_compile_str(
module_name=mod.name,
src_code = mod.code(),
location=location,
include_dirs=[],
libs=self.libraries(),
preargs=self.compile_args())
finally:
release_lock()
return module
def build_dynamic_module(self):
"""Generate the code for this module, compile it, return the imported dynamic module. """Generate the code for this module, compile it, return the imported dynamic module.
""" """
self.code_gen() self.code_gen()
...@@ -755,18 +773,7 @@ class CLinker(link.Linker): ...@@ -755,18 +773,7 @@ class CLinker(link.Linker):
cthunk = object() # dummy so weave can get the type cthunk = object() # dummy so weave can get the type
mod = cmodule.DynamicModule(module_name) mod = cmodule.DynamicModule(module_name)
if 0:
# Eliminate duplicate inputs and outputs from the storage that we will pass to instantiate
out_storage = [x for i, x in enumerate(out_storage) if (i+len(in_storage)) not in self.dupidx]
in_storage = [x for i, x in enumerate(in_storage) if i not in self.dupidx]
argnames = ["i%i" % i for i in xrange(len(in_storage))] \
+ ["o%i" % i for i in xrange(len(out_storage))] \
+ ["orph%i" % i for i in xrange(len(self.orphans))]
# The code of instantiate # The code of instantiate
#code = self.instantiate_code(1+len(argnames)) #the 1 is for error_storage
code = self.instantiate_code(1+len(self.args)) #the 1 is for error_storage code = self.instantiate_code(1+len(self.args)) #the 1 is for error_storage
instantiate = cmodule.ExtFunction('instantiate', code, method=cmodule.METH_VARARGS) instantiate = cmodule.ExtFunction('instantiate', code, method=cmodule.METH_VARARGS)
#['error_storage'] + argnames, #['error_storage'] + argnames,
...@@ -799,19 +806,7 @@ class CLinker(link.Linker): ...@@ -799,19 +806,7 @@ class CLinker(link.Linker):
for header in self.headers(): for header in self.headers():
mod.add_include(header) mod.add_include(header)
get_lock() return mod
try:
module = self.module_compile_str(
module_name=mod.name,
src_code = mod.code(),
location=get_compiledir(),
include_dirs=[],
libs=self.libraries(),
preargs=self.compile_args())
finally:
release_lock()
return module
def cthunk_factory(self, error_storage, in_storage, out_storage): def cthunk_factory(self, error_storage, in_storage, out_storage):
...@@ -831,9 +826,10 @@ class CLinker(link.Linker): ...@@ -831,9 +826,10 @@ class CLinker(link.Linker):
except KeyError: except KeyError:
key = None key = None
if key is None: if key is None:
#if we can't get a key, then forget the cache mechanism
module = self.compile_cmodule() module = self.compile_cmodule()
else: else:
module = get_module_cache().call(self.compile_cmodule, key=key) module = get_module_cache().module_from_key(key=key, fn=self.compile_cmodule)
vars = self.inputs + self.outputs + self.orphans vars = self.inputs + self.outputs + self.orphans
# List of indices that should be ignored when passing the arguments # List of indices that should be ignored when passing the arguments
......
"""Generate and compile C modules for Python """Generate and compile C modules for Python
""" """
import os, tempfile, StringIO, sys, logging, subprocess import os, tempfile, StringIO, sys, logging, subprocess, cPickle, atexit
_logger=logging.getLogger("theano.gof.cmodule") _logger=logging.getLogger("theano.gof.cmodule")
def warning(*args): def warning(*args):
sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n') #sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n')
_logger.warning(' '.join(str(a) for a in args)) _logger.warning(' '.join(str(a) for a in args))
def info(*args): def info(*args):
sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n') #sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger.info(' '.join(str(a) for a in args)) _logger.info(' '.join(str(a) for a in args))
def debug(*args): def debug(*args):
#sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n') #sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
...@@ -115,10 +115,138 @@ class DynamicModule(object): ...@@ -115,10 +115,138 @@ class DynamicModule(object):
#TODO: add_type #TODO: add_type
def dlimport(fullpath, suffix=None):
"""Dynamically load a .so, .dll, or .py file
:type fullpath: string
:param fullpath: a fully-qualified path do a compiled python module
:param suffix: a suffix to strip from the end of fullpath to get the import name
:type suffix: string
:returns: the dynamically loaded module (from __import__)
"""
if suffix is None:
if fullpath.endswith('.so'):
suffix = '.so'
elif fullpath.endswith('.dll'):
suffix = '.dll'
elif fullpath.endswith('.py'):
suffix = '.py'
else:
suffix = ''
rval = None
if fullpath.endswith(suffix):
module_name = '.'.join(fullpath.split(os.path.sep)[-2:])[:-len(suffix)]
else:
raise ValueError('path has wrong suffix', (fullpath, suffix))
workdir = fullpath[:-len(module_name)- 1 - len(suffix)]
#debug("WORKDIR", workdir)
#debug("module_name", module_name)
pathcopy = list(sys.path)
sys.path = [workdir]
try:
rval = __import__(module_name, {}, {}, [module_name])
if not rval:
error('__import__ failed', fullpath)
finally:
sys.path = pathcopy
assert fullpath.startswith(rval.__file__)
return rval
class ModuleCache(object):
def __init__(self, dirname, force_fresh=False):
self.dirname = dirname
self.module_from_name = {}
self.name_from_key_filename = os.path.join(self.dirname, 'module_cache.pkl')
self.name_from_key = {}
self.stats = [0, 0, 0]
if not force_fresh:
try:
f = file(self.name_from_key_filename, 'r')
self.name_from_key = cPickle.load(f)
debug('ModuleCache loaded', len(self.name_from_key))
f.close()
except (IOError, EOFError):
debug('cache load failed. Using fresh cache')
pass
def persist(self):
f = file(self.name_from_key_filename, 'w')
cPickle.dump(self.name_from_key, f)
f.close()
def module_from_key(self, key, fn=None):
rval = None
if key in self.name_from_key:
# we have seen this key either in this process or previously
#debug('OLD KEY HASH', hash(key), hash(key[1][0]), key[1][0])
name = self.name_from_key[key]
if name not in self.module_from_name:
#debug('loading name', name)
self.module_from_name[name] = dlimport(name)
self.stats[1] += 1
else:
self.stats[0] += 1
rval = self.module_from_name[name]
else:
# we have never seen this key before
location = tempfile.mkdtemp(dir=self.dirname)
#debug("LOCATION*", location)
try:
module = fn(location=location) # WILL FAIL FOR BAD C CODE
finally:
# >>TODO: erase location
pass
debug('NEW KEY HASH', hash(key), hash(key[1][0]), key[1][0])
for k,n in self.name_from_key.iteritems():
if k == key:
debug("HASH OF RELOAD IS DIFFERENT", hash(k), hash(key))
print ''
print hash(k[0])
print hash(key[0])
print ''
print "OLD",
print hash(k[1][0])
print k[1][0].rehash()
print ""
print "NEW", hash(key[1][0]), key[1][0].rehash()
print ''
print hash(k[1][1])
print hash(key[1][1])
assert k != key
name = module.__file__
#debug("LOCATION**", location)
#debug("NAME**", name)
assert name.startswith(location)
assert name not in self.module_from_name
assert key not in self.name_from_key
self.name_from_key[key] = name
self.module_from_name[name] = module
self.stats[2] += 1
rval = module
#debug('stats', self.stats, sum(self.stats))
return rval
_module_cache = None
def get_module_cache(dirname):
global _module_cache
if _module_cache is None:
_module_cache = ModuleCache(dirname, force_fresh=False)
atexit.register(_module_cache.persist)
return _module_cache
def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[], def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[],
preargs=[], tmpdir=None): preargs=[], tmpdir=None):
#TODO: don't to the dlimport in this function
preargs= [] if preargs is None else list(preargs) preargs= [] if preargs is None else list(preargs)
preargs.append('-fPIC') preargs.append('-fPIC')
no_opt = False no_opt = False
...@@ -127,7 +255,7 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[] ...@@ -127,7 +255,7 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
include_dirs = ['/usr/include/python2.6'] + include_dirs include_dirs = ['/usr/include/python2.6'] + include_dirs
libs = ['python2.6'] + libs libs = ['python2.6'] + libs
workdir = tempfile.mkdtemp(dir=location) workdir = location
cppfilename = os.path.join(workdir, 'mod.cpp') cppfilename = os.path.join(workdir, 'mod.cpp')
cppfile = file(cppfilename, 'w') cppfile = file(cppfilename, 'w')
...@@ -157,19 +285,12 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[] ...@@ -157,19 +285,12 @@ def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[]
status = p.wait() status = p.wait()
if status: if status:
warning('g++ return status', status) error('g++ return status', status)
else: else:
#touch the __init__ file #touch the __init__ file
file(os.path.join(workdir, "__init__.py"),'w').close() file(os.path.join(workdir, "__init__.py"),'w').close()
#load the module rval = dlimport(lib_filename)
sys.path.insert(0, workdir)
try:
rval = __import__(module_name, {}, {}, [module_name])
if not rval:
debug('__import__ failed')
finally:
del sys.path[0]
finally: finally:
warning("TODO: cleanup") warning("TODO: cleanup")
...@@ -228,7 +349,6 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[ ...@@ -228,7 +349,6 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[
file(os.path.join(workdir, "__init__.py"),'w').close() file(os.path.join(workdir, "__init__.py"),'w').close()
#load the module #load the module
pathcopy = list(sys.path)
sys.path.insert(0, workdir) sys.path.insert(0, workdir)
try: try:
rval = __import__(module_name, {}, {}, [module_name]) rval = __import__(module_name, {}, {}, [module_name])
......
...@@ -59,7 +59,7 @@ class Scalar(Type): ...@@ -59,7 +59,7 @@ class Scalar(Type):
return type(self) == type(other) and other.dtype == self.dtype return type(self) == type(other) and other.dtype == self.dtype
def __hash__(self): def __hash__(self):
return hash(self.dtype) return hash('theano.scalar.Scalar') ^ hash(self.dtype)
def dtype_specs(self): def dtype_specs(self):
try: try:
...@@ -348,7 +348,7 @@ class ScalarOp(Op): ...@@ -348,7 +348,7 @@ class ScalarOp(Op):
return test return test
def __hash__(self): def __hash__(self):
return hash(getattr(self, 'output_types_preference', 0)) return hash(type(self).__name__) ^ hash(getattr(self, 'output_types_preference', 0))
def __str__(self): def __str__(self):
if hasattr(self, 'name') and self.name: if hasattr(self, 'name') and self.name:
......
...@@ -41,6 +41,10 @@ def check_equal_numpy(x, y): ...@@ -41,6 +41,10 @@ def check_equal_numpy(x, y):
compile.register_checker(check_equal_numpy) compile.register_checker(check_equal_numpy)
def hashtype(self):
t = type(self)
return hash(t.__name__) ^ hash(t.__module__)
elemwise.hashtype = hashtype
__oplist_constructor_list = [] __oplist_constructor_list = []
...@@ -305,7 +309,7 @@ class TensorType(Type): ...@@ -305,7 +309,7 @@ class TensorType(Type):
def __hash__(self): def __hash__(self):
"""Hash equal for same kinds of TensorType""" """Hash equal for same kinds of TensorType"""
return hash(type(self)) ^ hash(self.dtype) ^ hash(self.broadcastable) return hashtype(self) ^ hash(self.dtype) ^ hash(self.broadcastable)
ndim = property(lambda self: len(self.broadcastable), doc = "number of dimensions") ndim = property(lambda self: len(self.broadcastable), doc = "number of dimensions")
"""Number of dimensions """Number of dimensions
...@@ -732,7 +736,7 @@ class TensorConstantSignature(tuple): ...@@ -732,7 +736,7 @@ class TensorConstantSignature(tuple):
return (x == a) and (b.shape == y.shape) and (numpy.all(b == y)) return (x == a) and (b.shape == y.shape) and (numpy.all(b == y))
def __hash__(self): def __hash__(self):
a, b = self a, b = self
return hash(type(self)) ^ hash(a) ^ hash(b.shape) return hashtype(self) ^ hash(a) ^ hash(b.shape)
class TensorConstant(Constant, _tensor_py_operators): class TensorConstant(Constant, _tensor_py_operators):
"""Subclass to add the tensor operators to the basic `Constant` class. """Subclass to add the tensor operators to the basic `Constant` class.
...@@ -1607,7 +1611,7 @@ class SetSubtensor(Op): ...@@ -1607,7 +1611,7 @@ class SetSubtensor(Op):
if isinstance(entry, slice) if isinstance(entry, slice)
else entry else entry
for entry in self.idx_list) for entry in self.idx_list)
return hash(type(self)) ^ hash(idx_list) ^ hash(self.inplace) return hashtype(self) ^ hash(idx_list) ^ hash(self.inplace)
def __str__(self): def __str__(self):
indices = [] indices = []
...@@ -2125,7 +2129,7 @@ class Flatten(Op): ...@@ -2125,7 +2129,7 @@ class Flatten(Op):
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.outdim == other.outdim return type(self) == type(other) and self.outdim == other.outdim
def __hash__(self): def __hash__(self):
return hash(type(self))^hash(self.outdim) return hashtype(self)^hash(self.outdim)
def make_node(self, x): def make_node(self, x):
t_x = as_tensor_variable(x) t_x = as_tensor_variable(x)
if self.outdim < 1 or (x.ndim and self.outdim > x.ndim): if self.outdim < 1 or (x.ndim and self.outdim > x.ndim):
...@@ -2277,7 +2281,7 @@ class TensorDotGrad(Op): ...@@ -2277,7 +2281,7 @@ class TensorDotGrad(Op):
return type(self) == type(other) and self.axes == other.axes return type(self) == type(other) and self.axes == other.axes
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ hash(self.axes) ^ 89234 return hashtype(self) ^ hash(self.axes) ^ 89234
def make_node(self, x, y, gz): def make_node(self, x, y, gz):
assert isinstance(x, Variable) assert isinstance(x, Variable)
...@@ -2324,7 +2328,7 @@ class TensorDot(Op): ...@@ -2324,7 +2328,7 @@ class TensorDot(Op):
return type(self) == type(other) and self.axes == other.axes return type(self) == type(other) and self.axes == other.axes
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ hash(self.axes) ^ 89234 return hashtype(self) ^ hash(self.axes) ^ 89234
def make_node(self, x, y): def make_node(self, x, y):
......
...@@ -123,8 +123,15 @@ class DimShuffle(Op): ...@@ -123,8 +123,15 @@ class DimShuffle(Op):
if self.inplace: if self.inplace:
self.view_map = {0: [0]} self.view_map = {0: [0]}
self._hashval = hash(type(self)) ^ hash(self.inplace) \ self._rehash()
^ hash(self.new_order) ^ hash(self.input_broadcastable)
def __getstate__(self):
d = dict(self.__dict__)
del d['_hashval']
return d
def __setstate__(self, d):
self.__dict__.update(d)
self._rehash()
def make_node(self, input): def make_node(self, input):
ib = tuple(input.type.broadcastable) ib = tuple(input.type.broadcastable)
...@@ -148,6 +155,10 @@ class DimShuffle(Op): ...@@ -148,6 +155,10 @@ class DimShuffle(Op):
and self.new_order == other.new_order \ and self.new_order == other.new_order \
and self.input_broadcastable == other.input_broadcastable and self.input_broadcastable == other.input_broadcastable
def _rehash(self):
self._hashval = hash(type(self).__name__) ^ hash(type(self).__module__) ^ hash(self.inplace) \
^ hash(self.new_order) ^ hash(self.input_broadcastable)
def __hash__(self): def __hash__(self):
return self._hashval return self._hashval
...@@ -353,15 +364,13 @@ class Elemwise(Op): ...@@ -353,15 +364,13 @@ class Elemwise(Op):
self.ufunc = None self.ufunc = None
#precompute the hash of this node #precompute the hash of this node
items = self.inplace_pattern.items() self._rehash()
items.sort()
tuple_items = tuple([k for k,v in items] + [(tuple(v) if isinstance(v, (tuple, list)) else v) for k,v in items])
self._hashval = hash(self.scalar_op) ^ hash(tuple_items)
def __getstate__(self): def __getstate__(self):
d = copy(self.__dict__) d = copy(self.__dict__)
d.pop('ufunc') d.pop('ufunc')
d.pop('__epydoc_asRoutine', None) d.pop('__epydoc_asRoutine', None)
d.pop('_hashval')
return d return d
def __setstate__(self, d): def __setstate__(self, d):
...@@ -370,6 +379,7 @@ class Elemwise(Op): ...@@ -370,6 +379,7 @@ class Elemwise(Op):
self.ufunc = numpy.frompyfunc(self.scalar_op.impl, self.scalar_op.nin, self.scalar_op.nout) self.ufunc = numpy.frompyfunc(self.scalar_op.impl, self.scalar_op.nin, self.scalar_op.nout)
else: else:
self.ufunc = None self.ufunc = None
self._rehash()
def make_node(self, *inputs): def make_node(self, *inputs):
""" """
...@@ -429,6 +439,14 @@ class Elemwise(Op): ...@@ -429,6 +439,14 @@ class Elemwise(Op):
return rval return rval
return False return False
def _rehash(self):
items = self.inplace_pattern.items()
items.sort()
tuple_items = tuple([k for k,v in items] + [(tuple(v) if isinstance(v, (tuple, list)) else v) for k,v in items])
h = hash('Elemwise') ^ hash(self.scalar_op) ^ hash(tuple_items)
assert h == getattr(self,'_hashval', h)
self._hashval = h
def __hash__(self): def __hash__(self):
return self._hashval return self._hashval
......
...@@ -94,6 +94,11 @@ class SoftmaxWithBias(gof.Op): ...@@ -94,6 +94,11 @@ class SoftmaxWithBias(gof.Op):
def __init__(self, **kwargs): def __init__(self, **kwargs):
gof.Op.__init__(self, **kwargs) gof.Op.__init__(self, **kwargs)
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return tensor.hashtype(self)
def make_node(self, x, b): def make_node(self, x, b):
x = tensor.as_tensor_variable(x) x = tensor.as_tensor_variable(x)
b = tensor.as_tensor_variable(b) b = tensor.as_tensor_variable(b)
...@@ -266,8 +271,9 @@ class SoftmaxGrad(gof.Op): ...@@ -266,8 +271,9 @@ class SoftmaxGrad(gof.Op):
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return tensor.hashtype(self)
def make_node(self, dy, sm, **kwargs): def make_node(self, dy, sm, **kwargs):
dy = tensor.as_tensor_variable(dy) dy = tensor.as_tensor_variable(dy)
...@@ -437,6 +443,10 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -437,6 +443,10 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
nout=3 nout=3
def __init__(self, **kwargs): def __init__(self, **kwargs):
gof.Op.__init__(self, **kwargs) gof.Op.__init__(self, **kwargs)
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return tensor.hashtype(self)
def make_node(self, x, b, y_idx): def make_node(self, x, b, y_idx):
x = tensor.as_tensor_variable(x) x = tensor.as_tensor_variable(x)
...@@ -608,6 +618,10 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op): ...@@ -608,6 +618,10 @@ class CrossentropySoftmax1HotWithBiasDx (gof.Op):
"""Gradient wrt x of the CrossentropySoftmax1Hot Op""" """Gradient wrt x of the CrossentropySoftmax1Hot Op"""
def __init__(self, **kwargs): def __init__(self, **kwargs):
gof.Op.__init__(self,**kwargs) gof.Op.__init__(self,**kwargs)
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return tensor.hashtype(self)
def make_node(self, dy, sm, y_idx,**kwargs): def make_node(self, dy, sm, y_idx,**kwargs):
dy = tensor.as_tensor_variable(dy) dy = tensor.as_tensor_variable(dy)
sm = tensor.as_tensor_variable(sm) sm = tensor.as_tensor_variable(sm)
...@@ -728,7 +742,7 @@ class CrossentropyCategorical1HotGrad(gof.Op): ...@@ -728,7 +742,7 @@ class CrossentropyCategorical1HotGrad(gof.Op):
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return tensor.hashtype(self)
def make_node(self, g_y, coding_dist, true_one_of_n): def make_node(self, g_y, coding_dist, true_one_of_n):
return gof.Apply(self, [g_y, coding_dist, true_one_of_n], [coding_dist.type()]) return gof.Apply(self, [g_y, coding_dist, true_one_of_n], [coding_dist.type()])
def perform(self, node, (g_y, coding_dist, true_one_of_n), (g_coding_strg,)): def perform(self, node, (g_y, coding_dist, true_one_of_n), (g_coding_strg,)):
...@@ -741,10 +755,6 @@ crossentropy_categorical_1hot_grad = CrossentropyCategorical1HotGrad() ...@@ -741,10 +755,6 @@ crossentropy_categorical_1hot_grad = CrossentropyCategorical1HotGrad()
class CrossentropyCategorical1Hot(gof.Op): class CrossentropyCategorical1Hot(gof.Op):
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
"""Compute the cross entropy between a coding distribution and """Compute the cross entropy between a coding distribution and
a true distribution of the form [0, 0, ... 0, 1, 0, ..., 0] a true distribution of the form [0, 0, ... 0, 1, 0, ..., 0]
...@@ -758,6 +768,11 @@ class CrossentropyCategorical1Hot(gof.Op): ...@@ -758,6 +768,11 @@ class CrossentropyCategorical1Hot(gof.Op):
Op will probably be optimized away in favour of one with a C implementation. Op will probably be optimized away in favour of one with a C implementation.
""" """
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return tensor.hashtype(self)
def make_node(self, coding_dist, true_one_of_n): def make_node(self, coding_dist, true_one_of_n):
""" """
:type coding_dist: dense matrix :type coding_dist: dense matrix
...@@ -906,6 +921,11 @@ class Prepend_scalar_constant_to_each_row(gof.Op): ...@@ -906,6 +921,11 @@ class Prepend_scalar_constant_to_each_row(gof.Op):
val = scalar.constant(val) val = scalar.constant(val)
self.val = val self.val = val
def __eq__(self, other):
return (type(self) == type(other)) and (self.val == other.val)
def __hash__(self):
return tensor.hashtype(self) ^ hash(self.val.value)
def make_node(self, mat): def make_node(self, mat):
#check type of input #check type of input
if not isinstance(mat,gof.Variable) or not mat.type==tensor.matrix().type: if not isinstance(mat,gof.Variable) or not mat.type==tensor.matrix().type:
...@@ -938,6 +958,11 @@ class Prepend_scalar_constant_to_each_row(gof.Op): ...@@ -938,6 +958,11 @@ class Prepend_scalar_constant_to_each_row(gof.Op):
return goutput[:,1:] return goutput[:,1:]
class Prepend_scalar_to_each_row(gof.Op): class Prepend_scalar_to_each_row(gof.Op):
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return tensor.hashtype(self)
def make_node(self, val, mat): def make_node(self, val, mat):
#check type of input #check type of input
if isinstance(val, float): if isinstance(val, float):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论