提交 0786f2af authored 作者: james@X40's avatar james@X40

Added doc to Join, and fixed the mixed-input-type bug.

上级 39916a7e
......@@ -1574,21 +1574,25 @@ class Join(Op):
def make_node(self, *axis_and_tensors):
"""
WRITEME
:param axis: an Int or integer-valued Result
:param tensors: a variable number (but not zero) of tensors to concatenate along the
specified axis. These tensors must have the same shape along all dimensions other than this axis.
:returns: a symbolic Result. It has the same ndim as the input tensors, and the most
inclusive dtype.
"""
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
if not tensors:
raise ValueError('Cannot join an empty list of tensors')
as_tensor_args= [as_tensor(x) for x in tensors]
dtypes = [x.type.dtype for x in as_tensor_args]
out_dtype = scal.upcast(*dtypes)
if not all(targs.type.ndim for targs in as_tensor_args):
raise TypeError('Join cannot handle arguments of dimension 0. For joining scalar values, see @stack');
if not all([dtypes[0] == dt for dt in dtypes[1:]]):
# Note that we could automatically find out the appropriate dtype
# able to store the concatenation of all tensors, but for now we
# just raise an error.
raise TypeError('All dtypes must match', tensors)
# When the axis may vary, no dimension can be guaranteed to be
# broadcastable.
bcastable = [False] * len(as_tensor_args[0].type.broadcastable)
......@@ -1613,16 +1617,14 @@ class Join(Op):
if inputs[0].type not in int_types:
raise TypeError('Axis could not be cast to an integer type', axis, inputs[0].type, int_types)
outputs = [tensor(dtype = dtypes[0],
outputs = [tensor(dtype = out_dtype,
broadcastable = bcastable)]
return Apply(self, inputs, outputs)
def perform(self, node, axis_and_tensors, (out, )):
"""
WRITEME
"""
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
out[0] = numpy.concatenate(tensors, axis = axis)
out[0] = numpy.asarray(numpy.concatenate(tensors, axis = axis),
dtype=node.outputs[0].type.dtype)
def grad(self, axis_and_tensors, (gz,)):
""" The gradient wrt a join op is a `Split`, used to partition the gradient along the
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论