提交 cd57f4e5 authored 作者: abergeron's avatar abergeron

Merge pull request #4470 from nouiz/fix_slice_symbol

crash fix gh-4460
......@@ -329,6 +329,14 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
x = numpy.arange(100).reshape((5, 5, 4))
numpy.allclose(res, x[[slice(1, -1)] * x.ndim])
def test_slice_symbol(self):
x = self.shared(numpy.random.rand(5, 4).astype(self.dtype))
y = self.shared(numpy.random.rand(1, 2, 3).astype(self.dtype))
o = x[:y.shape[0], None, :]
f = theano.function([], o, mode=self.mode)
ret = f()
assert ret.shape == (1, 1, 4)
def test_newaxis(self):
"""
newaxis support comes from logic in the __getitem__ of TensorType
......
......@@ -524,8 +524,17 @@ class _tensor_py_operators(object):
counter += 1
new_args.append(arg)
view = self.dimshuffle(pattern)
check_rval = [arg == slice(None, None, None) for arg in new_args]
if all(check_rval) == True:
full_slices = True
for arg in new_args:
# We can't do arg == slice(None, None, None) as in
# Python 2.7, this call __lt__ if we have a slice
# with some symbolic variable.
if not (isinstance(arg, slice) and
arg.start is None and
arg.stop is None and
arg.step is None):
full_slices = False
if full_slices:
return view
else:
return view.__getitem__(tuple(new_args))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论