提交 ad1310c8 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5276 from ReyhaneAskari/fix_5253

Make join work inplace
from __future__ import absolute_import, print_function, division
import theano
from theano.tensor.basic import Join
def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
......@@ -114,10 +115,12 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
# Pad the sequences if needed
if padding:
# Since padding could be an empty tensor, Join returns a view of s.
join = Join(view=0)
for i, s in enumerate(sequences):
n = s.shape[0] % save_every_N
z = theano.tensor.zeros((n, s.shape[1:]), dtype=s.dtype)
sequences[i] = theano.tensor.concatenate([s, z], axis=0)
sequences[i] = join(0, [s, z])
# Establish the input variables of the outer scan
o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] +
......
......@@ -3886,7 +3886,25 @@ class Join(Op):
"""
check_input = False
__props__ = ()
__props__ = ("view",)
def __init__(self, view=-1):
self.view = view
if view != -1:
# since the first input is always the axis, the tensors
# start from index 1.
self.view_map = {0: [1 + view]}
def __str__(self):
if self.view == -1:
return "Join"
else:
return super(Join, self).__str__()
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, "view"):
self.view = -1
def make_node(self, *axis_and_tensors):
"""
......@@ -3998,27 +4016,50 @@ class Join(Op):
def perform(self, node, axis_and_tensors, out_):
out, = out_
view = self.view
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
ndim = tensors[0].ndim
if axis < -ndim:
raise IndexError("Join axis %d out of bounds [0, %d)" %
(axis, ndim))
# we check these tensors for being empty.
if (view != -1) and numpy.all(
[tensor.shape[axis] == 0 for tensor in
tensors[0:view] + tensors[view + 1:]]):
out[0] = tensors[view]
out[0] = theano._asarray(numpy.concatenate(tensors, axis=axis),
dtype=node.outputs[0].type.dtype)
else:
ndim = tensors[0].ndim
if axis < -ndim:
raise IndexError("Join axis %d out of bounds [0, %d)" %
(axis, ndim))
out[0] = theano._asarray(numpy.concatenate(tensors, axis=axis),
dtype=node.outputs[0].type.dtype)
def c_code_cache_version(self):
return (3,)
return (4,)
def c_code(self, node, name, inputs, outputs, sub):
axis, tensors = inputs[0], inputs[1:]
view = self.view
non_empty_tensor = tensors[view]
input_1 = tensors[0]
l = len(tensors)
out, = outputs
fail = sub['fail']
adtype = node.inputs[0].type.dtype_specs()[1]
code = """
PyObject* list = PyList_New(%(l)s);
int axis = ((%(adtype)s *)PyArray_DATA(%(axis)s))[0];
int tensors_lens_sum = 0""" % locals()
for i, inp in enumerate(tensors):
code += """ + PyArray_DIM(%(inp)s, axis) """ % locals()
code += """;\n
tensors_lens_sum -= PyArray_DIM(%(non_empty_tensor)s, axis);
if(%(view)s != -1 && tensors_lens_sum == 0){
Py_XDECREF(%(out)s);
Py_INCREF(%(non_empty_tensor)s);
%(out)s = %(non_empty_tensor)s;
}
else{
PyObject* list = PyList_New(%(l)s);
""" % locals()
for i, inp in enumerate(tensors):
code += """
......@@ -4026,21 +4067,19 @@ class Join(Op):
PyList_SetItem(list, %(i)s, (PyObject*)%(inp)s);
""" % locals()
code += """
//PyObject* PyArray_Concatenate(PyObject* obj, int axis)
int axis = ((%(adtype)s *)PyArray_DATA(%(axis)s))[0];
int ndim = PyArray_NDIM(%(input_1)s);
if( axis < -ndim ){
PyErr_Format(PyExc_IndexError,
"Join axis %%d out of bounds [0, %%d)", axis, ndim);
%(fail)s
}
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject *)PyArray_Concatenate(list, axis);
Py_DECREF(list);
if(!%(out)s){
%(fail)s
//PyObject* PyArray_Concatenate(PyObject* obj, int axis)
int ndim = PyArray_NDIM(%(input_1)s);
if( axis < -ndim ){
PyErr_Format(PyExc_IndexError,
"Join axis %%d out of bounds [0, %%d)", axis, ndim);
%(fail)s
}
Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject *)PyArray_Concatenate(list, axis);
Py_DECREF(list);
if(!%(out)s){
%(fail)s
}
}
""" % locals()
return code
......
......@@ -4251,6 +4251,29 @@ class T_Join_and_Split(unittest.TestCase):
for node in f.maker.fgraph.toposort()])
self.assertRaises(ValueError, f)
def test_join_inplace():
"""Test join to work inplace.
This function tests the case when several elements are passed to the
join function but all except one of them are empty. In this case join
should work inplace and the output should be the view of the non-empty
element.
"""
s = tensor.lscalar()
x = tensor.vector('x')
z = tensor.zeros((s,))
join = Join(view=0)
c = join(0, x, z, z)
f = theano.function([theano.In(x, borrow=True), s], theano.Out(c, borrow=True))
data = numpy.array([3, 4, 5], dtype=theano.config.floatX)
print (f(data, 0))
assert f(data, 0) is data
assert numpy.allclose(f(data, 0), [3, 4, 5])
class test_comparison(unittest.TestCase):
"""Test <, >, <=, >=, == and !=
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论