提交 3166a3d5 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use compat.python2x.all instead of numpy.all on generator

Otherwise, the generator evaluates to "True", because numpy.all does not actually iterate through it. Also add test case to check that m[0, idx_list] triggers full advanced indexing, not "take".
上级 d2947cae
...@@ -1048,6 +1048,22 @@ inplace_increment_missing = SkipTest( ...@@ -1048,6 +1048,22 @@ inplace_increment_missing = SkipTest(
class TestAdvancedSubtensor(unittest.TestCase): class TestAdvancedSubtensor(unittest.TestCase):
# test inc_subtensor # test inc_subtensor
# also tests set_subtensor # also tests set_subtensor
def __init__(self, name,
shared=tensor._shared,
sub=tensor.AdvancedSubtensor,
inc_sub=tensor.AdvancedIncSubtensor,
mode=None,
dtype=theano.config.floatX,
ignore_topo=DeepCopyOp):
self.shared = shared
self.sub = sub
self.inc_sub = inc_sub
if mode is None:
mode = theano.compile.mode.get_default_mode()
self.mode = mode
self.dtype = dtype
self.ignore_topo = ignore_topo
return super(TestAdvancedSubtensor, self).__init__(name)
def setUp(self): def setUp(self):
self.s = iscalar() self.s = iscalar()
...@@ -1059,6 +1075,16 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -1059,6 +1075,16 @@ class TestAdvancedSubtensor(unittest.TestCase):
self.ix12 = lvector() self.ix12 = lvector()
self.ix2 = lmatrix() self.ix2 = lmatrix()
def eval_output_and_check(self, t):
f = inplace_func([], t, mode=self.mode)
topo = f.maker.fgraph.toposort()
topo_ = [node for node in topo if not isinstance(node.op,
self.ignore_topo)]
assert len(topo_) == 1
assert isinstance(topo_[0].op, self.sub)
tval = f()
return tval
def test_cant_adv_idx_into_scalar(self): def test_cant_adv_idx_into_scalar(self):
self.assertRaises(TypeError, lambda: self.s[self.ix1]) self.assertRaises(TypeError, lambda: self.s[self.ix1])
...@@ -1072,6 +1098,37 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -1072,6 +1098,37 @@ class TestAdvancedSubtensor(unittest.TestCase):
assert a.broadcastable == self.ix2.broadcastable, ( assert a.broadcastable == self.ix2.broadcastable, (
a.broadcastable, self.ix2.broadcastable) a.broadcastable, self.ix2.broadcastable)
def test_index_w_int_and_vec(self):
# like test_ok_list, but with a single index on the first one
# data has to have at least 2 dimensions
for data, idx in [(rand(4, 5), [2, 3]),
(rand(2, 4, 3), [0, 3]),
(rand(2, 4, 3), [3, 3, 1, 1, 2, 2, 0, 0]),
(rand(2, 4, 3), [3, 3, 1, 1, 2, 2, 0, 0,
-1, -2, -3, -4]),
# Test 4 dims as gpu code use another algo
# in that case This new algo is not as much
# optimized for that case.
(rand(4, 4, 2, 3), [3,
3, 1, 1, 2, 2, 0, 0, -1, -2, -3, -4]),
# Test with TensorConstant index.
(rand(2, 4, 3),
theano.tensor.constant([3, 3, 1, 1, 2, 2, 0, 0])),
]:
data = numpy.asarray(data, dtype=self.dtype)
n = self.shared(data)
t = n[0, idx]
self.assertTrue(isinstance(t.owner.op, tensor.AdvancedSubtensor))
val = self.eval_output_and_check(t)
if isinstance(idx, list):
good = data[0, idx]
else:
good = data[0, idx.data]
self.assertTrue(val.ndim == data.ndim - 1)
self.assertTrue(numpy.allclose(val, good), (val, good))
def test_inc_adv_subtensor_w_matrix(self): def test_inc_adv_subtensor_w_matrix(self):
subt = self.v[self.ix2] subt = self.v[self.ix2]
a = inc_subtensor(subt, subt) a = inc_subtensor(subt, subt)
......
...@@ -2,6 +2,7 @@ import numpy ...@@ -2,6 +2,7 @@ import numpy
import theano import theano
from theano.compat import PY3 from theano.compat import PY3
from theano.compat.python2x import all
from theano.scalar import ComplexError, IntegerDivisionError from theano.scalar import ComplexError, IntegerDivisionError
from theano.gof import Constant, Variable from theano.gof import Constant, Variable
from theano.gof.utils import hashtype from theano.gof.utils import hashtype
...@@ -366,8 +367,8 @@ class _tensor_py_operators: ...@@ -366,8 +367,8 @@ class _tensor_py_operators:
if advanced: if advanced:
if (axis is not None if (axis is not None
and numpy.all(a == slice(None) for a in args[:axis]) and all(a == slice(None) for a in args[:axis])
and numpy.all(a == slice(None) for a in args[axis + 1:]) and all(a == slice(None) for a in args[axis + 1:])
and isinstance(args[axis], ( and isinstance(args[axis], (
numpy.ndarray, numpy.ndarray,
list, list,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论