提交 d01ddc02 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1658 from lamblin/fix_repeatop_bitwidth

In RepeatOp, fix detection of numpy-supported dtypes
......@@ -246,17 +246,17 @@ class RepeatOp(theano.Op):
# Some dtypes are not supported by numpy's implementation of repeat.
# Until another one is available, we should fail at graph construction
# time, not wait for execution.
int_bitwidth = theano.gof.python_int_bitwidth()
if int_bitwidth == 64:
ptr_bitwidth = theano.gof.local_bitwidth()
if ptr_bitwidth == 64:
numpy_unsupported_dtypes = ('uint64',)
if int_bitwidth == 32:
if ptr_bitwidth == 32:
numpy_unsupported_dtypes = ('uint32', 'int64', 'uint64')
if repeats.dtype in numpy_unsupported_dtypes:
raise TypeError(
("dtypes %s are not supported by numpy.repeat "
"for the 'repeats' parameter, "
% numpy_unsupported_dtypes), repeats.dtype)
% str(numpy_unsupported_dtypes)), repeats.dtype)
if self.axis is None:
broadcastable = [False]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论