提交 c6398151 authored 作者: lamblin's avatar lamblin

Merge pull request #1416 from nouiz/reduce_dtype_noaxis

[BUGFIX] Reduce dtype noaxis
...@@ -301,6 +301,12 @@ class Scalar(Type): ...@@ -301,6 +301,12 @@ class Scalar(Type):
ret.imag = -this->imag; ret.imag = -this->imag;
return ret; return ret;
} }
bool operator ==(const complex_type &y) const {
return (this->real == y.real) && (this->imag == y.imag);
}
bool operator ==(const npy_float%(nbits)s &y) const {
return (this->real == y) && (this->imag == 0);
}
complex_type operator -(const complex_type &y) const { complex_type operator -(const complex_type &y) const {
complex_type ret; complex_type ret;
ret.real = this->real - y.real; ret.real = this->real - y.real;
......
...@@ -1463,8 +1463,13 @@ class CAReduce(Op): ...@@ -1463,8 +1463,13 @@ class CAReduce(Op):
axis = range(len(input.type.broadcastable)) axis = range(len(input.type.broadcastable))
if len(axis) == 0: if len(axis) == 0:
op = Elemwise(scalar.identity) # The acc_dtype is never a downcast compared to the input dtype
return op._c_all(op.make_node(input), name, inames, onames, sub) # So we just need a cast to the output dtype.
var = theano.tensor.cast(input, node.outputs[0].dtype)
if var is input:
var = Elemwise(scalar.identity)(input)
assert var.dtype == node.outputs[0].dtype
return var.owner.op._c_all(var.owner, name, inames, onames, sub)
order1 = [i for i in xrange(input.type.ndim) if i not in axis] order1 = [i for i in xrange(input.type.ndim) if i not in axis]
order = order1 + list(axis) order = order1 + list(axis)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论