提交 9adfd678 authored 作者: James Bergstra's avatar James Bergstra

merge

from .sharedvalue import shared, shared_constructor from theano.compile.sandbox.sharedvalue import shared, shared_constructor
from .pfunc import pfunc from theano.compile.sandbox.pfunc import pfunc
...@@ -792,14 +792,24 @@ class CLinker(link.Linker): ...@@ -792,14 +792,24 @@ class CLinker(link.Linker):
function raises a KeyError exception. function raises a KeyError exception.
""" """
order = list(self.env.toposort()) return self.cmodule_key_(self.env, self.no_recycling,
env_inputs_dict = dict((i, [-1, pos]) for pos, i in enumerate(self.env.inputs)) compile_args=self.compile_args(),
libraries=self.libraries()
)
@staticmethod
def cmodule_key_(env, no_recycling, compile_args=None, libraries=None):
"""
Do the actual computation of cmodule_key in a static method
to allow it to be reused in scalar.Composite.__eq__
"""
order = list(env.toposort())
env_computed_set = set() env_computed_set = set()
env_inputs_dict = dict((i, [-1, pos]) for pos, i in enumerate(env.inputs))
constant_ids = dict() constant_ids = dict()
op_pos = {} # Apply -> topological position op_pos = {} # Apply -> topological position
rval = ['CLinker.cmodule_key'] # will be cast to tuple on return rval = ['CLinker.cmodule_key'] # will be cast to tuple on return
rval.append(tuple(self.compile_args())) if compile_args is not None: rval.append(tuple(compile_args))
rval.append(tuple(self.libraries())) if libraries is not None: rval.append(tuple(libraries))
version = [] version = []
# assert that every input to every node is one of' # assert that every input to every node is one of'
...@@ -822,16 +832,16 @@ class CLinker(link.Linker): ...@@ -822,16 +832,16 @@ class CLinker(link.Linker):
else: else:
if i.owner is None: if i.owner is None:
assert all( all(out is not None for out in o.outputs) for o in order) assert all( all(out is not None for out in o.outputs) for o in order)
assert all( input.owner is None for input in self.env.inputs) assert all( input.owner is None for input in env.inputs)
raise Exception('what is this?', (i, type(i), i.clients, self.env)) raise Exception('what is this?', (i, type(i), i.clients, env))
if i in self.env.outputs: if i in env.outputs:
rval += [op_pos[i.owner], # outputs rval += [op_pos[i.owner], # outputs
i.owner.outputs.index(i), i.owner.outputs.index(i),
self.env.outputs.index(i)] env.outputs.index(i)]
else: else:
rval += [op_pos[i.owner], i.owner.outputs.index(i)] # temps rval += [op_pos[i.owner], i.owner.outputs.index(i)] # temps
assert rval assert rval
rval.append(i in self.no_recycling) rval.append(i in no_recycling)
return tuple(rval) return tuple(rval)
for node_pos, node in enumerate(order): for node_pos, node in enumerate(order):
......
...@@ -386,7 +386,7 @@ class ModuleCache(object): ...@@ -386,7 +386,7 @@ class ModuleCache(object):
try: try:
module = fn(location=location) # WILL FAIL FOR BAD C CODE module = fn(location=location) # WILL FAIL FOR BAD C CODE
except Exception, e: except Exception, e:
shutil.rmtree(location) _rmtree(location)
#try: #try:
#except Exception, ee: #except Exception, ee:
#error('failed to cleanup location', location, ee) #error('failed to cleanup location', location, ee)
...@@ -515,7 +515,8 @@ class ModuleCache(object): ...@@ -515,7 +515,8 @@ class ModuleCache(object):
def _rmtree(parent): def _rmtree(parent):
try: try:
shutil.rmtree(parent) if not os.getenv('THEANO_NOCLEANUP',0):
shutil.rmtree(parent)
except Exception, e: except Exception, e:
try: try:
# mark this directory for deletion by a future refresh() # mark this directory for deletion by a future refresh()
......
from .. import gof from theano import gof
import sys import sys
......
...@@ -348,6 +348,9 @@ def int_out(*types): ...@@ -348,6 +348,9 @@ def int_out(*types):
def float_out(*types): def float_out(*types):
return float64, return float64,
def upgrade_to_float(*types): def upgrade_to_float(*types):
"""
This upgrade the types to float32 or float64 to don't loose any precision.
"""
conv = {int8: float32, conv = {int8: float32,
int16: float32, int16: float32,
int32: float64, int32: float64,
...@@ -370,8 +373,8 @@ class ScalarOp(Op): ...@@ -370,8 +373,8 @@ class ScalarOp(Op):
def make_node(self, *inputs): def make_node(self, *inputs):
if self.nin >= 0: if self.nin >= 0:
if len(inputs) != self.nin: if len(inputs) != self.nin:
raise TypeError("Wrong number of inputs for %s.make_node (got %i, expected %i)" \ raise TypeError("Wrong number of inputs for %s.make_node (got %i(%s), expected %i)" \
% (self, len(inputs), self.nin)) % (self, len(inputs), str(inputs), self.nin))
inputs = [as_scalar(input) for input in inputs] inputs = [as_scalar(input) for input in inputs]
outputs = [t() for t in self.output_types([input.type for input in inputs])] outputs = [t() for t in self.output_types([input.type for input in inputs])]
if len(outputs) != self.nout: if len(outputs) != self.nout:
...@@ -977,6 +980,7 @@ class Inv(UnaryScalarOp): ...@@ -977,6 +980,7 @@ class Inv(UnaryScalarOp):
inv = Inv(upgrade_to_float, name = 'inv') inv = Inv(upgrade_to_float, name = 'inv')
class Log(UnaryScalarOp): class Log(UnaryScalarOp):
""" log base e """
def impl(self, x): def impl(self, x):
return math.log(x) return math.log(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -994,6 +998,7 @@ class Log(UnaryScalarOp): ...@@ -994,6 +998,7 @@ class Log(UnaryScalarOp):
log = Log(upgrade_to_float, name = 'log') log = Log(upgrade_to_float, name = 'log')
class Log2(UnaryScalarOp): class Log2(UnaryScalarOp):
""" log base 2 """
def impl(self, x): def impl(self, x):
return numpy.log2(x) return numpy.log2(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -1009,6 +1014,7 @@ class Log2(UnaryScalarOp): ...@@ -1009,6 +1014,7 @@ class Log2(UnaryScalarOp):
log2 = Log2(upgrade_to_float, name = 'log2') log2 = Log2(upgrade_to_float, name = 'log2')
class Log10(UnaryScalarOp): class Log10(UnaryScalarOp):
""" log base 10 """
def impl(self, x): def impl(self, x):
return numpy.log10(x) return numpy.log10(x)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -1170,6 +1176,14 @@ class Composite(ScalarOp): ...@@ -1170,6 +1176,14 @@ class Composite(ScalarOp):
implement the loop fusion optimizer (which I have yet to do implement the loop fusion optimizer (which I have yet to do
someday...) someday...)
""" """
def __str__(self):
if hasattr(self, 'name') and self.name:
return self.name
else:
return "%s{%s}" % (self.__class__.__name__, ", ".join(
"%s=%s" % (k, v) for k, v in self.__dict__.items()
if k not in ["name","env","_c_code"] ))
def __init__(self, inputs, outputs): def __init__(self, inputs, outputs):
env = Env(*gof.graph.clone(inputs, outputs)) env = Env(*gof.graph.clone(inputs, outputs))
gof.MergeOptimizer().optimize(env) gof.MergeOptimizer().optimize(env)
...@@ -1233,12 +1247,15 @@ class Composite(ScalarOp): ...@@ -1233,12 +1247,15 @@ class Composite(ScalarOp):
self.nin = len(inputs) self.nin = len(inputs)
self.nout = len(outputs) self.nout = len(outputs)
self.env = env self.env = env
self.inputs_type = tuple([input.type for input in self.env.inputs])
self.outputs_type = tuple([output.type for output in self.env.outputs])
self._rehash()
def output_types(self, input_types): def output_types(self, input_types):
if tuple(input_types) != tuple([input.type for input in self.env.inputs]): if tuple(input_types) != self.inputs_type:
raise TypeError("Wrong types for Composite. Expected %s, got %s." raise TypeError("Wrong types for Composite. Expected %s, got %s."
% (tuple([input.type for input in self.env.inputs]), tuple(input_types))) % (self.inputs_type, tuple(input_types)))
return [output.type for output in self.env.outputs] return self.outputs_type
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
for storage, impl in zip(output_storage, self._impls): for storage, impl in zip(output_storage, self._impls):
...@@ -1259,10 +1276,36 @@ class Composite(ScalarOp): ...@@ -1259,10 +1276,36 @@ class Composite(ScalarOp):
onames), onames),
**sub) **sub)
d['name'] = name d['name'] = name
if not sub.has_key('id'):
#The use of a dummy id is safe as the code is in a separate block.
#It won't generate conflicting variable name.
d['id']='_DUMMY_ID_'
return self._c_code % d return self._c_code % d
def __eq__(self, other): def __eq__(self, other):
return self is other if self is other: return True
if not isinstance(other, self.__class__): return False
if self.nin!=other.nin or self.nout != other.nout: return False
return self._hashval == other._hashval
return self._cmodule_key == other._cmodule_key
def _rehash(self):
#TODO: What no_recycling is used for? What I need to put their?
# no_recycling = []
self._cmodule_key = gof.CLinker.cmodule_key_(self.env, [])
self._hashval = hash(self._cmodule_key)
def __hash__(self): def __hash__(self):
return id(self) return self._hashval
# def __getstate__(self):
# d = copy(self.__dict__)
# d.pop('env')
# d.pop('_impls')
# #TODO: the self._impls must be restored to allow the perform to work.(c version continue to work.
# return d
# def __setstate__(self, d):
# self.__dict__.update(d)
# #TODO: how to restore the _impls?
...@@ -1227,6 +1227,68 @@ register_canonicalize(local_transposed_dot, name='local_transposed_dot') ...@@ -1227,6 +1227,68 @@ register_canonicalize(local_transposed_dot, name='local_transposed_dot')
# # Loop fusion # # # Loop fusion #
# ############### # ###############
@gof.local_optimizer([T.Elemwise, T.Elemwise])
def local_elemwise_fusion(node):
"""
As part of specialisation, we fusion two consecutif elemwise op of the same shape.
For mixed dtype, we let the Compise op do the cast. It let the C compile do the cast.
The number of dimension is validated at call time by theano itself.
TODO:The broadcast flag?
"""
# TODO:implement Composite.__eq__ by using CLinker.cmodule_key() to compare the graph.
#TODO: Merge when nb_clients>1? When this optimisation could introduce duplication of computation? When this will be faster?
if not isinstance(node.op, T.Elemwise):
return False
nb_elemwise=0
inputs=[]#inputs of the new Elemwise op.
s_inputs = []#inputs of the new scalar op.
s_g=[]#graph of scalar, what will by done in the inner loop.
for i in node.inputs:
if i.owner and isinstance(i.owner.op,T.Elemwise) and len(i.clients)<=1:
if len(i.clients)>1:
#should we put this in the first if, then we would go to the elif to don't fuse it?
#if one of the inputs have more then 1 clients and it is an intermediate result. We don't fuse.
print "local_elemwise_fusion: Elemwise inputs have more then 1 client. Don't optimise for now"
return False
nb_elemwise+=1
inputs.extend(i.owner.inputs)
s_input = [scalar.Scalar(x.dtype).make_variable() for x in i.owner.inputs]
s_inputs.extend(s_input)
s_op=i.owner.op.scalar_op(*s_input)
s_g.append(s_op)
else:
if i.owner and isinstance(i.owner.op,T.Elemwise) and len(i.clients)>1:
#should we put this in the first if, then we would go to the elif to don't fuse it?
print "local_elemwise_fusion: inputs have more then 1 client. Don't fuse it for now.!"
return False
inputs.append(i)
s=scalar.Scalar(i.dtype).make_variable()
s_inputs.append(s)
s_g.append(s)
#if no inputs have are an elemwise, their is nothing to fuse.
if nb_elemwise==0:
# print "local_elemwise_fusion: no elemwise in inputs. Nothing to fuse."
return False
otype = node.outputs[0].type
s_new_out=node.op.scalar_op(*s_g)
#create the composite op.
C = scalar.Composite(s_inputs,[s_new_out])
#create the new node.
n=T.Elemwise(C).make_node(*inputs)
assert len(n.outputs)==1
assert node.outputs[0].dtype==n.outputs[0].dtype
# print "local_elemwise_fusion: FUSED",nb_elemwise+1,"elemwise!"
return n.outputs
#register_specialize(local_elemwise_fusion)
# def make_composite(inputs, outputs): # def make_composite(inputs, outputs):
# scalar_inputs = [scalar.Scalar(dtype = i.type.dtype)() for i in inputs] # scalar_inputs = [scalar.Scalar(dtype = i.type.dtype)() for i in inputs]
# def transform(r): # def transform(r):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论