提交 c5a0b307 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed the test that failed because of the fix to join

上级 6ee3d69f
...@@ -1943,6 +1943,7 @@ class Join(Op): ...@@ -1943,6 +1943,7 @@ class Join(Op):
# for the axis dimension. # for the axis dimension.
# All concatenated elements must also have the same broadcastable # All concatenated elements must also have the same broadcastable
# dimensions. # dimensions.
orig = as_tensor_variable_args
if isinstance(axis, int): if isinstance(axis, int):
bcasts = [x.type.broadcastable[0:axis] + \ bcasts = [x.type.broadcastable[0:axis] + \
x.type.broadcastable[axis + 1:] for x in as_tensor_variable_args] x.type.broadcastable[axis + 1:] for x in as_tensor_variable_args]
...@@ -1964,7 +1965,9 @@ class Join(Op): ...@@ -1964,7 +1965,9 @@ class Join(Op):
outputs = [tensor(dtype = out_dtype, outputs = [tensor(dtype = out_dtype,
broadcastable = bcastable)] broadcastable = bcastable)]
return Apply(self, inputs, outputs) node = Apply(self, inputs, outputs)
node.tag.shape_zero = None if any(not x.type.broadcastable[0] for x in orig) else len(orig)
return node
def perform(self, node, axis_and_tensors, (out, )): def perform(self, node, axis_and_tensors, (out, )):
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:] axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
...@@ -2004,13 +2007,10 @@ class Join(Op): ...@@ -2004,13 +2007,10 @@ class Join(Op):
assert isinstance(node.owner.op, Join) assert isinstance(node.owner.op, Join)
if node.ndim != 1: if node.ndim != 1:
raise TypeError('argument must be symbolic vector') raise TypeError('argument must be symbolic vector')
inputs = node.owner.inputs if node.owner.tag.shape_zero is None:
axis, tensors = inputs[0], inputs[1:] raise ValueError("could not determine vector length")
# if v is a vector, then axis must be 0 else:
# the question is whether all the inputs are broadcastable. return node.owner.tag.shape_zero
if all(i.broadcastable[0] for i in tensors):
return len(tensors)
raise ValueError("could not determine vector length")
@_redefine_asRoutine(Join()) @_redefine_asRoutine(Join())
def join(axis, *tensors): def join(axis, *tensors):
...@@ -2110,7 +2110,7 @@ def get_vector_length(v): ...@@ -2110,7 +2110,7 @@ def get_vector_length(v):
if v.owner and isinstance(v.owner.op, Join): if v.owner and isinstance(v.owner.op, Join):
try: try:
return join.vec_length(v) return join.vec_length(v)
except: except ValueError:
pass pass
if v.owner and v.owner.op == shape: if v.owner and v.owner.op == shape:
return v.owner.inputs[0].type.ndim return v.owner.inputs[0].type.ndim
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论