提交 a489cb3e authored 作者: James Bergstra's avatar James Bergstra

added c implementation of DimShuffle, some small optimizations to the C opwise linker

上级 5cb8526d
aa.x : aa.cc aa.x : aa.cc
g++ -O3 -ffast-math -ftree-vectorize aa.cc -o aa.x -L${PUB_PREFIX}/lib -lgsl -lmkl g++ -O3 -ffast-math aa.cc -o aa.x -L${PUB_PREFIX}/lib -lgsl ${THEANO_BLAS_LDFLAGS}
#g++ aa.cc -o aa.x -L${PUB_PREFIX}/lib -lgsl -lmkl
clean : clean :
rm aa.x rm aa.x
...@@ -213,6 +213,7 @@ def local_sub_to_gemm(node): ...@@ -213,6 +213,7 @@ def local_sub_to_gemm(node):
#TODO: we actually want to get any scalar here, not necessrily a constant #TODO: we actually want to get any scalar here, not necessrily a constant
mulleft_const = opt.local_mul_canonizer.get_constant(mulleft) mulleft_const = opt.local_mul_canonizer.get_constant(mulleft)
if mulleft_const is not None: if mulleft_const is not None:
assert mulleft_const.size() == 1
mulleft_const = mulleft_const.flatten()[0] mulleft_const = mulleft_const.flatten()[0]
#subleft - (mulleft_const * ?) #subleft - (mulleft_const * ?)
if mulright.owner and (mulright.owner.op == T.add): if mulright.owner and (mulright.owner.op == T.add):
...@@ -422,8 +423,10 @@ class M(module.Module): ...@@ -422,8 +423,10 @@ class M(module.Module):
self.step = module.Method([x], err, updates=dict(updates)) self.step = module.Method([x], err, updates=dict(updates))
mod = M() mod = M()
#mode = 'FAST_RUN' mode = 'FAST_RUN'
mode = ProfileMode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker()) #mode = ProfileMode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker())
mode = Mode(optimizer='fast_run', linker=theano.gof.OpWiseCLinker(nice_errors=True))
mode = Mode(optimizer='fast_run', linker='c')
print mod.pretty(mode=mode) print mod.pretty(mode=mode)
m = mod.make(mode=mode) m = mod.make(mode=mode)
...@@ -443,6 +446,6 @@ try: ...@@ -443,6 +446,6 @@ try:
mode.print_summary() mode.print_summary()
pass pass
except: except:
raise pass
...@@ -686,14 +686,15 @@ class CLinker(link.Linker): ...@@ -686,14 +686,15 @@ class CLinker(link.Linker):
instantiate.customize.add_support_code(support_code) instantiate.customize.add_support_code(support_code)
instantiate.customize.add_support_code(self.struct_code) instantiate.customize.add_support_code(self.struct_code)
instantiate.customize.add_support_code(static) instantiate.customize.add_support_code(static)
for extra_arg in ("-w", #-w means supress all warnings for extra_arg in (
): "-O2",
#"-O3", "-ffast-math",
#"-ffast-math",
#"-fprefetch-loop-arrays", #"-fprefetch-loop-arrays",
#"-ftree-vect-loop-version", #"-ftree-vect-loop-version",
#"-ftree-loop-optimize", #"-ftree-loop-optimize",
#"-ftree-vectorize"): #"-ftree-vectorize"):
"-w" #-w means supress all warnings
):
instantiate.customize.add_extra_compile_arg(extra_arg) instantiate.customize.add_extra_compile_arg(extra_arg)
for arg in self.compile_args(): for arg in self.compile_args():
instantiate.customize.add_extra_compile_arg(arg) instantiate.customize.add_extra_compile_arg(arg)
...@@ -736,7 +737,6 @@ def _execute(cthunk, init_tasks, tasks, error_storage): ...@@ -736,7 +737,6 @@ def _execute(cthunk, init_tasks, tasks, error_storage):
else: else:
return tasks[failure_code - n] return tasks[failure_code - n]
def execute(): def execute():
execute.cthunk = cthunk
failure = cutils.run_cthunk(cthunk) failure = cutils.run_cthunk(cthunk)
if failure: if failure:
task, taskname, id = find_task(failure) task, taskname, id = find_task(failure)
...@@ -748,6 +748,7 @@ def _execute(cthunk, init_tasks, tasks, error_storage): ...@@ -748,6 +748,7 @@ def _execute(cthunk, init_tasks, tasks, error_storage):
exc_value = exc_type(_exc_value, task) exc_value = exc_type(_exc_value, task)
exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared
raise exc_type, exc_value, exc_trace raise exc_type, exc_value, exc_trace
execute.cthunk = cthunk
return execute return execute
...@@ -770,9 +771,12 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -770,9 +771,12 @@ class OpWiseCLinker(link.LocalLinker):
__cache__ = {} __cache__ = {}
def __init__(self, fallback_on_perform = True): def __init__(self,
fallback_on_perform = True,
nice_errors = True):
self.env = None self.env = None
self.fallback_on_perform = fallback_on_perform self.fallback_on_perform = fallback_on_perform
self.nice_errors = nice_errors
def accept(self, env, no_recycling = []): def accept(self, env, no_recycling = []):
if self.env is not None and self.env is not env: if self.env is not None and self.env is not env:
...@@ -842,7 +846,9 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -842,7 +846,9 @@ class OpWiseCLinker(link.LocalLinker):
else: else:
no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs] no_recycling = [storage_map[r] for r in no_recycling if r not in env.inputs]
f = link.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler) f = link.streamline(env, thunks, order,
no_recycling = no_recycling,
nice_errors = self.nice_errors)
return f, [link.Container(input, storage) for input, storage in zip(env.inputs, input_storage)], \ return f, [link.Container(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[link.Container(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \ [link.Container(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
...@@ -850,7 +856,6 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -850,7 +856,6 @@ class OpWiseCLinker(link.LocalLinker):
def _default_checker(x, y): def _default_checker(x, y):
"""WRITEME """WRITEME
Default checker for DualLinker. This checks that the Default checker for DualLinker. This checks that the
......
...@@ -5,6 +5,7 @@ from type import Type ...@@ -5,6 +5,7 @@ from type import Type
import sys, traceback import sys, traceback
from copy import copy from copy import copy
from cutils import run_cthunk
__excepthook = sys.excepthook __excepthook = sys.excepthook
...@@ -225,9 +226,27 @@ def clear_storage_thunk(stg): ...@@ -225,9 +226,27 @@ def clear_storage_thunk(stg):
thunk.inputs = [stg] thunk.inputs = [stg]
return thunk return thunk
def streamline(env, thunks, order, no_recycling = [], profiler = None): def streamline(env, thunks, order, no_recycling = [], profiler = None, nice_errors = True):
"""WRITEME""" """WRITEME
if profiler is None:
:param env:
:param thunks: the list of program instructions
:param order: the list of apply instances that gave rise to the thunks (same order as thunks)
:param no_recycling: storage elements that cannot be 'recycled' by repeatedly executing the
program. These storage elements are cleared before re-running.
:param profiler: deprecated
:param nice_errors: run in such a way that the double-traceback is printed. This costs a
bit of performance in the inner python loop.
"""
if profiler is not None:
raise NotImplementedError()
if nice_errors:
def f(): def f():
for x in no_recycling: for x in no_recycling:
x[0] = None x[0] = None
...@@ -237,14 +256,13 @@ def streamline(env, thunks, order, no_recycling = [], profiler = None): ...@@ -237,14 +256,13 @@ def streamline(env, thunks, order, no_recycling = [], profiler = None):
except: except:
raise_with_op(node) raise_with_op(node)
else: else:
# don't worry about raise_with_op, just go a little faster.
#there is a mix of python and c thunks
def f(): def f():
for x in no_recycling: for x in no_recycling:
x[0] = None x[0] = None
def g(): for thunk in thunks:
for thunk, node in zip(thunks, order): thunk()
profiler.profile_node(thunk, node)
profiler.profile_env(g, env)
f.profiler = profiler
return f return f
class LocalLinker(Linker): class LocalLinker(Linker):
......
...@@ -171,13 +171,15 @@ class ProfileMode(Mode): ...@@ -171,13 +171,15 @@ class ProfileMode(Mode):
%(max(0, len(atimes)-15), sum(t for t, a in atimes[15:])) %(max(0, len(atimes)-15), sum(t for t, a in atimes[15:]))
n_ops_to_print = 20
print 'Op-wise summary: <fraction of local_time spent on this kind of Op> <Op name>' print 'Op-wise summary: <fraction of local_time spent on this kind of Op> <Op name>'
otimes = [(t/local_time, a, self.op_cimpl[a]) for a, t in op_time.items()] otimes = [(t/local_time, a, self.op_cimpl[a]) for a, t in op_time.items()]
otimes.sort() otimes.sort()
otimes.reverse() otimes.reverse()
for t,a,ci in otimes[:15]: for t,a,ci in otimes[:n_ops_to_print]:
print '\t%.3f\t%s %s' % (t, '*' if ci else ' ', a) print '\t%.3f\t%s %s' % (t, '*' if ci else ' ', a)
print ' ... (remaining %i Ops account for %.2f of the runtime)'\ print ' ... (remaining %i Ops account for %.2f of the runtime)'\
%(max(0, len(otimes)-15), sum(t for t, a, ci in otimes[15:])) %(max(0, len(otimes)-n_ops_to_print), sum(t for t, a, ci in
otimes[n_ops_to_print:]))
print '(*) Op is running a c implementation' print '(*) Op is running a c implementation'
...@@ -103,16 +103,18 @@ class DimShuffle(Op): ...@@ -103,16 +103,18 @@ class DimShuffle(Op):
for i, b in enumerate(input_broadcastable): for i, b in enumerate(input_broadcastable):
if i not in new_order: if i not in new_order:
# we want to drop this dimension because it's not a value in new_order # we want to drop this dimension because it's not a value in new_order
if b == 1: if b == 1: # 1 aka True
self.drop.append(i) self.drop.append(i)
else: else:
# we cannot drop non-broadcastable dimensions # we cannot drop non-broadcastable dimensions
raise NotImplementedError("You cannot drop a non-broadcastable dimension.") raise ValueError("You cannot drop a non-broadcastable dimension.")
else: else:
i2j[i] = j i2j[i] = j
j += 1 j += 1
# transposition of non-broadcastable dimensions # transposition of non-broadcastable dimensions
# This is how the dimensions will be permuted, without accounting for the extra
# 'x' broadcastable dimensions to insert.
self.shuffle = [i2j[x] for x in new_order if x != 'x'] self.shuffle = [i2j[x] for x in new_order if x != 'x']
# list of dimensions of the output that are broadcastable and were not in the original input # list of dimensions of the output that are broadcastable and were not in the original input
...@@ -144,7 +146,8 @@ class DimShuffle(Op): ...@@ -144,7 +146,8 @@ class DimShuffle(Op):
and self.input_broadcastable == other.input_broadcastable and self.input_broadcastable == other.input_broadcastable
def __hash__(self): def __hash__(self):
return hash(self.inplace) ^ hash(self.new_order) ^ hash(self.input_broadcastable) return hash(type(self)) ^ hash(self.inplace) \
^ hash(self.new_order) ^ hash(self.input_broadcastable)
def __str__(self): def __str__(self):
if self.inplace: if self.inplace:
...@@ -175,6 +178,73 @@ class DimShuffle(Op): ...@@ -175,6 +178,73 @@ class DimShuffle(Op):
storage[0] = res storage[0] = res
def c_code(self, node, name, (input,), (res,), sub):
def statements(lst):
return ';\n'.join(lst) + ';'
nd_in = len(self.input_broadcastable)
nd_out = len(self.new_order)
check_input_nd = [('if (%(input)s->nd != ' + str(nd_in) + ')'
'{PyErr_SetString(PyExc_NotImplementedError, "input nd"); %(fail)s;}')]
clear_output = ['if (%(res)s) {Py_XDECREF(%(res)s);}']
shape_statements = ['npy_intp dimensions[%i]'%nd_out]
shape_statements += [('dimensions['+str(i)+'] = %(input)s->dimensions['+str(o)+']')
if o != 'x' else
('dimensions['+str(i)+'] = 1')
for i, o in enumerate(self.new_order)]
strides_statements = ['npy_intp strides[%i]'%nd_out]
strides_statements += [('strides['+str(i)+'] = %(input)s->strides['+str(o)+']')
if o != 'x' else
('strides['+str(i)+'] = 0')
for i, o in enumerate(self.new_order)]
if self.inplace:
print "INPLACE"
get_base = ['{ PyArrayObject * base = %(input)s', 'Py_INCREF((PyObject*)base)']
else:
print "NOT INPLACE"
get_base = [('{ PyArrayObject * base = (PyArrayObject*)PyArray_FromAny((PyObject*)%(input)s, NULL,'
'0, 0, NPY_ALIGNED|NPY_ENSURECOPY, NULL)')]
alloc_output = [('%(res)s = (PyArrayObject*)PyArray_New(&PyArray_Type, '
'' + str(nd_out) + ', dimensions, '
'PyArray_TYPE(base), strides, '
'base->data, base->descr->elsize, '
'PyArray_FLAGS(base), NULL)'),
'%(res)s->base = (PyObject*)base',
'}']
full_code = statements(check_input_nd
+ clear_output
+ shape_statements
+ strides_statements
+ get_base
+ alloc_output)
if 0:
print 'C_CODE'
print ''
print self
print "IN BROAD", self.input_broadcastable
print "NEW ORDER", self.new_order
print "SHUFFLE", self.shuffle
print "AUGMENT", self.augment
print '------------'
print ''
print full_code
if 0:
import sys
sys.exit()
return full_code % dict(locals(), **sub)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
gz = as_tensor(gz) gz = as_tensor(gz)
grad_order = ['x'] * len(x.type.broadcastable) grad_order = ['x'] * len(x.type.broadcastable)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论