提交 f57da9d4 authored 作者: abalkin's avatar abalkin

Implemented c_code for AdvancedSubtensor1

上级 0421c6b0
...@@ -6549,8 +6549,44 @@ class AdvancedSubtensor1(Op): ...@@ -6549,8 +6549,44 @@ class AdvancedSubtensor1(Op):
x, ilist = ishapes x, ilist = ishapes
return [ilist + x[1:]] return [ilist + x[1:]]
advanced_subtensor1 = AdvancedSubtensor1() def c_code(self, node, name, input_names, output_names, sub):
a_name, i_name = input_names[0], input_names[1]
output_name = output_names[0]
fail = sub['fail']
return """
if (%(output_name)s != NULL) {
npy_intp nd, i, *shape;
nd = PyArray_NDIM(%(a_name)s) + PyArray_NDIM(%(i_name)s) - 1;
if (PyArray_NDIM(%(output_name)s) != nd) {
Py_CLEAR(%(output_name)s);
}
else {
shape = PyArray_DIMS(%(output_name)s);
for (i = 0; i < PyArray_NDIM(%(i_name)s); i++) {
if (shape[i] != PyArray_DIMS(%(i_name)s)[i]) {
Py_CLEAR(%(output_name)s);
break;
}
}
if (%(output_name)s != NULL) {
for (; i < nd; i++) {
if (shape[i] != PyArray_DIMS(%(a_name)s)[i-PyArray_NDIM(%(i_name)s)+1]) {
Py_CLEAR(%(output_name)s);
break;
}
}
}
}
}
%(output_name)s = (PyArrayObject*)PyArray_TakeFrom(%(a_name)s, (PyObject*)%(i_name)s, 0,
%(output_name)s, NPY_RAISE);
if (%(output_name)s == NULL) %(fail)s;
""" % locals()
def c_code_cache_version(self):
return (0, 0, 2)
advanced_subtensor1 = AdvancedSubtensor1()
class AdvancedIncSubtensor1(Op): class AdvancedIncSubtensor1(Op):
"""Increments a subtensor using advanced slicing (list of index)""" """Increments a subtensor using advanced slicing (list of index)"""
......
...@@ -6967,7 +6967,7 @@ class TestInferShape(utt.InferShapeTester): ...@@ -6967,7 +6967,7 @@ class TestInferShape(utt.InferShapeTester):
class TestTensorInstanceMethods(unittest.TestCase): class TestTensorInstanceMethods(unittest.TestCase):
def setUp(self): def setUp(self):
self.vars = matrices('X', 'Y') self.vars = matrices('X', 'Y')
self.vals = [rand(2,2),rand(2,2)] self.vals = [m.astype(floatX) for m in [rand(2,2),rand(2,2)]]
def test_argmin(self): def test_argmin(self):
X, _ = self.vars X, _ = self.vars
...@@ -7060,7 +7060,7 @@ class TestTensorInstanceMethods(unittest.TestCase): ...@@ -7060,7 +7060,7 @@ class TestTensorInstanceMethods(unittest.TestCase):
X, _ = self.vars X, _ = self.vars
x, _ = self.vals x, _ = self.vals
indices = [1,0,3] indices = [1,0,3]
assert_array_equal(X.take(indices).eval({X: x}), x.take(indices)) assert_array_equal(X.take(indices).sum().eval({X: x}), x.take(indices).sum())
indices = [1,0,1] indices = [1,0,1]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1)) assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
indices = [-10,5,12] indices = [-10,5,12]
...@@ -7072,6 +7072,11 @@ class TestTensorInstanceMethods(unittest.TestCase): ...@@ -7072,6 +7072,11 @@ class TestTensorInstanceMethods(unittest.TestCase):
x.take(indices, 1, mode='clip')) x.take(indices, 1, mode='clip'))
assert_array_equal(X.take(indices, -1, mode='clip').eval({X: x}), assert_array_equal(X.take(indices, -1, mode='clip').eval({X: x}),
x.take(indices, -1, mode='clip')) x.take(indices, -1, mode='clip'))
# Test error handling
self.assertRaises(IndexError, X.take(indices).eval, {X: x})
self.assertRaises(IndexError, (2 * X.take(indices)).eval, {X: x})
self.assertRaises(TypeError, X.take, [0.0])
indices = [[1,0,1], [0,1,1]] indices = [[1,0,1], [0,1,1]]
assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1)) assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1))
# Test equivalent advanced indexing # Test equivalent advanced indexing
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论