提交 1f5892b3 authored 作者: Frederic's avatar Frederic

Fix BinCount output dtype. It is intp of double depending of the input.

上级 6e9e6d6b
...@@ -141,6 +141,11 @@ class BinCountOp(theano.Op): ...@@ -141,6 +141,11 @@ class BinCountOp(theano.Op):
numpy_unsupported_dtypes = ('uint64',) numpy_unsupported_dtypes = ('uint64',)
if int_bitwidth == 32: if int_bitwidth == 32:
numpy_unsupported_dtypes = ('uint32', 'int64', 'uint64') numpy_unsupported_dtypes = ('uint32', 'int64', 'uint64')
intp_bitwidth = theano.gof.cmodule.local_bitwidth()
if intp_bitwidth == 32:
out_type = basic.ivector()
elif intp_bitwidth == 64:
out_type = basic.lvector()
if x.dtype in numpy_unsupported_dtypes: if x.dtype in numpy_unsupported_dtypes:
raise TypeError( raise TypeError(
...@@ -152,10 +157,9 @@ class BinCountOp(theano.Op): ...@@ -152,10 +157,9 @@ class BinCountOp(theano.Op):
if weights is None: if weights is None:
weights = theano.gof.Constant(theano.gof.Generic(), None) weights = theano.gof.Constant(theano.gof.Generic(), None)
out_type = x.type()
else: else:
weights = basic.as_tensor_variable(weights) weights = basic.as_tensor_variable(weights)
out_type = weights.type() out_type = basic.dvector()
if weights.ndim != 1: if weights.ndim != 1:
raise TypeError("Weights cannot have a number of" raise TypeError("Weights cannot have a number of"
"dimension different of 1.") "dimension different of 1.")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论