提交 58f2a073 authored 作者: David Warde-Farley's avatar David Warde-Farley

Merge pull request #47 from goodfeli/improve_indexing_exception

basic indexing with floating point raises better error
...@@ -1260,12 +1260,13 @@ class _tensor_py_operators: ...@@ -1260,12 +1260,13 @@ class _tensor_py_operators:
args = args, args = args,
# Determine if advanced indexing is needed or not # Determine if advanced indexing is needed or not
# The logic is already in Subtensor.convert: if it succeeds, # The logic is already in Subtensor.convert: if it succeeds,
# standard indexing is used, else, advanced indexing # standard indexing is used; if it fails with
# AdvancedIndexingError, advanced indexing
advanced = False advanced = False
for arg in args: for arg in args:
try: try:
Subtensor.convert(arg) Subtensor.convert(arg)
except TypeError: except AdvancedIndexingError:
advanced = True advanced = True
break break
...@@ -2957,6 +2958,12 @@ def transpose(x, **kwargs): ...@@ -2957,6 +2958,12 @@ def transpose(x, **kwargs):
return DimShuffle(x.broadcastable, dims, inplace=True)(tensor_copy(x)) return DimShuffle(x.broadcastable, dims, inplace=True)(tensor_copy(x))
class AdvancedIndexingError(TypeError):
"""A class raised as an exception when Subtensor
is asked to perform advanced indexing """
def __init__(self, *args):
super(AdvancedIndexingError, self).__init__(*args)
class Subtensor(Op): class Subtensor(Op):
"""Return a subtensor view """Return a subtensor view
...@@ -3002,8 +3009,13 @@ class Subtensor(Op): ...@@ -3002,8 +3009,13 @@ class Subtensor(Op):
@staticmethod @staticmethod
def convert(entry, slice_ok=True): def convert(entry, slice_ok=True):
invalid_scal_types = [scal.float64, scal.float32 ]
scal_types = [scal.int64, scal.int32, scal.int16, scal.int8] scal_types = [scal.int64, scal.int32, scal.int16, scal.int8]
tensor_types = [lscalar, iscalar, wscalar, bscalar] tensor_types = [lscalar, iscalar, wscalar, bscalar]
invalid_tensor_types = [fscalar, dscalar, cscalar, zscalar ]
if isinstance(entry, gof.Variable) and (entry.type in invalid_scal_types \
or entry.type in invalid_tensor_types):
raise TypeError("Expected an integer")
if isinstance(entry, gof.Variable) and entry.type in scal_types: if isinstance(entry, gof.Variable) and entry.type in scal_types:
return entry.type return entry.type
elif isinstance(entry, gof.Type) and entry in scal_types: elif isinstance(entry, gof.Type) and entry in scal_types:
...@@ -3041,7 +3053,7 @@ class Subtensor(Op): ...@@ -3041,7 +3053,7 @@ class Subtensor(Op):
elif isinstance(entry, int): elif isinstance(entry, int):
return entry return entry
else: else:
raise TypeError(Subtensor.e_indextype, entry) raise AdvancedIndexingError(Subtensor.e_indextype, entry)
def __init__(self, idx_list): def __init__(self, idx_list):
self.idx_list = tuple(map(self.convert, idx_list)) self.idx_list = tuple(map(self.convert, idx_list))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论