提交 83a9cc0b authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: John Salvatier

Skip tests if increment is not available

上级 ef3bba72
...@@ -7367,6 +7367,7 @@ class AdvancedSubtensor(Op): ...@@ -7367,6 +7367,7 @@ class AdvancedSubtensor(Op):
*rest)] + \ *rest)] + \
[DisconnectedType()()] * len(rest) [DisconnectedType()()] * len(rest)
class AdvancedIncSubtensor(Op): class AdvancedIncSubtensor(Op):
"""Increments a subtensor using advanced indexing. """Increments a subtensor using advanced indexing.
...@@ -7476,9 +7477,9 @@ class AdvancedIncSubtensor(Op): ...@@ -7476,9 +7477,9 @@ class AdvancedIncSubtensor(Op):
out, = out_ out, = out_
if not self.inplace: if not self.inplace:
out[0] = inputs[0].copy() out[0] = inputs[0].copy()
else: else:
out[0] = inputs[0] out[0] = inputs[0]
if self.set_instead_of_inc: if self.set_instead_of_inc:
out[0][inputs[2:]] = inputs[1] out[0][inputs[2:]] = inputs[1]
elif self.increment_available: elif self.increment_available:
......
...@@ -3675,6 +3675,7 @@ class TestIncSubtensor1(unittest.TestCase): ...@@ -3675,6 +3675,7 @@ class TestIncSubtensor1(unittest.TestCase):
# also tests set_subtensor # also tests set_subtensor
def setUp(self): def setUp(self):
AdvancedIncSubtensor.check_increment_available()
self.s = iscalar() self.s = iscalar()
self.v = fvector() self.v = fvector()
self.m = dmatrix() self.m = dmatrix()
...@@ -3746,6 +3747,11 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -3746,6 +3747,11 @@ class TestAdvancedSubtensor(unittest.TestCase):
a = self.v[self.ix2] a = self.v[self.ix2]
def test_inc_adv_selection(self): def test_inc_adv_selection(self):
if not AdvancedIncSubtensor.increment_available:
raise SkipTest("inc_subtensor with advanced indexing not enabled. "
"Installing NumPy 1.8 or the latest development version "
"should make that feature available.")
a = inc_subtensor(self.v[self.ix2], self.v[self.ix2]) a = inc_subtensor(self.v[self.ix2], self.v[self.ix2])
assert a.type == self.v.type, (a.type,self.v.type) assert a.type == self.v.type, (a.type,self.v.type)
...@@ -3755,6 +3761,10 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -3755,6 +3761,10 @@ class TestAdvancedSubtensor(unittest.TestCase):
assert numpy.allclose(aval, [.4, .9*3, .1 * 3]) assert numpy.allclose(aval, [.4, .9*3, .1 * 3])
def test_inc_adv_selection2(self): def test_inc_adv_selection2(self):
if not AdvancedIncSubtensor.increment_available:
raise SkipTest("inc_subtensor with advanced indexing not enabled. "
"Installing NumPy 1.8 or the latest development version "
"should make that feature available.")
subt = self.m[self.ix1,self.ix12] subt = self.m[self.ix1,self.ix12]
a = inc_subtensor(subt, subt) a = inc_subtensor(subt, subt)
...@@ -3771,6 +3781,10 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -3771,6 +3781,10 @@ class TestAdvancedSubtensor(unittest.TestCase):
[.5, .3*2, .15]]), aval [.5, .3*2, .15]]), aval
def test_inc_adv_selection_with_broadcasting(self): def test_inc_adv_selection_with_broadcasting(self):
if not AdvancedIncSubtensor.increment_available:
raise SkipTest("inc_subtensor with advanced indexing not enabled. "
"Installing NumPy 1.8 or the latest development version "
"should make that feature available.")
a = inc_subtensor(self.m[self.ix1,self.ix12], 2.1) a = inc_subtensor(self.m[self.ix1,self.ix12], 2.1)
assert a.type == self.m.type, (a.type, self.m.type) assert a.type == self.m.type, (a.type, self.m.type)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论