提交 bb356c26 authored 作者: Frederic's avatar Frederic

Make subtensor with list work as numpy

上级 b0a20106
...@@ -315,6 +315,12 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -315,6 +315,12 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
n = self.shared(numpy.arange(12, dtype=self.dtype).reshape((4, 3))) n = self.shared(numpy.arange(12, dtype=self.dtype).reshape((4, 3)))
self.assertRaises(Exception, lambda: n[:(2L ** 63)]) self.assertRaises(Exception, lambda: n[:(2L ** 63)])
def test_list_slice(self):
x = theano.tensor.arange(100).reshape((5, 5, 4))
res = x[[slice(1, -1)] * x.ndim].eval()
x = numpy.arange(100).reshape((5, 5, 4))
numpy.allclose(res, x[[slice(1, -1)] * x.ndim])
def test_newaxis(self): def test_newaxis(self):
""" """
newaxis support comes from logic in the __getitem__ of TensorType newaxis support comes from logic in the __getitem__ of TensorType
......
...@@ -355,7 +355,10 @@ class _tensor_py_operators: ...@@ -355,7 +355,10 @@ class _tensor_py_operators:
# SLICING/INDEXING # SLICING/INDEXING
def __getitem__(self, args): def __getitem__(self, args):
if not isinstance(args, tuple): if (isinstance(args, list) and
any([isinstance(a, slice) for a in args])):
pass
elif not isinstance(args, tuple):
args = args, args = args,
# Convert python literals to theano constants # Convert python literals to theano constants
args = theano.tensor.subtensor.make_constant(args) args = theano.tensor.subtensor.make_constant(args)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论