提交 c5a0b307 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed the test that failed because of the fix to join

上级 6ee3d69f
......@@ -1943,6 +1943,7 @@ class Join(Op):
# for the axis dimension.
# All concatenated elements must also have the same broadcastable
# dimensions.
orig = as_tensor_variable_args
if isinstance(axis, int):
bcasts = [x.type.broadcastable[0:axis] + \
x.type.broadcastable[axis + 1:] for x in as_tensor_variable_args]
......@@ -1964,7 +1965,9 @@ class Join(Op):
outputs = [tensor(dtype = out_dtype,
broadcastable = bcastable)]
return Apply(self, inputs, outputs)
node = Apply(self, inputs, outputs)
node.tag.shape_zero = None if any(not x.type.broadcastable[0] for x in orig) else len(orig)
return node
def perform(self, node, axis_and_tensors, (out, )):
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
......@@ -2004,13 +2007,10 @@ class Join(Op):
assert isinstance(node.owner.op, Join)
if node.ndim != 1:
raise TypeError('argument must be symbolic vector')
inputs = node.owner.inputs
axis, tensors = inputs[0], inputs[1:]
# if v is a vector, then axis must be 0
# the question is whether all the inputs are broadcastable.
if all(i.broadcastable[0] for i in tensors):
return len(tensors)
raise ValueError("could not determine vector length")
if node.owner.tag.shape_zero is None:
raise ValueError("could not determine vector length")
else:
return node.owner.tag.shape_zero
@_redefine_asRoutine(Join())
def join(axis, *tensors):
......@@ -2110,7 +2110,7 @@ def get_vector_length(v):
if v.owner and isinstance(v.owner.op, Join):
try:
return join.vec_length(v)
except:
except ValueError:
pass
if v.owner and v.owner.op == shape:
return v.owner.inputs[0].type.ndim
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论