提交 10082356 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Clean up Subtensor Op and its exceptions

上级 7857774b
......@@ -432,40 +432,8 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
class Subtensor(Op):
"""
Return a subtensor view.
The inputs array is the tensor x, followed by scalar integer types.
TODO: WRITEME: how are the scalar integer variables formatted?
This class uses a relatively complex internal representation of the inputs
to remember how the input tensor x should be sliced.
idx_list: instance variable TODO: WRITEME: is this a list or a tuple?
(old docstring gives two conflicting
descriptions)
elements are either integers, theano scalar types, or slices.
one element per "explicitly named dimension"
TODO: WRITEME: what is an "explicitly named dimension" ?
if integer:
indexes into the inputs array
if slice:
start/stop/step members of each slice are integer indices
into the inputs array or None
integer indices be actual integers or theano scalar types
Note that the idx_list defines the Op, so two Subtensor instances are
considered to be different Ops if they have different idx_list fields.
This means that the entries in it are theano Types, not theano Variables.
@todo: add support for advanced tensor indexing (in Subtensor_dx too).
"""Basic NumPy indexing operator."""
"""
e_subslice = "nested slicing is not supported"
e_indextype = "Invalid index type or slice for Subtensor"
debug = 0
check_input = False
view_map = {0: [0]}
_f16_ok = True
......@@ -513,27 +481,29 @@ class Subtensor(Op):
when would that happen?
"""
invalid_scal_types = [scal.float64, scal.float32, scal.float16]
scal_types = [scal.int64, scal.int32, scal.int16, scal.int8]
tensor_types = [
invalid_scal_types = (scal.float64, scal.float32, scal.float16)
scal_types = (scal.int64, scal.int32, scal.int16, scal.int8)
tensor_types = (
theano.tensor.lscalar,
theano.tensor.iscalar,
theano.tensor.wscalar,
theano.tensor.bscalar,
]
invalid_tensor_types = [
)
invalid_tensor_types = (
theano.tensor.fscalar,
theano.tensor.dscalar,
theano.tensor.cscalar,
theano.tensor.zscalar,
]
)
if (
isinstance(entry, (np.ndarray, theano.tensor.Variable))
and hasattr(entry, "dtype")
and entry.dtype == "bool"
):
raise AdvancedBooleanIndexingError(Subtensor.e_indextype, entry)
raise AdvancedBooleanIndexingError(
"Invalid index type or slice for Subtensor"
)
if isinstance(entry, gof.Variable) and (
entry.type in invalid_scal_types or entry.type in invalid_tensor_types
......@@ -587,7 +557,7 @@ class Subtensor(Op):
"Python scalar in idx_list." "Please report this error to theano-dev."
)
else:
raise AdvancedIndexingError(Subtensor.e_indextype, entry)
raise AdvancedIndexingError("Invalid index type or slice for Subtensor")
def get_constant_idx(
self, inputs, allow_partial=False, only_process_constants=False, elemwise=True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论