提交 6bf0d06d authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Add TensorSharedVariable support to get_vector_length

上级 cb2e0340
......@@ -5002,6 +5002,8 @@ def get_vector_length(v):
raise TypeError("argument must be symbolic vector, got '%s'" % v)
if v.type.broadcastable[0]:
return 1
if isinstance(v, theano.tensor.sharedvar.TensorSharedVariable) and v.type.ndim == 1:
return len(v.get_value())
if isinstance(v, gof.Constant) and v.type.ndim == 1:
return len(v.data)
if v.owner and isinstance(v.owner.op, theano.tensor.opt.MakeVector):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论