提交 5e6ce10a authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: David Warde-Farley

Ensure that indices are suitable for numpy.take.

On a 32-bit machine, numpy.take may refuse to use int64 arrays for indices. (incl. typo fix by @dwf)
上级 9bd1bff5
......@@ -39,16 +39,18 @@ def _asarray(a, dtype, order=None):
if rval.dtype.num != dtype.num:
# Type mismatch between the data type we asked for, and the one
# returned by numpy.asarray.
if (dtype.num == numpy.dtype(numpy.int32).num or
dtype.num == numpy.dtype(numpy.int64).num):
# If both types have the same string description (byte order, basic
# type, and number of bytes), then it is safe to return a view.
if (dtype.str == rval.dtype.str):
# Silent fix.
return rval.view(dtype=dtype)
else:
# Unexpected mismatch: better know what is going on!
raise TypeError('numpy.array did not return the data type we '
'asked for (%s #%s), instead it returned type %s #%s: function '
'theano._asarray may need to be extended to handle this '
'data type as well.' %
(dtype, dtype.num, rval.dtype, rval.dtype.num))
'asked for (%s %s #%s), instead it returned type '
'%s %s #%s: function '
'theano._asarray may need to be modified to handle this '
'data type.' %
(dtype, dtype.str, dtype.num, rval.dtype, rval.str, rval.dtype.num))
else:
return rval
......@@ -4909,8 +4909,13 @@ class AdvancedSubtensor1(Op):
else:
o = None
# I have read that using clip or wrap mode make it a lot faster when
# the output is provided
# If i.dtype is more precise than numpy.intc (int32 on 32-bit machines,
# int64 on 64-bit machines), numpy may raise the following error:
# TypeError: array cannot be safely cast to required type.
# Since we will probably not have an array with more than 2**31 items
# on a 32-bit arch, I suppose it is safe to cast i into intc.
i = theano._asarray(i, dtype=numpy.intc)
out[0] = x.take(i, axis=0, out=o)
def grad(self, inputs, grads):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论