提交 3ec221fe authored 作者: Iulian Vlad Serban's avatar Iulian Vlad Serban

Minor modifications to issue #2613.

上级 c58e0892
......@@ -3044,21 +3044,25 @@ class GpuJoin(tensor.Join, GpuOp):
as_tensor_variable_args = [as_cuda_ndarray_variable(x)
for x in tensors]
# If axis is negative, we need to increment it with the number
# of dimensions of the smallest tensor variable
axis_int = int(axis.eval())
# Get joining axis as int
axis_int = 0
if not isinstance(axis, int):
try:
# Note : `get_scalar_constant_value` returns a ndarray not
# an int
axis_int = int(tensor.get_scalar_constant_value(axis))
except tensor.basic.NotScalarConstantError:
pass
else:
axis_int = axis
if (axis_int < 0):
# Find tensor with smallest number of dimensions
min_dim = -1
for cnda in tensors:
if min_dim < 0 or min_dim > len(list(cnda.shape)):
min_dim = len(list(cnda.shape))
# Throw error if it's not safe to increment axis with the
# minimum dim. Normally, this error will be caught in the
# Join op class, but just to be sure we double check it here.
if axis_int + min_dim < 0:
raise ValueError("Cannot join list of tensors at axis (%s) when tensor with smallest dim (%s) is smaller." % (axis_int, min_dim))
if min_dim < 0 or min_dim > cnda.ndim:
min_dim = cnda.ndim
axis = axis + min_dim
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论