提交 cd57f4e5 authored 作者: abergeron's avatar abergeron

Merge pull request #4470 from nouiz/fix_slice_symbol

crash fix gh-4460
...@@ -329,6 +329,14 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -329,6 +329,14 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
x = numpy.arange(100).reshape((5, 5, 4)) x = numpy.arange(100).reshape((5, 5, 4))
numpy.allclose(res, x[[slice(1, -1)] * x.ndim]) numpy.allclose(res, x[[slice(1, -1)] * x.ndim])
def test_slice_symbol(self):
x = self.shared(numpy.random.rand(5, 4).astype(self.dtype))
y = self.shared(numpy.random.rand(1, 2, 3).astype(self.dtype))
o = x[:y.shape[0], None, :]
f = theano.function([], o, mode=self.mode)
ret = f()
assert ret.shape == (1, 1, 4)
def test_newaxis(self): def test_newaxis(self):
""" """
newaxis support comes from logic in the __getitem__ of TensorType newaxis support comes from logic in the __getitem__ of TensorType
......
...@@ -524,8 +524,17 @@ class _tensor_py_operators(object): ...@@ -524,8 +524,17 @@ class _tensor_py_operators(object):
counter += 1 counter += 1
new_args.append(arg) new_args.append(arg)
view = self.dimshuffle(pattern) view = self.dimshuffle(pattern)
check_rval = [arg == slice(None, None, None) for arg in new_args] full_slices = True
if all(check_rval) == True: for arg in new_args:
# We can't do arg == slice(None, None, None) as in
# Python 2.7, this call __lt__ if we have a slice
# with some symbolic variable.
if not (isinstance(arg, slice) and
arg.start is None and
arg.stop is None and
arg.step is None):
full_slices = False
if full_slices:
return view return view
else: else:
return view.__getitem__(tuple(new_args)) return view.__getitem__(tuple(new_args))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论