提交 f707ea95 authored 作者: lamblin's avatar lamblin

Merge pull request #427 from nouiz/fix_reduce

Fix reduce
...@@ -1026,6 +1026,9 @@ class CAReduce(Op): ...@@ -1026,6 +1026,9 @@ class CAReduce(Op):
self.axis.sort() self.axis.sort()
self.axis = tuple(self.axis) self.axis = tuple(self.axis)
self.set_ufunc(scalar_op)
def set_ufunc(self, scalar_op):
# This is probably a speed up of the implementation # This is probably a speed up of the implementation
if isinstance(scalar_op, theano.scalar.basic.Add): if isinstance(scalar_op, theano.scalar.basic.Add):
self.ufunc = numpy.add self.ufunc = numpy.add
...@@ -1078,12 +1081,12 @@ class CAReduce(Op): ...@@ -1078,12 +1081,12 @@ class CAReduce(Op):
def __getstate__(self): def __getstate__(self):
d = copy(self.__dict__) d = copy(self.__dict__)
d.pop('ufunc') d.pop('ufunc', None)
return d return d
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
self.ufunc = numpy.frompyfunc(self.scalar_op.impl, 2, 1) self.set_ufunc(self.scalar_op)
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.axis == other.axis return type(self) == type(other) and self.scalar_op == other.scalar_op and self.axis == other.axis
...@@ -1357,8 +1360,25 @@ class CAReduceDtype(CAReduce): ...@@ -1357,8 +1360,25 @@ class CAReduceDtype(CAReduce):
def __hash__(self): def __hash__(self):
return CAReduce.__hash__(self) ^ hash(self.dtype) return CAReduce.__hash__(self) ^ hash(self.dtype)
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, "dtype"):
# This is needed as old pickled will crash otherwise.
# We need to keep the old dtype behavior as the op
# could be in an apply node with a specified dtype.
self.dtype = "OLD"
def _output_dtype(self, idtype): def _output_dtype(self, idtype):
dtype = self.dtype dtype = self.dtype
if dtype == "OLD":
return dict(
int8='int32',
int16='int32',
int32='int64',
uint8='uint32',
uint16='uint32',
uint32='uint64',
).get(idtype, idtype)
if dtype is None: if dtype is None:
# If input has a discrete dtype, upcast it to 64 # If input has a discrete dtype, upcast it to 64
return dict( return dict(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论