提交 de366285 authored 作者: Benjamin Scellier's avatar Benjamin Scellier

file theano/gpuarray/type.py

上级 8d809317
from __future__ import absolute_import, print_function, division
import numpy
import numpy as np
import six.moves.copyreg as copyreg
from six import iteritems
import warnings
......@@ -226,7 +226,7 @@ class GpuArrayType(Type):
converted_data = theano._asarray(data, self.dtype)
# We use the `values_eq` static function from TensorType
# to handle NaN values.
if TensorType.values_eq(numpy.asarray(data),
if TensorType.values_eq(np.asarray(data),
converted_data,
force_same_dtype=False):
data = converted_data
......@@ -293,18 +293,18 @@ class GpuArrayType(Type):
return False
if force_same_dtype and a.typecode != b.typecode:
return False
a_eq_b = numpy.asarray(compare(a, '==', b))
a_eq_b = np.asarray(compare(a, '==', b))
if a_eq_b.all():
return True
# maybe the trouble is that there are NaNs
a = numpy.asarray(a)
b = numpy.asarray(b)
a = np.asarray(a)
b = np.asarray(b)
a_missing = numpy.isnan(a)
a_missing = np.isnan(a)
if a_missing.any():
b_missing = numpy.isnan(b)
return numpy.all(a_eq_b + (a_missing == b_missing))
b_missing = np.isnan(b)
return np.all(a_eq_b + (a_missing == b_missing))
else:
return False
......@@ -326,16 +326,16 @@ class GpuArrayType(Type):
rtol_ = rtol
if atol is not None:
atol_ = atol
res = elemwise2(a, '', b, a, odtype=numpy.dtype('bool'),
res = elemwise2(a, '', b, a, odtype=np.dtype('bool'),
op_tmpl="res = (fabs(a - b) <"
"(%(atol_)s + %(rtol_)s * fabs(b)))" %
locals())
ret = numpy.asarray(res).all()
ret = np.asarray(res).all()
if ret:
return True
# maybe the trouble is that there are NaNs
an = numpy.asarray(a)
bn = numpy.asarray(b)
an = np.asarray(a)
bn = np.asarray(b)
return tensor.TensorType.values_eq_approx(
an, bn, allow_remove_inf=allow_remove_inf,
allow_remove_nan=allow_remove_nan, rtol=rtol, atol=atol)
......@@ -408,9 +408,9 @@ class GpuArrayType(Type):
def get_size(self, shape_info):
if shape_info:
return numpy.prod(shape_info) * numpy.dtype(self.dtype).itemsize
return np.prod(shape_info) * np.dtype(self.dtype).itemsize
else:
return numpy.dtype(self.dtype).itemsize
return np.dtype(self.dtype).itemsize
def c_declare(self, name, sub, check_input=True):
return """
......@@ -470,7 +470,7 @@ class GpuArrayType(Type):
'<gpuarray_api.h>']
def c_header_dirs(self):
return [pygpu.get_include(), numpy.get_include()]
return [pygpu.get_include(), np.get_include()]
def c_libraries(self):
return ['gpuarray']
......@@ -509,7 +509,7 @@ class GpuArrayVariable(_operators, Variable):
# override the default
def __repr_test_value__(self):
return repr(numpy.array(theano.gof.op.get_test_value(self)))
return repr(np.array(theano.gof.op.get_test_value(self)))
GpuArrayType.Variable = GpuArrayVariable
......@@ -534,13 +534,13 @@ class GpuArrayConstant(_operators, Constant):
"""
def signature(self):
return GpuArraySignature((self.type, numpy.asarray(self.data)))
return GpuArraySignature((self.type, np.asarray(self.data)))
def __str__(self):
if self.name is not None:
return self.name
try:
np_data = numpy.asarray(self.data)
np_data = np.asarray(self.data)
except gpuarray.GpuArrayException:
np_data = self.data
return "GpuArrayConstant{%s}" % np_data
......@@ -568,7 +568,7 @@ class GpuArraySharedVariable(_operators, SharedVariable):
else:
return self.container.value.copy()
else:
return numpy.asarray(self.container.value)
return np.asarray(self.container.value)
def set_value(self, value, borrow=False):
if isinstance(value, pygpu.gpuarray.GpuArray):
......@@ -601,7 +601,7 @@ def gpuarray_shared_constructor(value, name=None, strict=False,
if target == 'gpu' or target == 'cpu':
raise TypeError('not for me')
if not isinstance(value, (numpy.ndarray, pygpu.gpuarray.GpuArray)):
if not isinstance(value, (np.ndarray, pygpu.gpuarray.GpuArray)):
raise TypeError('ndarray or GpuArray required')
if target is notset:
......@@ -809,7 +809,7 @@ copyreg.constructor(GpuArray_unpickler)
def GpuArray_pickler(cnda):
ctx_name = _name_for_ctx(cnda.context)
return (GpuArray_unpickler, (numpy.asarray(cnda), ctx_name))
return (GpuArray_unpickler, (np.asarray(cnda), ctx_name))
# In case pygpu is not imported.
if pygpu is not None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论