提交 6bf0d06d authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Add TensorSharedVariable support to get_vector_length

上级 cb2e0340
...@@ -5002,6 +5002,8 @@ def get_vector_length(v): ...@@ -5002,6 +5002,8 @@ def get_vector_length(v):
raise TypeError("argument must be symbolic vector, got '%s'" % v) raise TypeError("argument must be symbolic vector, got '%s'" % v)
if v.type.broadcastable[0]: if v.type.broadcastable[0]:
return 1 return 1
if isinstance(v, theano.tensor.sharedvar.TensorSharedVariable) and v.type.ndim == 1:
return len(v.get_value())
if isinstance(v, gof.Constant) and v.type.ndim == 1: if isinstance(v, gof.Constant) and v.type.ndim == 1:
return len(v.data) return len(v.data)
if v.owner and isinstance(v.owner.op, theano.tensor.opt.MakeVector): if v.owner and isinstance(v.owner.op, theano.tensor.opt.MakeVector):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论