提交 304a40a9 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

New Op to concatenate tensors

上级 a8242ddf
......@@ -936,6 +936,85 @@ class MakeVector(Op):
make_lvector = MakeVector(lscalar)
class Concatenate(Op):
"""
Concatenate two L{Tensor}s along the given axis.
These L{Tensor}s must have the same shape along all dimensions other than
this axis.
"""
def make_node(self, *axis_and_tensors):
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
as_tensor_args= [as_tensor(x) for x in tensors]
dtypes = [x.type.dtype for x in as_tensor_args]
if not all([dtypes[0] == dt for dt in dtypes[1:]]):
# Note that we could automatically find out the appropriate dtype
# able to store the concatenation of all tensors, but for now we
# just raise an error.
raise TypeError('All dtypes must match', tensors)
# When the axis may vary, no dimension can be guaranteed to be
# broadcastable.
bcastable = [False] * len(as_tensor_args[0].type.broadcastable)
# When the axis is fixed, the broadcastable dimensions remain, except
# for the axis dimension.
# All concatenated elements must also have the same broadcastable
# dimensions.
if isinstance(axis, int):
bcasts = [x.type.broadcastable[0:axis] + \
x.type.broadcastable[axis + 1:] for x in as_tensor_args]
if not all([bcasts[0] == bc for bc in bcasts[1:]]):
raise ValueError('Dimensions other than the given axis must'
' match', tensors)
bcastable[:] = as_tensor_args[0].type.broadcastable
bcastable[axis] = False
inputs = [scal.as_scalar(axis)] + as_tensor_args
outputs = [tensor(dtype = dtypes[0],
broadcastable = bcastable)]
return Apply(self, inputs, outputs)
def perform(self, node, axis_and_tensors, (out, )):
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
out[0] = numpy.concatenate(tensors, axis = axis)
def grad(self, axis_and_tensors, (gz,)):
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
n_dims = len(shape(tensors[0]))
sizes_along_axis = [shape(x)[axis] for x in tensors]
idx = [0]
for s in sizes_along_axis:
idx.append(idx[-1] + s)
# The gradient w.r.t. the k-th tensor is a slice of gz along the
# 'axis' dimension.
return [gz[[slice(None)] * axis + [slice(idx[k], idx[k + 1])] + \
[slice(None)] * (n_dims - axis - 1)] \
for k in range(len(sizes_along_axis))]
def concatenate(tensors, axis=0):
"""
Convenience function to concatenate `Tensor`s along the given axis.
The `axis` parameter may either be an integer or an object that can be
converted to a scalar using `as_scalar`(`axis`). In the former case,
the axis is fixed at construction, while in the latter it may vary over
time depending on the value of the `axis` variable.
"""
# Check someone did not make the common mistake to do something like:
# c = concatenate(x, y)
# instead of
# c = concatenate((x, y))
if not isinstance(tensors, (tuple, list)):
raise TypeError("The 'tensors' argument must be either a tuple "
"or a list, make sure you did not forget () or [] around "
"arguments of concatenate.", tensors)
# Ensure we only create one instance of 'Concatenate', to simplify the
# merging job.
if not hasattr(concatenate, 'obj'):
concatenate.obj = Concatenate()
return concatenate.obj(axis, *tensors)
class VerticalStack(Op):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论