提交 2c44ac61 authored 作者: Reyhane Askari's avatar Reyhane Askari

made perform and c_code of join work in place and added test for it

上级 1c07fecd
......@@ -3888,10 +3888,12 @@ class Join(Op):
check_input = False
__props__ = ("view",)
def __init__(self, view=False):
def __init__(self, view=-1):
self.view = view
if view:
self.view_map = {0: [0]}
if view != -1:
# since the first input is always the axis, the tensors
# start from index 1.
self.view_map = {0: [1 + view]}
def make_node(self, *axis_and_tensors):
"""
......@@ -4003,12 +4005,13 @@ class Join(Op):
def perform(self, node, axis_and_tensors, out_):
out, = out_
view = self.view
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
# tailing tensors are all tensors except the first one
tailing_tensors_are_empty = numpy.all(
[tensor.shape[axis] == 0 for tensor in axis_and_tensors[2:]])
if tailing_tensors_are_empty:
out[0] = tensors[0]
# we check these tensors for being empty.
if (view != -1) and numpy.all(
[tensor.shape[axis] == 0 for tensor in
tensors[0:view] + tensors[view + 1:]]):
out[0] = tensors[view]
else:
ndim = tensors[0].ndim
......@@ -4020,18 +4023,32 @@ class Join(Op):
dtype=node.outputs[0].type.dtype)
def c_code_cache_version(self):
return
return (3,)
return (4,)
def c_code_(self, node, name, inputs, outputs, sub):
def c_code(self, node, name, inputs, outputs, sub):
axis, tensors = inputs[0], inputs[1:]
view = self.view
non_empty_tensor = tensors[view]
input_1 = tensors[0]
l = len(tensors)
out, = outputs
fail = sub['fail']
adtype = node.inputs[0].type.dtype_specs()[1]
code = """
PyObject* list = PyList_New(%(l)s);
int axis = ((%(adtype)s *)PyArray_DATA(%(axis)s))[0];
int tensors_lens_sum = 0""" % locals()
for i, inp in enumerate(tensors):
code += """ + PyArray_DIM(%(inp)s, axis) """ % locals()
code += """;\n
tensors_lens_sum -= PyArray_DIM(%(non_empty_tensor)s, axis);
if(%(view)s != -1 && tensors_lens_sum == 0){
Py_XDECREF(%(out)s);
Py_INCREF(%(non_empty_tensor)s);
%(out)s = %(non_empty_tensor)s;
}
else{
PyObject* list = PyList_New(%(l)s);
""" % locals()
for i, inp in enumerate(tensors):
code += """
......@@ -4039,21 +4056,19 @@ class Join(Op):
PyList_SetItem(list, %(i)s, (PyObject*)%(inp)s);
""" % locals()
code += """
//PyObject* PyArray_Concatenate(PyObject* obj, int axis)
int axis = ((%(adtype)s *)PyArray_DATA(%(axis)s))[0];
int ndim = PyArray_NDIM(%(input_1)s);
if( axis < -ndim ){
PyErr_Format(PyExc_IndexError,
"Join axis %%d out of bounds [0, %%d)", axis, ndim);
%(fail)s
}
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject *)PyArray_Concatenate(list, axis);
Py_DECREF(list);
if(!%(out)s){
%(fail)s
//PyObject* PyArray_Concatenate(PyObject* obj, int axis)
int ndim = PyArray_NDIM(%(input_1)s);
if( axis < -ndim ){
PyErr_Format(PyExc_IndexError,
"Join axis %%d out of bounds [0, %%d)", axis, ndim);
%(fail)s
}
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject *)PyArray_Concatenate(list, axis);
(list);
if(!%(out)s){
%(fail)s
}
}
""" % locals()
return code
......
......@@ -4251,6 +4251,29 @@ class T_Join_and_Split(unittest.TestCase):
for node in f.maker.fgraph.toposort()])
self.assertRaises(ValueError, f)
def test_join_inplace():
"""Test join to work inplace.
This function tests the case when several elements are passed to the
join function but all except one of them are empty. In this case join
should work inplace and the output should be the view of the non-empty
element.
"""
s = tensor.lscalar()
x = tensor.vector('x')
z = tensor.zeros((s,))
join = Join(view=0)
c = join(0, x, z, z)
f = theano.function([theano.In(x, borrow=True), s], theano.Out(c, borrow=True))
data = numpy.array([3, 4, 5], dtype=theano.config.floatX)
print (f(data, 0))
assert f(data, 0) is data
assert numpy.allclose(f(data, 0), [3, 4, 5])
class test_comparison(unittest.TestCase):
"""Test <, >, <=, >=, == and !=
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论