提交 ad1310c8 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5276 from ReyhaneAskari/fix_5253

Make join work inplace
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import theano import theano
from theano.tensor.basic import Join
def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[], def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
...@@ -114,10 +115,12 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[], ...@@ -114,10 +115,12 @@ def scan_checkpoints(fn, sequences=[], outputs_info=None, non_sequences=[],
# Pad the sequences if needed # Pad the sequences if needed
if padding: if padding:
# Since padding could be an empty tensor, Join returns a view of s.
join = Join(view=0)
for i, s in enumerate(sequences): for i, s in enumerate(sequences):
n = s.shape[0] % save_every_N n = s.shape[0] % save_every_N
z = theano.tensor.zeros((n, s.shape[1:]), dtype=s.dtype) z = theano.tensor.zeros((n, s.shape[1:]), dtype=s.dtype)
sequences[i] = theano.tensor.concatenate([s, z], axis=0) sequences[i] = join(0, [s, z])
# Establish the input variables of the outer scan # Establish the input variables of the outer scan
o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] + o_sequences = [s.reshape([s.shape[0] / save_every_N, save_every_N] +
......
...@@ -3886,7 +3886,25 @@ class Join(Op): ...@@ -3886,7 +3886,25 @@ class Join(Op):
""" """
check_input = False check_input = False
__props__ = () __props__ = ("view",)
def __init__(self, view=-1):
self.view = view
if view != -1:
# since the first input is always the axis, the tensors
# start from index 1.
self.view_map = {0: [1 + view]}
def __str__(self):
if self.view == -1:
return "Join"
else:
return super(Join, self).__str__()
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, "view"):
self.view = -1
def make_node(self, *axis_and_tensors): def make_node(self, *axis_and_tensors):
""" """
...@@ -3998,7 +4016,15 @@ class Join(Op): ...@@ -3998,7 +4016,15 @@ class Join(Op):
def perform(self, node, axis_and_tensors, out_): def perform(self, node, axis_and_tensors, out_):
out, = out_ out, = out_
view = self.view
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:] axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
# we check these tensors for being empty.
if (view != -1) and numpy.all(
[tensor.shape[axis] == 0 for tensor in
tensors[0:view] + tensors[view + 1:]]):
out[0] = tensors[view]
else:
ndim = tensors[0].ndim ndim = tensors[0].ndim
if axis < -ndim: if axis < -ndim:
raise IndexError("Join axis %d out of bounds [0, %d)" % raise IndexError("Join axis %d out of bounds [0, %d)" %
...@@ -4008,16 +4034,31 @@ class Join(Op): ...@@ -4008,16 +4034,31 @@ class Join(Op):
dtype=node.outputs[0].type.dtype) dtype=node.outputs[0].type.dtype)
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (4,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
axis, tensors = inputs[0], inputs[1:] axis, tensors = inputs[0], inputs[1:]
view = self.view
non_empty_tensor = tensors[view]
input_1 = tensors[0] input_1 = tensors[0]
l = len(tensors) l = len(tensors)
out, = outputs out, = outputs
fail = sub['fail'] fail = sub['fail']
adtype = node.inputs[0].type.dtype_specs()[1] adtype = node.inputs[0].type.dtype_specs()[1]
code = """ code = """
int axis = ((%(adtype)s *)PyArray_DATA(%(axis)s))[0];
int tensors_lens_sum = 0""" % locals()
for i, inp in enumerate(tensors):
code += """ + PyArray_DIM(%(inp)s, axis) """ % locals()
code += """;\n
tensors_lens_sum -= PyArray_DIM(%(non_empty_tensor)s, axis);
if(%(view)s != -1 && tensors_lens_sum == 0){
Py_XDECREF(%(out)s);
Py_INCREF(%(non_empty_tensor)s);
%(out)s = %(non_empty_tensor)s;
}
else{
PyObject* list = PyList_New(%(l)s); PyObject* list = PyList_New(%(l)s);
""" % locals() """ % locals()
for i, inp in enumerate(tensors): for i, inp in enumerate(tensors):
...@@ -4027,21 +4068,19 @@ class Join(Op): ...@@ -4027,21 +4068,19 @@ class Join(Op):
""" % locals() """ % locals()
code += """ code += """
//PyObject* PyArray_Concatenate(PyObject* obj, int axis) //PyObject* PyArray_Concatenate(PyObject* obj, int axis)
int axis = ((%(adtype)s *)PyArray_DATA(%(axis)s))[0];
int ndim = PyArray_NDIM(%(input_1)s); int ndim = PyArray_NDIM(%(input_1)s);
if( axis < -ndim ){ if( axis < -ndim ){
PyErr_Format(PyExc_IndexError, PyErr_Format(PyExc_IndexError,
"Join axis %%d out of bounds [0, %%d)", axis, ndim); "Join axis %%d out of bounds [0, %%d)", axis, ndim);
%(fail)s %(fail)s
} }
Py_XDECREF(%(out)s); Py_XDECREF(%(out)s);
%(out)s = (PyArrayObject *)PyArray_Concatenate(list, axis); %(out)s = (PyArrayObject *)PyArray_Concatenate(list, axis);
Py_DECREF(list); Py_DECREF(list);
if(!%(out)s){ if(!%(out)s){
%(fail)s %(fail)s
} }
}
""" % locals() """ % locals()
return code return code
......
...@@ -4251,6 +4251,29 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -4251,6 +4251,29 @@ class T_Join_and_Split(unittest.TestCase):
for node in f.maker.fgraph.toposort()]) for node in f.maker.fgraph.toposort()])
self.assertRaises(ValueError, f) self.assertRaises(ValueError, f)
def test_join_inplace():
"""Test join to work inplace.
This function tests the case when several elements are passed to the
join function but all except one of them are empty. In this case join
should work inplace and the output should be the view of the non-empty
element.
"""
s = tensor.lscalar()
x = tensor.vector('x')
z = tensor.zeros((s,))
join = Join(view=0)
c = join(0, x, z, z)
f = theano.function([theano.In(x, borrow=True), s], theano.Out(c, borrow=True))
data = numpy.array([3, 4, 5], dtype=theano.config.floatX)
print (f(data, 0))
assert f(data, 0) is data
assert numpy.allclose(f(data, 0), [3, 4, 5])
class test_comparison(unittest.TestCase): class test_comparison(unittest.TestCase):
"""Test <, >, <=, >=, == and != """Test <, >, <=, >=, == and !=
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论