提交 7d637509 authored 作者: Frederic Bastien's avatar Frederic Bastien

make misc.safe_asarray._asarray understand floatX.

上级 11fd91c9
...@@ -6,6 +6,8 @@ __docformat__ = "restructuredtext en" ...@@ -6,6 +6,8 @@ __docformat__ = "restructuredtext en"
import numpy import numpy
import theano
def _asarray(a, dtype, order=None): def _asarray(a, dtype, order=None):
"""Convert the input to a Numpy array. """Convert the input to a Numpy array.
...@@ -26,6 +28,8 @@ def _asarray(a, dtype, order=None): ...@@ -26,6 +28,8 @@ def _asarray(a, dtype, order=None):
used internally. It is imported so as to be available directly through used internally. It is imported so as to be available directly through
theano._asarray theano._asarray
""" """
if dtype == 'floatX':
dtype = theano.config.floatX
dtype = numpy.dtype(dtype) # Convert into dtype object. dtype = numpy.dtype(dtype) # Convert into dtype object.
rval = numpy.asarray(a, dtype=dtype, order=order) rval = numpy.asarray(a, dtype=dtype, order=order)
# Note that dtype comparison must be done by comparing their `num` # Note that dtype comparison must be done by comparing their `num`
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论