提交 575c7a97 authored 作者: Cesar Laurent's avatar Cesar Laurent

Global change to new stack interface.

上级 42dedfe7
...@@ -828,7 +828,7 @@ def local_gpu_careduce(node): ...@@ -828,7 +828,7 @@ def local_gpu_careduce(node):
new_in_shp.append(x_shape[i]) new_in_shp.append(x_shape[i])
new_greduce = GpuCAReduce(new_mask, scalar_op) new_greduce = GpuCAReduce(new_mask, scalar_op)
reshaped_x = x.reshape(tensor.stack(*new_in_shp)) reshaped_x = x.reshape(tensor.stack(new_in_shp))
gpu_reshaped_x = as_cuda_ndarray_variable(reshaped_x) gpu_reshaped_x = as_cuda_ndarray_variable(reshaped_x)
reshaped_gpu_inputs = [gpu_reshaped_x] reshaped_gpu_inputs = [gpu_reshaped_x]
if new_greduce.supports_c_code(reshaped_gpu_inputs): if new_greduce.supports_c_code(reshaped_gpu_inputs):
...@@ -837,7 +837,7 @@ def local_gpu_careduce(node): ...@@ -837,7 +837,7 @@ def local_gpu_careduce(node):
if reduce_reshaped_x.ndim != out.ndim: if reduce_reshaped_x.ndim != out.ndim:
rval = reduce_reshaped_x.reshape( rval = reduce_reshaped_x.reshape(
tensor.stack(*shape_of[out])) tensor.stack(shape_of[out]))
else: else:
rval = reduce_reshaped_x rval = reduce_reshaped_x
else: else:
......
...@@ -572,7 +572,7 @@ def local_gpua_careduce(node): ...@@ -572,7 +572,7 @@ def local_gpua_careduce(node):
dtype=getattr(node.op, 'dtype', None), dtype=getattr(node.op, 'dtype', None),
acc_dtype=getattr(node.op, 'acc_dtype', None)) acc_dtype=getattr(node.op, 'acc_dtype', None))
reshaped_x = x.reshape(tensor.stack(*new_in_shp)) reshaped_x = x.reshape(tensor.stack(new_in_shp))
gpu_reshaped_x = gpu_from_host(reshaped_x) gpu_reshaped_x = gpu_from_host(reshaped_x)
gvar = greduce(gpu_reshaped_x) gvar = greduce(gpu_reshaped_x)
# We need to have the make node called, otherwise the mask can # We need to have the make node called, otherwise the mask can
...@@ -584,7 +584,7 @@ def local_gpua_careduce(node): ...@@ -584,7 +584,7 @@ def local_gpua_careduce(node):
if reduce_reshaped_x.ndim != node.outputs[0].ndim: if reduce_reshaped_x.ndim != node.outputs[0].ndim:
unreshaped_reduce = reduce_reshaped_x.reshape( unreshaped_reduce = reduce_reshaped_x.reshape(
tensor.stack(*shape_of[node.outputs[0]])) tensor.stack(shape_of[node.outputs[0]]))
else: else:
unreshaped_reduce = reduce_reshaped_x unreshaped_reduce = reduce_reshaped_x
return [unreshaped_reduce] return [unreshaped_reduce]
......
...@@ -3013,8 +3013,8 @@ class HStack(gof.op.Op): ...@@ -3013,8 +3013,8 @@ class HStack(gof.op.Op):
split = tensor.Split(len(inputs))(gz, 1, split = tensor.Split(len(inputs))(gz, 1,
tensor.stack( tensor.stack(
*[x.shape[1] [x.shape[1]
for x in inputs])) for x in inputs]))
if not isinstance(split, list): if not isinstance(split, list):
split = [split] split = [split]
...@@ -3094,8 +3094,8 @@ class VStack(HStack): ...@@ -3094,8 +3094,8 @@ class VStack(HStack):
split = tensor.Split(len(inputs))(gz, 0, split = tensor.Split(len(inputs))(gz, 0,
tensor.stack( tensor.stack(
*[x.shape[0] [x.shape[0]
for x in inputs])) for x in inputs]))
if not isinstance(split, list): if not isinstance(split, list):
split = [split] split = [split]
......
...@@ -185,7 +185,7 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -185,7 +185,7 @@ def as_tensor_variable(x, name=None, ndim=None):
if isinstance(x, (tuple, list)) and python_any(isinstance(xi, Variable) if isinstance(x, (tuple, list)) and python_any(isinstance(xi, Variable)
for xi in x): for xi in x):
try: try:
return stack(*x) return stack(x)
except (TypeError, ValueError): except (TypeError, ValueError):
pass pass
...@@ -1672,7 +1672,7 @@ def smallest(*args): ...@@ -1672,7 +1672,7 @@ def smallest(*args):
a, b = args a, b = args
return switch(a < b, a, b) return switch(a < b, a, b)
else: else:
return min(stack(*args), axis=0) return min(stack(args), axis=0)
@constructor @constructor
...@@ -1687,7 +1687,7 @@ def largest(*args): ...@@ -1687,7 +1687,7 @@ def largest(*args):
a, b = args a, b = args
return switch(a > b, a, b) return switch(a > b, a, b)
else: else:
return max(stack(*args), axis=0) return max(stack(args), axis=0)
########################## ##########################
...@@ -3806,8 +3806,8 @@ class Join(Op): ...@@ -3806,8 +3806,8 @@ class Join(Op):
if 'float' in out_dtype or 'complex' in out_dtype: if 'float' in out_dtype or 'complex' in out_dtype:
# assume that this is differentiable # assume that this is differentiable
split = Split(len(tensors)) split = Split(len(tensors))
split_gz = split(gz, axis, stack(*[shape(x)[axis] split_gz = split(gz, axis, stack([shape(x)[axis]
for x in tensors])) for x in tensors]))
# If there is only one split, it might not be in a list. # If there is only one split, it might not be in a list.
if not isinstance(split_gz, list): if not isinstance(split_gz, list):
split_gz = [split_gz] split_gz = [split_gz]
...@@ -5691,7 +5691,7 @@ def stacklists(arg): ...@@ -5691,7 +5691,7 @@ def stacklists(arg):
""" """
if isinstance(arg, (tuple, list)): if isinstance(arg, (tuple, list)):
return stack(*list(map(stacklists, arg))) return stack(list(map(stacklists, arg)))
else: else:
return arg return arg
......
...@@ -83,7 +83,7 @@ class Fourier(gof.Op): ...@@ -83,7 +83,7 @@ class Fourier(gof.Op):
list(shape_a[axis.data + 1:])) list(shape_a[axis.data + 1:]))
else: else:
l = len(shape_a) l = len(shape_a)
shape_a = tensor.stack(*shape_a) shape_a = tensor.stack(shape_a)
out_shape = tensor.concatenate((shape_a[0: axis], [n], out_shape = tensor.concatenate((shape_a[0: axis], [n],
shape_a[axis + 1:])) shape_a[axis + 1:]))
n_splits = [1] * l n_splits = [1] * l
......
...@@ -365,7 +365,7 @@ def _infer_ndim_bcast(ndim, shape, *args): ...@@ -365,7 +365,7 @@ def _infer_ndim_bcast(ndim, shape, *args):
if len(pre_v_shape) == 0: if len(pre_v_shape) == 0:
v_shape = tensor.constant([], dtype='int32') v_shape = tensor.constant([], dtype='int32')
else: else:
v_shape = tensor.stack(*pre_v_shape) v_shape = tensor.stack(pre_v_shape)
elif shape is None: elif shape is None:
# The number of drawn samples will be determined automatically, # The number of drawn samples will be determined automatically,
......
...@@ -3465,6 +3465,7 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3465,6 +3465,7 @@ class T_Join_and_Split(unittest.TestCase):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
s = stack([a, b]) s = stack([a, b])
s = stack([a, b], 1) s = stack([a, b], 1)
s = stack([a, b], axis=1)
s = stack(tensors=[a, b]) s = stack(tensors=[a, b])
s = stack(tensors=[a, b], axis=1) s = stack(tensors=[a, b], axis=1)
assert not w assert not w
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论