提交 1e01cf23 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

improvements to as_scalar and astensor

上级 84d1859f
......@@ -11,6 +11,11 @@ from gof import Result, GuardedOp, Env, utils
def as_scalar(x, name = None):
if isinstance(x, gof.Op):
if len(x.outputs) != 1:
raise ValueError("It is ambiguous which output of a multi-output Op has to be fetched.", x)
else:
x = x.outputs[0]
if isinstance(x, float):
s = Scalar('float64', name = name)
s.data = x
......@@ -195,7 +200,7 @@ def _multi(*fns):
else:
return [f(name) for name in names]
if len(fns) == 1:
return partial(f2, fns)
return partial(f2, fns[0])
else:
return [partial(f2, f) for f in fns]
......
......@@ -304,6 +304,12 @@ s2t.Tensor = Tensor
# alternate Tensor constructor
def astensor(data, broadcastable=None, name=None):
"""Return a L{Tensor} containing given data"""
if isinstance(data, Op):
if len(data.outputs) != 1:
raise ValueError("It is ambiguous which output of a multi-output Op has to be fetched.", data)
else:
data = data.outputs[0]
if isinstance(data, Tensor):
if broadcastable is not None and list(data.broadcastable) != list(broadcastable):
raise TypeError("The data to wrap as a Tensor has the wrong broadcastable pattern. Expected %s, got %s." % (broadcastable, data.broadcastable))
......@@ -325,7 +331,7 @@ def astensor(data, broadcastable=None, name=None):
try:
rval = Tensor(data.dtype, broadcastable, name = name)
except TypeError:
raise TypeError("Cannot convert %s to Tensor." % _data)
raise TypeError("Cannot convert %s to Tensor." % repr(_data))
rval.data = data # will raise if broadcastable was mis-specified
return rval
s2t.astensor = astensor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论