提交 30c6b4bd authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Remove support for python scalars in idx_list.

上级 591fd120
...@@ -347,18 +347,10 @@ class Subtensor(Op): ...@@ -347,18 +347,10 @@ class Subtensor(Op):
slice_c = None slice_c = None
return slice(slice_a, slice_b, slice_c) return slice(slice_a, slice_b, slice_c)
# There is a bug in numpy that results in isinstance(x, int) returning elif isinstance(entry, (int, long, numpy.integer)):
# False for numpy integers. # Disallow the use of python scalars in idx_list
# See <http://projects.scipy.org/numpy/ticket/2235>. raise TypeError("Python scalar in idx_list."
elif isinstance(entry, numpy.integer): "Please report this error to theano-dev.")
return entry
# On Windows 64-bit, shapes are returned as Python long, as they can
# be bigger than what a Python int can hold.
# Shapes should always fit in a numpy.int64, and we support them better
# 2) In Python3, long replaced int. So we must assert it fit in int64.
elif isinstance(entry, (int, long)):
entry64 = numpy.int64(entry)
return entry64
else: else:
raise AdvancedIndexingError(Subtensor.e_indextype, entry) raise AdvancedIndexingError(Subtensor.e_indextype, entry)
...@@ -405,7 +397,6 @@ class Subtensor(Op): ...@@ -405,7 +397,6 @@ class Subtensor(Op):
def __init__(self, idx_list): def __init__(self, idx_list):
self.idx_list = tuple(map(self.convert, idx_list)) self.idx_list = tuple(map(self.convert, idx_list))
self.perform_cache_cdata = None
@staticmethod @staticmethod
def my_as_scalar(a): def my_as_scalar(a):
...@@ -471,18 +462,9 @@ class Subtensor(Op): ...@@ -471,18 +462,9 @@ class Subtensor(Op):
out, = out_ out, = out_
x = inputs[0] x = inputs[0]
# The subtensor (or idx_list) does not depend on the inputs.
# (and cdata was cached on initial call)
if self.perform_cache_cdata is not None:
out[0] = numpy.asarray(x.__getitem__(self.perform_cache_cdata))
return
cdata = get_idx_list(inputs, self.idx_list) cdata = get_idx_list(inputs, self.idx_list)
if len(cdata) == 1: if len(cdata) == 1:
cdata = cdata[0] cdata = cdata[0]
# (first call caches cdata here)
if len(inputs) == 1:
self.perform_cache_cdata = cdata
out[0] = numpy.asarray(x.__getitem__(cdata)) out[0] = numpy.asarray(x.__getitem__(cdata))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论