提交 edcbac8e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up Join.make_node

上级 59f09d09
...@@ -2240,47 +2240,32 @@ class Join(COp): ...@@ -2240,47 +2240,32 @@ class Join(COp):
if not hasattr(self, "view"): if not hasattr(self, "view"):
self.view = -1 self.view = -1
def make_node(self, *axis_and_tensors): def make_node(self, axis, *tensors):
""" """
Parameters Parameters
---------- ----------
axis: an Int or integer-valued Variable axis
The axis upon which to join `tensors`.
tensors tensors
A variable number (but not zero) of tensors to A variable number of tensors to join along the specified axis.
concatenate along the specified axis. These tensors must have These tensors must have the same shape along all dimensions other
the same shape along all dimensions other than this axis. than `axis`.
Returns
-------
A symbolic Variable
It has the same ndim as the input tensors, and the most inclusive
dtype.
""" """
axis, tens = axis_and_tensors[0], axis_and_tensors[1:] if not tensors:
if not tens:
raise ValueError("Cannot join an empty list of tensors") raise ValueError("Cannot join an empty list of tensors")
as_tensor_variable_args = [as_tensor_variable(x) for x in tens]
dtypes = [x.type.dtype for x in as_tensor_variable_args]
out_dtype = aes.upcast(*dtypes)
def output_maker(bcastable):
return tensor(dtype=out_dtype, broadcastable=bcastable)
return self._make_node_internal( tensors = [as_tensor_variable(x) for x in tensors]
axis, tens, as_tensor_variable_args, output_maker out_dtype = aes.upcast(*[x.type.dtype for x in tensors])
)
def _make_node_internal(self, axis, tens, as_tensor_variable_args, output_maker): if not builtins.all(targs.type.ndim for targs in tensors):
if not builtins.all(targs.type.ndim for targs in as_tensor_variable_args):
raise TypeError( raise TypeError(
"Join cannot handle arguments of dimension 0." "Join cannot handle arguments of dimension 0."
" For joining scalar values, see @stack" " Use `stack` to join scalar values."
) )
# Handle single-tensor joins immediately. # Handle single-tensor joins immediately.
if len(as_tensor_variable_args) == 1: if len(tensors) == 1:
bcastable = list(as_tensor_variable_args[0].type.broadcastable) bcastable = list(tensors[0].type.broadcastable)
else: else:
# When the axis is fixed, a dimension should be # When the axis is fixed, a dimension should be
# broadcastable if at least one of the inputs is # broadcastable if at least one of the inputs is
...@@ -2288,17 +2273,15 @@ class Join(COp): ...@@ -2288,17 +2273,15 @@ class Join(COp):
# except for the axis dimension. # except for the axis dimension.
# Initialize bcastable all false, and then fill in some trues with # Initialize bcastable all false, and then fill in some trues with
# the loops. # the loops.
bcastable = [False] * len(as_tensor_variable_args[0].type.broadcastable) bcastable = [False] * len(tensors[0].type.broadcastable)
ndim = len(bcastable) ndim = len(bcastable)
# Axis can also be a constant
if not isinstance(axis, int): if not isinstance(axis, int):
try: try:
# Note : `get_scalar_constant_value` returns a ndarray not
# an int
axis = int(get_scalar_constant_value(axis)) axis = int(get_scalar_constant_value(axis))
except NotScalarConstantError: except NotScalarConstantError:
pass pass
if isinstance(axis, int): if isinstance(axis, int):
# Basically, broadcastable -> length 1, but the # Basically, broadcastable -> length 1, but the
# converse does not hold. So we permit e.g. T/F/T # converse does not hold. So we permit e.g. T/F/T
...@@ -2310,12 +2293,12 @@ class Join(COp): ...@@ -2310,12 +2293,12 @@ class Join(COp):
if axis < -ndim: if axis < -ndim:
raise IndexError( raise IndexError(
f"Join axis {int(axis)} out of bounds [0, {int(ndim)})" f"Axis value {axis} is out of range for the given input dimensions"
) )
if axis < 0: if axis < 0:
axis += ndim axis += ndim
for x in as_tensor_variable_args: for x in tensors:
for current_axis, bflag in enumerate(x.type.broadcastable): for current_axis, bflag in enumerate(x.type.broadcastable):
# Constant negative axis can no longer be negative at # Constant negative axis can no longer be negative at
# this point. It safe to compare this way. # this point. It safe to compare this way.
...@@ -2327,34 +2310,24 @@ class Join(COp): ...@@ -2327,34 +2310,24 @@ class Join(COp):
bcastable[axis] = False bcastable[axis] = False
except IndexError: except IndexError:
raise ValueError( raise ValueError(
'Join argument "axis" is out of range' f"Axis value {axis} is out of range for the given input dimensions"
" (given input dimensions)"
) )
else: else:
# When the axis may vary, no dimension can be guaranteed to be # When the axis may vary, no dimension can be guaranteed to be
# broadcastable. # broadcastable.
bcastable = [False] * len(as_tensor_variable_args[0].type.broadcastable) bcastable = [False] * len(tensors[0].type.broadcastable)
if not builtins.all( if not builtins.all([x.ndim == len(bcastable) for x in tensors]):
[x.ndim == len(bcastable) for x in as_tensor_variable_args[1:]]
):
raise TypeError( raise TypeError(
"Join() can only join tensors with the same " "number of dimensions." "Only tensors with the same number of dimensions can be joined"
) )
inputs = [as_tensor_variable(axis)] + list(as_tensor_variable_args) inputs = [as_tensor_variable(axis)] + list(tensors)
if inputs[0].type not in int_types:
raise TypeError(
"Axis could not be cast to an integer type",
axis,
inputs[0].type,
int_types,
)
outputs = [output_maker(bcastable)] if inputs[0].type.dtype not in int_dtypes:
raise TypeError(f"Axis value {inputs[0]} must be an integer type")
node = Apply(self, inputs, outputs) return Apply(self, inputs, [tensor(dtype=out_dtype, broadcastable=bcastable)])
return node
def perform(self, node, axis_and_tensors, out_): def perform(self, node, axis_and_tensors, out_):
(out,) = out_ (out,) = out_
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论