提交 efe63023 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make as_scalar() work with tensors that have ndim == 0.

上级 a9acfba5
...@@ -82,6 +82,7 @@ get_scalar_type.cache = {} ...@@ -82,6 +82,7 @@ get_scalar_type.cache = {}
def as_scalar(x, name=None): def as_scalar(x, name=None):
from ..tensor import TensorType, scalar_from_tensor
if isinstance(x, gof.Apply): if isinstance(x, gof.Apply):
if len(x.outputs) != 1: if len(x.outputs) != 1:
raise ValueError("It is ambiguous which output of a multi-output" raise ValueError("It is ambiguous which output of a multi-output"
...@@ -89,9 +90,12 @@ def as_scalar(x, name=None): ...@@ -89,9 +90,12 @@ def as_scalar(x, name=None):
else: else:
x = x.outputs[0] x = x.outputs[0]
if isinstance(x, Variable): if isinstance(x, Variable):
if not isinstance(x.type, Scalar): if isinstance(x.type, Scalar):
return x
elif isinstance(x.type, TensorType) and x.ndim == 0:
return scalar_from_tensor(x)
else:
raise TypeError("Variable type field must be a Scalar.", x, x.type) raise TypeError("Variable type field must be a Scalar.", x, x.type)
return x
try: try:
return constant(x) return constant(x)
except TypeError: except TypeError:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论