提交 85019057 authored 作者: amrithasuresh's avatar amrithasuresh

Updated numpy as np

上级 4a98f1e1
...@@ -4,7 +4,7 @@ from textwrap import dedent ...@@ -4,7 +4,7 @@ from textwrap import dedent
import warnings import warnings
import logging import logging
import numpy import numpy as np
from six import integer_types from six import integer_types
from six.moves import xrange from six.moves import xrange
...@@ -58,7 +58,7 @@ def make_constant(args): ...@@ -58,7 +58,7 @@ def make_constant(args):
return slice(conv(a.start), return slice(conv(a.start),
conv(a.stop), conv(a.stop),
conv(a.step)) conv(a.step))
elif isinstance(a, (integer_types, numpy.integer)): elif isinstance(a, (integer_types, np.integer)):
return scal.ScalarConstant(scal.int64, a) return scal.ScalarConstant(scal.int64, a)
else: else:
return a return a
...@@ -355,11 +355,11 @@ class Subtensor(Op): ...@@ -355,11 +355,11 @@ class Subtensor(Op):
if (isinstance(entry, gof.Variable) and if (isinstance(entry, gof.Variable) and
entry.type in tensor_types and entry.type in tensor_types and
numpy.all(entry.type.broadcastable)): np.all(entry.type.broadcastable)):
return scal.get_scalar_type(entry.type.dtype) return scal.get_scalar_type(entry.type.dtype)
elif (isinstance(entry, gof.Type) and elif (isinstance(entry, gof.Type) and
entry in tensor_types and entry in tensor_types and
numpy.all(entry.broadcastable)): np.all(entry.broadcastable)):
return scal.get_scalar_type(entry.dtype) return scal.get_scalar_type(entry.dtype)
elif slice_ok and isinstance(entry, slice): elif slice_ok and isinstance(entry, slice):
a = entry.start a = entry.start
...@@ -385,7 +385,7 @@ class Subtensor(Op): ...@@ -385,7 +385,7 @@ class Subtensor(Op):
slice_c = None slice_c = None
return slice(slice_a, slice_b, slice_c) return slice(slice_a, slice_b, slice_c)
elif isinstance(entry, (integer_types, numpy.integer)): elif isinstance(entry, (integer_types, np.integer)):
# Disallow the use of python scalars in idx_list # Disallow the use of python scalars in idx_list
raise TypeError("Python scalar in idx_list." raise TypeError("Python scalar in idx_list."
"Please report this error to theano-dev.") "Please report this error to theano-dev.")
...@@ -510,8 +510,8 @@ class Subtensor(Op): ...@@ -510,8 +510,8 @@ class Subtensor(Op):
if start is None: if start is None:
start = 0 start = 0
if (p.stop is None or if (p.stop is None or
(isinstance(p.stop, (integer_types, numpy.integer, (isinstance(p.stop, (integer_types, np.integer,
numpy.ndarray)) and np.ndarray)) and
p.stop > start)): p.stop > start)):
broadcastable.append(True) broadcastable.append(True)
continue continue
...@@ -531,7 +531,7 @@ class Subtensor(Op): ...@@ -531,7 +531,7 @@ class Subtensor(Op):
if len(cdata) == 1: if len(cdata) == 1:
cdata = cdata[0] cdata = cdata[0]
out[0] = numpy.asarray(x.__getitem__(cdata)) out[0] = np.asarray(x.__getitem__(cdata))
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
xshp = shapes[0] xshp = shapes[0]
...@@ -681,7 +681,7 @@ class Subtensor(Op): ...@@ -681,7 +681,7 @@ class Subtensor(Op):
return pos[1] return pos[1]
def init_entry(entry, depth=0): def init_entry(entry, depth=0):
if isinstance(entry, (numpy.integer, integer_types)): if isinstance(entry, (np.integer, integer_types)):
init_cmds.append( init_cmds.append(
"subtensor_spec[%i] = %i;" % (spec_pos(), "subtensor_spec[%i] = %i;" % (spec_pos(),
entry)) entry))
...@@ -1390,7 +1390,7 @@ class IncSubtensor(Op): ...@@ -1390,7 +1390,7 @@ class IncSubtensor(Op):
op_is_set = 0 op_is_set = 0
fail = sub['fail'] fail = sub['fail']
view_ndim = (node.inputs[0].ndim - view_ndim = (node.inputs[0].ndim -
numpy.sum([not isinstance(idx, slice) np.sum([not isinstance(idx, slice)
for idx in self.idx_list])) for idx in self.idx_list]))
copy_of_x = self.copy_of_x(x) copy_of_x = self.copy_of_x(x)
...@@ -1712,11 +1712,11 @@ class AdvancedSubtensor1(Op): ...@@ -1712,11 +1712,11 @@ class AdvancedSubtensor1(Op):
# We need to check if values in i can fit in numpy.intp, because # We need to check if values in i can fit in numpy.intp, because
# if they don't, that should be an error (no array can have that # if they don't, that should be an error (no array can have that
# many elements on a 32-bit arch). # many elements on a 32-bit arch).
if i.dtype != numpy.intp: if i.dtype != np.intp:
i_ = theano._asarray(i, dtype=numpy.intp) i_ = theano._asarray(i, dtype=np.intp)
if not numpy.can_cast(i.dtype, numpy.intp): if not np.can_cast(i.dtype, np.intp):
# Check if there was actually an incorrect conversion # Check if there was actually an incorrect conversion
if numpy.any(i != i_): if np.any(i != i_):
raise IndexError( raise IndexError(
'index contains values that are bigger ' 'index contains values that are bigger '
'than the maximum array size on this system.', i) 'than the maximum array size on this system.', i)
...@@ -1946,7 +1946,7 @@ class AdvancedIncSubtensor1(Op): ...@@ -1946,7 +1946,7 @@ class AdvancedIncSubtensor1(Op):
return compile_cutils_code() return compile_cutils_code()
def c_code(self, node, name, input_names, output_names, sub): def c_code(self, node, name, input_names, output_names, sub):
numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]] numpy_ver = [int(n) for n in np.__version__.split('.')[:2]]
if bool(numpy_ver < [1, 8]): if bool(numpy_ver < [1, 8]):
raise NotImplementedError raise NotImplementedError
x, y, idx = input_names x, y, idx = input_names
...@@ -2113,13 +2113,13 @@ def adv_index_broadcastable_pattern(a, idx): ...@@ -2113,13 +2113,13 @@ def adv_index_broadcastable_pattern(a, idx):
if isinstance(v.type, SliceType): if isinstance(v.type, SliceType):
return slice(None, None) return slice(None, None)
return numpy.zeros((2,) * v.ndim, int) return np.zeros((2,) * v.ndim, int)
newidx = tuple(map(replace_slice, idx)) newidx = tuple(map(replace_slice, idx))
# 2 - True = 1; 2 - False = 2 # 2 - True = 1; 2 - False = 2
fakeshape = [2 - bc for bc in a.broadcastable] fakeshape = [2 - bc for bc in a.broadcastable]
retshape = numpy.empty(fakeshape)[newidx].shape retshape = np.empty(fakeshape)[newidx].shape
return tuple([dim == 1 for dim in retshape]) return tuple([dim == 1 for dim in retshape])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论