提交 2c44ac61 authored 作者: Reyhane Askari's avatar Reyhane Askari

made perform and c_code of join work in place and added test for it

上级 1c07fecd
...@@ -3888,10 +3888,12 @@ class Join(Op): ...@@ -3888,10 +3888,12 @@ class Join(Op):
check_input = False check_input = False
__props__ = ("view",) __props__ = ("view",)
def __init__(self, view=False): def __init__(self, view=-1):
self.view = view self.view = view
if view: if view != -1:
self.view_map = {0: [0]} # since the first input is always the axis, the tensors
# start from index 1.
self.view_map = {0: [1 + view]}
def make_node(self, *axis_and_tensors): def make_node(self, *axis_and_tensors):
""" """
...@@ -4003,12 +4005,13 @@ class Join(Op): ...@@ -4003,12 +4005,13 @@ class Join(Op):
def perform(self, node, axis_and_tensors, out_): def perform(self, node, axis_and_tensors, out_):
out, = out_ out, = out_
view = self.view
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:] axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
# tailing tensors are all tensors except the first one # we check these tensors for being empty.
tailing_tensors_are_empty = numpy.all( if (view != -1) and numpy.all(
[tensor.shape[axis] == 0 for tensor in axis_and_tensors[2:]]) [tensor.shape[axis] == 0 for tensor in
if tailing_tensors_are_empty: tensors[0:view] + tensors[view + 1:]]):
out[0] = tensors[0] out[0] = tensors[view]
else: else:
ndim = tensors[0].ndim ndim = tensors[0].ndim
...@@ -4020,18 +4023,32 @@ class Join(Op): ...@@ -4020,18 +4023,32 @@ class Join(Op):
dtype=node.outputs[0].type.dtype) dtype=node.outputs[0].type.dtype)
def c_code_cache_version(self): def c_code_cache_version(self):
return return (4,)
return (3,)
def c_code_(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
axis, tensors = inputs[0], inputs[1:] axis, tensors = inputs[0], inputs[1:]
view = self.view
non_empty_tensor = tensors[view]
input_1 = tensors[0] input_1 = tensors[0]
l = len(tensors) l = len(tensors)
out, = outputs out, = outputs
fail = sub['fail'] fail = sub['fail']
adtype = node.inputs[0].type.dtype_specs()[1] adtype = node.inputs[0].type.dtype_specs()[1]
code = """ code = """
PyObject* list = PyList_New(%(l)s); int axis = ((%(adtype)s *)PyArray_DATA(%(axis)s))[0];
int tensors_lens_sum = 0""" % locals()
for i, inp in enumerate(tensors):
code += """ + PyArray_DIM(%(inp)s, axis) """ % locals()
code += """;\n
tensors_lens_sum -= PyArray_DIM(%(non_empty_tensor)s, axis);
if(%(view)s != -1 && tensors_lens_sum == 0){
Py_XDECREF(%(out)s);
Py_INCREF(%(non_empty_tensor)s);
%(out)s = %(non_empty_tensor)s;
}
else{
PyObject* list = PyList_New(%(l)s);
""" % locals() """ % locals()
for i, inp in enumerate(tensors): for i, inp in enumerate(tensors):
code += """ code += """
...@@ -4039,21 +4056,19 @@ class Join(Op): ...@@ -4039,21 +4056,19 @@ class Join(Op):
PyList_SetItem(list, %(i)s, (PyObject*)%(inp)s); PyList_SetItem(list, %(i)s, (PyObject*)%(inp)s);
""" % locals() """ % locals()
code += """ code += """
//PyObject* PyArray_Concatenate(PyObject* obj, int axis) //PyObject* PyArray_Concatenate(PyObject* obj, int axis)
int axis = ((%(adtype)s *)PyArray_DATA(%(axis)s))[0]; int ndim = PyArray_NDIM(%(input_1)s);
int ndim = PyArray_NDIM(%(input_1)s); if( axis < -ndim ){
if( axis < -ndim ){ PyErr_Format(PyExc_IndexError,
PyErr_Format(PyExc_IndexError, "Join axis %%d out of bounds [0, %%d)", axis, ndim);
"Join axis %%d out of bounds [0, %%d)", axis, ndim); %(fail)s
%(fail)s }
} Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject *)PyArray_Concatenate(list, axis);
Py_XDECREF(%(out)s); (list);
%(out)s = (PyArrayObject *)PyArray_Concatenate(list, axis); if(!%(out)s){
%(fail)s
Py_DECREF(list); }
if(!%(out)s){
%(fail)s
} }
""" % locals() """ % locals()
return code return code
......
...@@ -4251,6 +4251,29 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -4251,6 +4251,29 @@ class T_Join_and_Split(unittest.TestCase):
for node in f.maker.fgraph.toposort()]) for node in f.maker.fgraph.toposort()])
self.assertRaises(ValueError, f) self.assertRaises(ValueError, f)
def test_join_inplace():
"""Test join to work inplace.
This function tests the case when several elements are passed to the
join function but all except one of them are empty. In this case join
should work inplace and the output should be the view of the non-empty
element.
"""
s = tensor.lscalar()
x = tensor.vector('x')
z = tensor.zeros((s,))
join = Join(view=0)
c = join(0, x, z, z)
f = theano.function([theano.In(x, borrow=True), s], theano.Out(c, borrow=True))
data = numpy.array([3, 4, 5], dtype=theano.config.floatX)
print (f(data, 0))
assert f(data, 0) is data
assert numpy.allclose(f(data, 0), [3, 4, 5])
class test_comparison(unittest.TestCase): class test_comparison(unittest.TestCase):
"""Test <, >, <=, >=, == and != """Test <, >, <=, >=, == and !=
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论