提交 488246bf authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merged

...@@ -17,10 +17,10 @@ def _asarray(a, dtype=None, order=None): ...@@ -17,10 +17,10 @@ def _asarray(a, dtype=None, order=None):
http://projects.scipy.org/numpy/ticket/870. http://projects.scipy.org/numpy/ticket/870.
Currently, this issue has only been causing trouble when the target Currently, this issue has only been causing trouble when the target
data type is 'int32', on some computers. As a result, this is the only data type is 'int32' or 'int64', on some computers. As a result, we
situation where we may do more than a simple call to ``numpy.asarray``. If silently fix it only in this situation: if a type mismatch is detected
it turns out that a similar problem can occur for more data type, this with another data type, an exception is raised (if that happens, then this
function should be updated accordingly. function may need to be modified to also handle this other data type).
This function's name starts with a '_' to indicate that it is meant to be This function's name starts with a '_' to indicate that it is meant to be
used internally. It is imported so as to be available directly through used internally. It is imported so as to be available directly through
...@@ -28,12 +28,20 @@ def _asarray(a, dtype=None, order=None): ...@@ -28,12 +28,20 @@ def _asarray(a, dtype=None, order=None):
""" """
dtype = numpy.dtype(dtype) # Convert into dtype object. dtype = numpy.dtype(dtype) # Convert into dtype object.
rval = numpy.asarray(a, dtype=dtype, order=order) rval = numpy.asarray(a, dtype=dtype, order=order)
numpy_int32 = numpy.dtype(numpy.int32) if rval.dtype is not dtype:
if (dtype is numpy_int32 and rval.dtype is not numpy_int32): # Type mismatch between the data type we asked for, and the one
# Enfore the numpy.int32 dtype. # returned by numpy.asarray.
return rval.view(dtype=numpy_int32) if (dtype is numpy.dtype(numpy.int32) or
dtype is numpy.dtype(numpy.int64)):
# Silent fix.
return rval.view(dtype=dtype)
else:
# Unexpected mismatch: better know what is going on!
raise TypeError('numpy.array did not return the data type we '
'asked for (%s #%s), instead it returned type %s #%s: function '
'theano._asarray may need to be extended to handle this '
'data type as well.' %
(dtype, dtype.num, rval.dtype, rval.dtype.num))
else: else:
# Using ``numpy.asarray`` should work just fine.
# Debug assert if we want to detect other failure cases (untested):
# assert rval.dtype is dtype
return rval return rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论