提交 b64abd85 authored 作者: nouiz's avatar nouiz

Merge pull request #956 from lamblin/fix_dtype_32bit

Make sure the output of perform has the exact right dtype
...@@ -2298,7 +2298,8 @@ class MaxAndArgmax(Op): ...@@ -2298,7 +2298,8 @@ class MaxAndArgmax(Op):
max, max_idx = outs max, max_idx = outs
if python_all(axis == range(x.ndim)): if python_all(axis == range(x.ndim)):
axis = None axis = None
max[0] = numpy.asarray(numpy.max(x, axis)) max[0] = theano._asarray(numpy.max(x, axis),
dtype=node.outputs[0].dtype)
max_idx[0] = theano._asarray(numpy.argmax(x, axis), dtype='int64') max_idx[0] = theano._asarray(numpy.argmax(x, axis), dtype='int64')
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
......
...@@ -151,11 +151,14 @@ class BinCountOp(theano.Op): ...@@ -151,11 +151,14 @@ class BinCountOp(theano.Op):
if weights is not None and weights.shape != x.shape: if weights is not None and weights.shape != x.shape:
raise TypeError("All inputs must have the same shape.") raise TypeError("All inputs must have the same shape.")
#Needed for numpy 1.4.1 compatibility #Needed for numpy 1.4.1 compatibility
if self.minlength: if self.minlength:
z[0] = np.bincount(x, weights=weights, minlength=self.minlength) out = np.bincount(x, weights=weights, minlength=self.minlength)
else: else:
z[0] = np.bincount(x, weights=weights) out = np.bincount(x, weights=weights)
z[0] = theano._asarray(out, dtype=node.outputs[0].dtype)
def grad(self, inputs, outputs_gradients): def grad(self, inputs, outputs_gradients):
output = self(*inputs) output = self(*inputs)
......
...@@ -128,7 +128,9 @@ class ArgSortOp(theano.Op): ...@@ -128,7 +128,9 @@ class ArgSortOp(theano.Op):
a = inputs[0] a = inputs[0]
axis = inputs[1] axis = inputs[1]
z = output_storage[0] z = output_storage[0]
z[0] = np.argsort(a, axis, self.kind, self.order) z[0] = theano._asarray(
np.argsort(a, axis, self.kind, self.order),
dtype=node.outputs[0].dtype)
def infer_shape(self, node, inputs_shapes): def infer_shape(self, node, inputs_shapes):
if (isinstance(node.inputs[1], theano.Constant) and if (isinstance(node.inputs[1], theano.Constant) and
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论