提交 1f2d68ea authored 作者: bergstra@tikuanyin's avatar bergstra@tikuanyin

clinker working with callcache

上级 9be3ec1a
import cPickle, logging, sys
_logger=logging.getLogger("theano.gof.callcache")
def warning(*args):
sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n')
_logger.warning(' '.join(str(a) for a in args))
def error(*args):
sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n')
_logger.error(' '.join(str(a) for a in args))
def info(*args):
sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger.info(' '.join(str(a) for a in args))
def debug(*args):
sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
_logger.debug(' '.join(str(a) for a in args))
class CallCache(object):
def __init__(self, filename=None):
self.filename = filename
try:
if filename is None:
raise IOError('bad filename') #just goes to except
f = file(filename, 'r')
self.cache = cPickle.load(f)
f.close()
except IOError:
self.cache = {}
def persist(self, filename=None):
filename = self.filename if filename is None else filename
f = file(filename, 'w')
cPickle.dump(self.cache, f)
f.close()
def call(self, fn, args=(), key=None):
key = (fn, tuple(args)) if key is None else key
if key not in self.cache:
debug('cache miss', len(self.cache))
self.cache[key] = fn(*args)
else:
debug('cache hit', len(self.cache))
return self.cache[key]
def __del__(self):
try:
if self.filename:
self.persist()
except Exception, e:
_logging.error('persist failed', self.filename, e)
...@@ -6,7 +6,7 @@ Defines Linkers that deal with C implementations. ...@@ -6,7 +6,7 @@ Defines Linkers that deal with C implementations.
from copy import copy from copy import copy
import md5 import md5
import re #for set_compiledir import re #for set_compiledir
import os, sys, platform import os, sys, platform, StringIO, time
# weave import # weave import
from scipy import weave from scipy import weave
...@@ -21,6 +21,39 @@ import utils ...@@ -21,6 +21,39 @@ import utils
from compiledir import * from compiledir import *
from compilelock import get_lock, release_lock from compilelock import get_lock, release_lock
import cmodule
import logging
_logger=logging.getLogger("theano.gof.cc")
def info(*args):
sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger.info(' '.join(str(a) for a in args))
def debug(*args):
_logger.debug(' '.join(str(a) for a in args))
def warning(*args):
sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n')
_logger.warning(' '.join(str(a) for a in args))
def error(*args):
sys.stderr.write('ERROR:'+ ' '.join(str(a) for a in args)+'\n')
_logger.error(' '.join(str(a) for a in args))
from .callcache import CallCache
_timers = {}
_module_cache = None
def get_module_cache():
global _module_cache
if _module_cache is None:
_module_cache = CallCache() #TODO: put a filename here for persistence
return _module_cache
_persistent_module_cache = None
def get_persistent_module_cache():
global _persistent_module_cache
if _persistent_module_cache is None:
_persistent_module_cache = CallCache() #TODO: put a filename here for persistence
return _persistent_module_cache
class CodeBlock: class CodeBlock:
"""WRITEME """WRITEME
Represents a computation unit composed of declare, behavior, and cleanup. Represents a computation unit composed of declare, behavior, and cleanup.
...@@ -324,6 +357,7 @@ class CLinker(link.Linker): ...@@ -324,6 +357,7 @@ class CLinker(link.Linker):
self.env = env self.env = env
self.fetch_variables() self.fetch_variables()
self.no_recycling = no_recycling self.no_recycling = no_recycling
self.module_compile_str = cmodule.gcc_module_compile_str
return self return self
def fetch_variables(self): def fetch_variables(self):
...@@ -500,6 +534,8 @@ class CLinker(link.Linker): ...@@ -500,6 +534,8 @@ class CLinker(link.Linker):
self.tasks = tasks self.tasks = tasks
all = self.inputs + self.outputs + self.orphans all = self.inputs + self.outputs + self.orphans
assert (self.init_tasks, self.tasks) == self.get_init_tasks()
# List of indices that should be ignored when passing the arguments # List of indices that should be ignored when passing the arguments
# (basically, everything that the previous call to uniq eliminated) # (basically, everything that the previous call to uniq eliminated)
self.dupidx = [i for i, x in enumerate(all) if all.count(x) > 1 and all.index(x) != i] self.dupidx = [i for i, x in enumerate(all) if all.count(x) > 1 and all.index(x) != i]
...@@ -609,6 +645,19 @@ class CLinker(link.Linker): ...@@ -609,6 +645,19 @@ class CLinker(link.Linker):
[link.Container(output, storage, True) for output, storage in zip(self.env.outputs, output_storage)], \ [link.Container(output, storage, True) for output, storage in zip(self.env.outputs, output_storage)], \
error_storage error_storage
def get_init_tasks(self):
init_tasks = []
tasks = []
id=1
for v in self.variables:
init_tasks.append((v, 'init', id))
tasks.append((v, 'get', id+1))
id += 2
for node in self.node_order:
tasks.append((node, 'code', id))
id += 1
return init_tasks, tasks
def make_thunk(self, input_storage = None, output_storage = None): def make_thunk(self, input_storage = None, output_storage = None):
"""WRITEME """WRITEME
Compiles this linker's env and returns a function to perform the Compiles this linker's env and returns a function to perform the
...@@ -632,17 +681,138 @@ class CLinker(link.Linker): ...@@ -632,17 +681,138 @@ class CLinker(link.Linker):
f() f()
first_output = ostor[0].data first_output = ostor[0].data
""" """
# Note: acquiring the lock here may not be necessary. However, it is init_tasks, tasks = self.get_init_tasks()
# cheap enough that it should not matter. cthunk, in_storage, out_storage, error_storage = self.__compile__(input_storage, output_storage)
res = _execute(cthunk, init_tasks, tasks, error_storage), in_storage, out_storage
return res
def cmodule_key(self):
"""Return a complete hashable signature of the module we compiled
The signature has the following form:
{{{
'CLinker.cmodule_key',
op0, (input0.type, input1.type, input0 pos, input1 pos)
op1, (...)
...
opK, (...)
}}}
The signature is a tuple of tuples.
The outer tuple has one element for every node in the topological ordering of
`self.env`.
The inner tuple has one element for the op used at that node, and one element for the
inputs to that node. The inputs are identified by their type and "graph position"
The graph position of a typical variable is encoded by integer pairs ``(a,b)``:
``a`` is the topological position of the input's owner (-1 for graph inputs),
``b`` is the index of the variable in the owner's output list.
The graph position of a Constant is defined as its signature.
If the Op of any Apply in the Env does not have c_code_cache_ok()==True, then this
function raises a KeyError exception.
"""
order = list(self.env.toposort())
env_inputs_set = dict((i, (-1, pos)) for pos, i in enumerate(self.env.inputs))
env_computed_set = set()
op_pos = {} # Apply -> topological position
rval = ['CLinker.cmodule_key'] # will be cast to tuple on return
# assert that every input to every node is one of'
# - an env input
# - an output from a node in the Env
# - a Constant
def graphpos(i):
if isinstance(i, graph.Constant):
return i.signature()
elif i in env_inputs_set:
return env_inputs_set[i]
else:
if i.owner is None:
assert all( all(out is not None for out in o.outputs) for o in order)
assert all( input.owner is None for input in self.env.inputs)
raise Exception('what is this?', (i, type(i), i.clients, self.env))
return (op_pos[i.owner], i.owner.outputs.index(i))
for opos, o in enumerate(order):
rval.append((o.op, tuple((i.type, graphpos(i)) for i in o.inputs)))
op_pos[o] = opos
env_computed_set.update(o.outputs)
rval = tuple(rval)
return rval
def compile_cmodule(self):
"""Generate the code for this module, compile it, return the imported dynamic module.
"""
self.code_gen()
module_name = self.hash
cthunk = object() # dummy so weave can get the type
mod = cmodule.DynamicModule(module_name)
if 0:
# Eliminate duplicate inputs and outputs from the storage that we will pass to instantiate
out_storage = [x for i, x in enumerate(out_storage) if (i+len(in_storage)) not in self.dupidx]
in_storage = [x for i, x in enumerate(in_storage) if i not in self.dupidx]
argnames = ["i%i" % i for i in xrange(len(in_storage))] \
+ ["o%i" % i for i in xrange(len(out_storage))] \
+ ["orph%i" % i for i in xrange(len(self.orphans))]
# The code of instantiate
#code = self.instantiate_code(1+len(argnames)) #the 1 is for error_storage
code = self.instantiate_code(1+len(self.args)) #the 1 is for error_storage
instantiate = cmodule.ExtFunction('instantiate', code, method=cmodule.METH_VARARGS)
#['error_storage'] + argnames,
#local_dict = d,
#global_dict = {})
# Static methods that can run and destroy the struct built by instantiate.
static = """
int %(struct_name)s_executor(%(struct_name)s* self) {
return self->run();
}
void %(struct_name)s_destructor(void* executor, void* self) {
//printf("doing cleanup\\n");
//fflush(stdout);
// ((%(struct_name)s*)self)->cleanup();
// free(self);
delete ((%(struct_name)s*)self);
//printf("done cleanup\\n");
//fflush(stdout);
}
""" % dict(struct_name = self.struct_name)
# We add all the support code, compile args, headers and libs we need.
for support_code in self.support_code():
mod.add_support_code(support_code)
mod.add_support_code(self.struct_code)
mod.add_support_code(static)
mod.add_function(instantiate)
for header in self.headers():
mod.add_include(header)
get_lock() get_lock()
try: try:
cthunk, in_storage, out_storage, error_storage = self.__compile__(input_storage, output_storage) module = self.module_compile_str(
res = _execute(cthunk, self.init_tasks, self.tasks, error_storage), in_storage, out_storage module_name=mod.name,
except: src_code = mod.code(),
location=get_compiledir(),
include_dirs=[],
libs=self.libraries(),
preargs=self.compile_args())
finally:
release_lock() release_lock()
raise
release_lock() return module
return res
def cthunk_factory(self, error_storage, in_storage, out_storage): def cthunk_factory(self, error_storage, in_storage, out_storage):
"""WRITEME """WRITEME
...@@ -656,106 +826,43 @@ class CLinker(link.Linker): ...@@ -656,106 +826,43 @@ class CLinker(link.Linker):
outputs in out_storage and if an error occurs will put the outputs in out_storage and if an error occurs will put the
type, value and traceback of the exception in error_storage. type, value and traceback of the exception in error_storage.
""" """
try:
key = self.cmodule_key()
except KeyError:
key = None
if key is None:
module = self.compile_cmodule()
else:
module = get_module_cache().call(self.compile_cmodule, key=key)
get_lock() vars = self.inputs + self.outputs + self.orphans
# List of indices that should be ignored when passing the arguments
# (basically, everything that the previous call to uniq eliminated)
dupidx = [i for i, x in enumerate(vars) if vars.count(x) > 1 and vars.index(x) != i]
try: out_storage = [x for i, x in enumerate(out_storage) if (i+len(in_storage)) not in dupidx]
in_storage = [x for i, x in enumerate(in_storage) if i not in dupidx]
orphd = [[orphan.data] for orphan in self.orphans]
# check if we already compiled this ret = module.instantiate(error_storage, *(in_storage + out_storage + orphd))
if not getattr(self, 'instantiate', False):
self.code_gen()
module_name = self.hash
# Eliminate duplicate inputs and outputs from the storage that we will pass to instantiate
out_storage = [x for i, x in enumerate(out_storage) if (i+len(in_storage)) not in self.dupidx]
in_storage = [x for i, x in enumerate(in_storage) if i not in self.dupidx]
cthunk = object() # dummy so weave can get the type
mod = weave.ext_tools.ext_module(module_name)
argnames = ["i%i" % i for i in xrange(len(in_storage))] \
+ ["o%i" % i for i in xrange(len(out_storage))] \
+ ["orph%i" % i for i in xrange(len(self.orphans))]
# The code of instantiate
code = """
%(struct_name)s* struct_ptr = new %(struct_name)s();
struct_ptr->init(error_storage, %(args)s);
PyObject* thunk = PyCObject_FromVoidPtrAndDesc((void*)(&%(struct_name)s_executor), struct_ptr, %(struct_name)s_destructor);
return thunk;
// return_val = thunk; // oh my god weave why does this leak >:\
""" % dict(struct_name = self.struct_name,
args = ", ".join(argnames))
d = dict(error_storage = object())
for argname in argnames:
d[argname] = object()
instantiate = weave.ext_tools.ext_function('instantiate',
code,
['error_storage'] + argnames,
local_dict = d,
global_dict = {})
# Static methods that can run and destroy the struct built by instantiate.
static = """
int %(struct_name)s_executor(%(struct_name)s* self) {
return self->run();
}
void %(struct_name)s_destructor(void* executor, void* self) {
//printf("doing cleanup\\n");
//fflush(stdout);
((%(struct_name)s*)self)->cleanup();
free(self);
//printf("done cleanup\\n");
//fflush(stdout);
}
""" % dict(struct_name = self.struct_name)
# We add all the support code, compile args, headers and libs we need.
for support_code in self.support_code():
instantiate.customize.add_support_code(support_code)
instantiate.customize.add_support_code(self.struct_code)
instantiate.customize.add_support_code(static)
for arg in self.compile_args():
instantiate.customize.add_extra_compile_arg(arg)
for header in self.headers():
instantiate.customize.add_header(header)
for lib in self.libraries():
instantiate.customize.add_library(lib)
mod.add_function(instantiate)
mod.compile(location = get_compiledir())
module = __import__("%s" % (module_name), {}, {}, [module_name])
self.instantiate = module.instantiate
else:
# Eliminate duplicate inputs and outputs from the storage that we will pass to instantiate
out_storage = [x for i, x in enumerate(out_storage) if (i+len(in_storage)) not in self.dupidx]
in_storage = [x for i, x in enumerate(in_storage) if i not in self.dupidx]
module_name = self.hash
module = __import__("%s" % (module_name), {}, {}, [module_name])
orphd = [[orphan.data] for orphan in self.orphans]
ret = module.instantiate(error_storage, *(in_storage + out_storage + orphd))
#win pdb add 3 ref count, so we disable it by default.
#assert sys.getrefcount(ret) == 2 # refcount leak check
except:
release_lock()
raise
release_lock()
return ret return ret
def instantiate_code(self, n_args):
code = StringIO.StringIO()
struct_name = self.struct_name
print >> code, "static PyObject * instantiate(PyObject * self, PyObject *argtuple) {"
print >> code, ' assert(PyTuple_Check(argtuple));'
print >> code, ' if (%(n_args)i != PyTuple_Size(argtuple)){ ' %locals()
print >> code, ' PyErr_Format(PyExc_TypeError, "Wrong number of arguments, expected %(n_args)i, got %%i", (int)PyTuple_Size(argtuple));' %locals()
print >> code, ' return NULL;'
print >> code, ' }'
print >> code, ' %(struct_name)s* struct_ptr = new %(struct_name)s();' %locals()
print >> code, ' ', ''.join('Py_INCREF(PyTuple_GET_ITEM(argtuple, %i));'%n for n in xrange(n_args))
print >> code, ' struct_ptr->init(', ','.join('PyTuple_GET_ITEM(argtuple, %i)'%n for n in xrange(n_args)), ');'
print >> code, ' PyObject* thunk = PyCObject_FromVoidPtrAndDesc((void*)(&%(struct_name)s_executor), struct_ptr, %(struct_name)s_destructor);' %locals()
print >> code, " return thunk; }"
return code.getvalue()
def _execute(cthunk, init_tasks, tasks, error_storage): def _execute(cthunk, init_tasks, tasks, error_storage):
"""WRITEME""" """WRITEME"""
...@@ -828,7 +935,8 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -828,7 +935,8 @@ class OpWiseCLinker(link.LocalLinker):
def make_all(self, profiler = None, input_storage = None, output_storage = None): def make_all(self, profiler = None, input_storage = None, output_storage = None):
# Acquire lock on compilation directory. # Acquire lock on compilation directory, and
# hold it throughout the compilation of all internal nodes.
get_lock() get_lock()
try: try:
...@@ -849,7 +957,6 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -849,7 +957,6 @@ class OpWiseCLinker(link.LocalLinker):
node_output_storage = [storage_map[r] for r in node.outputs] node_output_storage = [storage_map[r] for r in node.outputs]
try: try:
e = Env(*graph.clone(node.inputs, node.outputs)) e = Env(*graph.clone(node.inputs, node.outputs))
e.toposort = lambda: e.nodes
if any(isinstance(input, graph.Value) for input in node.inputs): if any(isinstance(input, graph.Value) for input in node.inputs):
desc = None desc = None
......
"""Generate and compile C modules for Python
"""
import os, tempfile, StringIO, sys, logging, subprocess
_logger=logging.getLogger("theano.gof.cmodule")
def warning(*args):
sys.stderr.write('WARNING:'+ ' '.join(str(a) for a in args)+'\n')
_logger.warning(' '.join(str(a) for a in args))
def info(*args):
sys.stderr.write('INFO:'+ ' '.join(str(a) for a in args)+'\n')
_logger.info(' '.join(str(a) for a in args))
def debug(*args):
#sys.stderr.write('DEBUG:'+ ' '.join(str(a) for a in args)+'\n')
_logger.debug(' '.join(str(a) for a in args))
METH_VARARGS="METH_VARARGS"
METH_NOARGS="METH_NOARGS"
class ExtFunction(object):
"""A C function to put into a DynamicModule """
name = ""
"""string - function's name"""
code_block = ""
"""string - the entire code for the function. Has the form ``static PyObject*
<name>([...]){ ... }
See Python's C API Reference for how to write c functions for python modules.
"""
method = ""
"""str - calling method for this function (i.e. 'METH_VARARGS', 'METH_NOARGS')"""
doc = ""
"""str - documentation string for this function"""
def __init__(self, name, code_block, method, doc="undocumented"):
self.name = name
self.code_block = code_block
self.method = method
self.doc = doc
def method_decl(self):
"""Returns the signature for this function that goes into the DynamicModule's method table"""
return '\t{"%s", %s, %s, "%s"}' %(self.name, self.name, self.method, self.doc)
class DynamicModule(object):
def __init__(self, name):
self.name = name
self.support_code = []
self.functions = []
self.includes = ["<Python.h>", "<iostream>"]
self.includes.append('<numpy/arrayobject.h>') #TODO: this should come from TensorType
self.init_blocks = ['import_array();'] #TODO: from TensorType
def print_methoddef(self, stream):
print >> stream, "static PyMethodDef MyMethods[] = {"
for f in self.functions:
print >> stream, f.method_decl(), ','
print >> stream, "\t{NULL, NULL, 0, NULL}"
print >> stream, "};"
def print_init(self, stream):
print >> stream, "PyMODINIT_FUNC init%s(void){" % self.name
for b in self.init_blocks:
print >> stream, ' ', b
print >> stream, ' ', '(void) Py_InitModule("%s", MyMethods);' % self.name
print >> stream, "}"
def add_include(self, str):
self.includes.append(str)
def add_init_code(self, code):
self.init_blocks.append(code)
def add_support_code(self, code):
if code not in self.support_code: #TODO: KLUDGE
self.support_code.append(code)
def add_function(self, fn):
self.functions.append(fn)
def code(self):
sio = StringIO.StringIO()
for inc in self.includes:
print >> sio, "#include", inc
print >> sio, "//////////////////////"
print >> sio, "//// Support Code"
print >> sio, "//////////////////////"
for sc in self.support_code:
print >> sio, sc
print >> sio, "//////////////////////"
print >> sio, "//// Functions"
print >> sio, "//////////////////////"
for f in self.functions:
print >> sio, f.code_block
print >> sio, "//////////////////////"
print >> sio, "//// Module init"
print >> sio, "//////////////////////"
self.print_methoddef(sio)
self.print_init(sio)
return sio.getvalue()
def list_code(self, ofile=sys.stdout):
"""Print out the code with line numbers to `ofile` """
for i, line in enumerate(self.code().split('\n')):
print >> ofile, '%4i'%(i+1), line
ofile.flush()
#TODO: add_type
def gcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[],
preargs=[], tmpdir=None):
preargs= [] if preargs is None else list(preargs)
preargs.append('-fPIC')
no_opt = False
#TODO: where to find these strings? sys? distutils?
include_dirs = ['/usr/include/python2.6'] + include_dirs
libs = ['python2.6'] + libs
workdir = tempfile.mkdtemp(dir=location)
cppfilename = os.path.join(workdir, 'mod.cpp')
cppfile = file(cppfilename, 'w')
debug('Writing module C++ code to', cppfilename)
ofiles = []
rval = None
try:
cppfile.write(src_code)
cppfile.close()
lib_filename = os.path.join(workdir, '%s.so'% module_name)
debug('Generating shared lib', lib_filename)
cmd = ['g++', '-shared', '-g']
if no_opt:
cmd.extend(p for p in preargs if not p.startswith('-O'))
else:
cmd.extend(preargs)
cmd.extend('-I%s'%idir for idir in include_dirs)
cmd.extend(['-o',lib_filename])
cmd.append(cppfilename)
cmd.extend(['-L%s'%ldir for ldir in lib_dirs])
cmd.extend(['-l%s'%l for l in libs])
debug('Running cmd', ' '.join(cmd))
p = subprocess.Popen(cmd)
status = p.wait()
if status:
warning('g++ return status', status)
else:
#touch the __init__ file
file(os.path.join(workdir, "__init__.py"),'w').close()
#load the module
sys.path.insert(0, workdir)
try:
rval = __import__(module_name, {}, {}, [module_name])
if not rval:
debug('__import__ failed')
finally:
del sys.path[0]
finally:
warning("TODO: cleanup")
#os.remove(cppfilename)
for ofile in ofiles:
#os.remove(ofiles[0])
pass
return rval
def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[], lib_dirs=[], libs=[],
preargs=[], tmpdir=None):
preargs= [] if preargs is None else list(preargs)
preargs.append('-fPIC')
no_opt = False
#TODO: -O preargs should be passed globally, not to -Xcompiler
#TODO: where to find these strings? sys? distutils?
include_dirs = ['/usr/include/python2.6'] + include_dirs
libs = ['python2.6', 'cudart'] + libs
lib_dirs = ['/usr/local/cuda/lib']+lib_dirs
workdir = tempfile.mkdtemp(dir=location)
cppfilename = os.path.join(workdir, 'mod.cpp') #.cpp to use g++
cppfilename = os.path.join(workdir, 'mod.cu') #.cu to use nvopencc
cppfile = file(cppfilename, 'w')
debug('Writing module C++ code to', cppfilename)
ofiles = []
rval = None
try:
cppfile.write(src_code)
cppfile.close()
lib_filename = os.path.join(workdir, '%s.so'% module_name)
debug('Generating shared lib', lib_filename)
cmd = ['nvcc', '-shared', '-g']
cmd.extend(['-Xcompiler', ','.join(preargs)])
cmd.extend('-I%s'%idir for idir in include_dirs)
cmd.extend(['-o',lib_filename])
cmd.append(cppfilename)
cmd.extend(['-L%s'%ldir for ldir in lib_dirs])
cmd.extend(['-l%s'%l for l in libs])
debug('Running cmd', ' '.join(cmd))
p = subprocess.Popen(cmd)
status = p.wait()
if status:
warning('nvcc return status', status)
else:
#touch the __init__ file
file(os.path.join(workdir, "__init__.py"),'w').close()
#load the module
pathcopy = list(sys.path)
sys.path.insert(0, workdir)
try:
rval = __import__(module_name, {}, {}, [module_name])
if not rval:
debug('__import__ failed')
finally:
del sys.path[0]
assert pathcopy == sys.path
finally:
warning("TODO: cleanup")
#os.remove(cppfilename)
for ofile in ofiles:
#os.remove(ofiles[0])
pass
return rval
def icc_module_compile_str(*args):
raise NotImplementedError()
...@@ -434,6 +434,10 @@ class Env(utils.object2): ...@@ -434,6 +434,10 @@ class Env(utils.object2):
{node: predecessors} where predecessors is a list of nodes {node: predecessors} where predecessors is a list of nodes
that should be computed before the key node. that should be computed before the key node.
""" """
if len(self.nodes) < 2:
# optimization
# when there are 0 or 1 nodes, no sorting is necessary
return list(self.nodes)
env = self env = self
ords = {} ords = {}
for feature in env._features: for feature in env._features:
......
...@@ -126,7 +126,7 @@ class SoftmaxWithBias(gof.Op): ...@@ -126,7 +126,7 @@ class SoftmaxWithBias(gof.Op):
return dx, db return dx, db
def c_headers(self): def c_headers(self):
return ['<iostream> <math>'] return ['<iostream>','<cmath>']
@staticmethod @staticmethod
def c_code_template(): def c_code_template():
...@@ -514,7 +514,8 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op): ...@@ -514,7 +514,8 @@ class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
db = tensor.sum(dx, axis = [0]) db = tensor.sum(dx, axis = [0])
return dx, db, None return dx, db, None
def c_headers(self): return ['<iostream>'] def c_headers(self):
return ['<iostream>', '<cmath>']
@staticmethod @staticmethod
def c_code_template(): def c_code_template():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论