提交 abd00672 authored 作者: Sina Honari's avatar Sina Honari

issue #1914, changing the place of equal_slices method and making code adaptable…

issue #1914, changing the place of equal_slices method and making code adaptable with PEP8 formatting
上级 144a3b01
...@@ -1268,25 +1268,24 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -1268,25 +1268,24 @@ class TestAdvancedSubtensor(unittest.TestCase):
def test_advanced_indexing(self): def test_advanced_indexing(self):
# tests advanced indexing in Theano for 2D and 3D tensors # tests advanced indexing in Theano for 2D and 3D tensors
rng = numpy.random.RandomState(utt.seed_rng()) rng = numpy.random.RandomState(utt.seed_rng())
a = rng.uniform(size=(3,3)) a = rng.uniform(size=(3, 3))
b = theano.shared(a) b = theano.shared(a)
i = tensor.iscalar() i = tensor.iscalar()
j = tensor.iscalar() j = tensor.iscalar()
z = b[[i, j], :] z = b[[i, j], :]
f1 = theano.function([i,j],z) f1 = theano.function([i, j], z)
cmd = f1(0,1) == a[[0,1],:] cmd = f1(0, 1) == a[[0, 1], :]
self.assertTrue(numpy.all(cmp)) self.assertTrue(numpy.all(cmp))
aa = rng.uniform(size=(4,2,3)) aa = rng.uniform(size=(4, 2, 3))
bb = theano.shared(aa) bb = theano.shared(aa)
k = tensor.iscalar() k = tensor.iscalar()
z = bb[[i, j, k],:, i:k] z = bb[[i, j, k], :, i:k]
f2 = theano.function([i,j,k],z) f2 = theano.function([i, j, k], z)
cmd = f2(0,1,2) == aa[[0,1,2],:, 0:2] cmd = f2(0, 1, 2) == aa[[0, 1, 2], :, 0:2]
self.assertTrue(numpy.all(cmp)) self.assertTrue(numpy.all(cmp))
class TestInferShape(utt.InferShapeTester): class TestInferShape(utt.InferShapeTester):
def test_infer_shape(self): def test_infer_shape(self):
# IncSubtensor # IncSubtensor
......
...@@ -16,6 +16,12 @@ from theano.tensor.type import TensorType ...@@ -16,6 +16,12 @@ from theano.tensor.type import TensorType
from theano.configparser import config from theano.configparser import config
def equal_slices(s1, s2):
return (s1.start == s2.start and
s1.stop == s2.stop and
s1.step == s2.step)
class AsTensorError(TypeError): class AsTensorError(TypeError):
"""Raised when as_tensor_variable isn't able to create a """Raised when as_tensor_variable isn't able to create a
TensorVariable. TensorVariable.
...@@ -350,11 +356,6 @@ class _tensor_py_operators: ...@@ -350,11 +356,6 @@ class _tensor_py_operators:
# argument slice(1, None, None), which is much more desirable. # argument slice(1, None, None), which is much more desirable.
# __getslice__ is deprecated in python 2.6 anyway. # __getslice__ is deprecated in python 2.6 anyway.
def equal_slices(self, s1, s2):
return (s1.start == s2.start and
s1.stop == s2.stop and
s1.step == s2.step)
def __getitem__(self, args): def __getitem__(self, args):
if not isinstance(args, tuple): if not isinstance(args, tuple):
args = args, args = args,
...@@ -380,8 +381,10 @@ class _tensor_py_operators: ...@@ -380,8 +381,10 @@ class _tensor_py_operators:
if advanced: if advanced:
if (axis is not None if (axis is not None
and all(isinstance(a, slice) and self.equal_slices(a, slice(None)) for a in args[:axis]) and all(isinstance(a, slice) and
and all(isinstance(a, slice) and self.equal_slices(a, slice(None)) for a in args[axis + 1:]) equal_slices(a, slice(None)) for a in args[:axis])
and all(isinstance(a, slice) and
equal_slices(a, slice(None)) for a in args[axis + 1:])
and isinstance(args[axis], ( and isinstance(args[axis], (
numpy.ndarray, numpy.ndarray,
list, list,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论