提交 53987a5b authored 作者: James Bergstra's avatar James Bergstra

get_vector_length of a broadcastable vector is 1

上级 ed03732c
......@@ -4262,6 +4262,8 @@ def get_vector_length(v):
v = as_tensor_variable(v)
if v.ndim != 1:
raise TypeError('argument must be symbolic vector')
if v.type.broadcastable[0]:
return 1
if isinstance(v, gof.Constant) and v.type.ndim == 1:
return len(v.data)
if v.owner and isinstance(v.owner.op, Join):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论