提交 c4c73a6f authored 作者: James Bergstra's avatar James Bergstra

added tensor_from_scalar op

上级 950015e0
......@@ -1335,7 +1335,34 @@ class t_gemm(unittest.TestCase):
return
self.fail()
class T_tensorfromscalar(unittest.TestCase):
def test0(self):
s = scal.constant(56)
t = tensor_from_scalar(s)
self.failUnless(t.owner.__class__ is TensorFromScalar)
self.failUnless(t.broadcastable == (), t.broadcastable)
self.failUnless(t.ndim == 0, t.ndim)
self.failUnless(t.dtype == s.dtype)
v = eval_outputs([t])
self.failUnless(v == 56, v)
self.failUnless(isinstance(v, numpy.ndarray))
self.failUnless(v.shape == (), v.shape)
def test1(self):
s = scal.constant(56)
t = astensor(s)
self.failUnless(t.owner.__class__ is TensorFromScalar)
self.failUnless(t.broadcastable == (), t.broadcastable)
self.failUnless(t.ndim == 0, t.ndim)
self.failUnless(t.dtype == s.dtype)
v = eval_outputs([t])
self.failUnless(v == 56, v)
self.failUnless(isinstance(v, numpy.ndarray))
self.failUnless(v.shape == (), v.shape)
def _tensor(data, broadcastable=None, name=None):
......
......@@ -10,9 +10,11 @@ from gof.python25 import all
def astensor(data):
#This symbol is replaced when we import tensor.py, ask Olivier why.
raise Exception("Circular dependencies prevent using this here. import tensor before elemwise")
def Tensor(*inputs, **kwargs):
#This symbol is replaced when we import tensor.py, ask Olivier why.
raise Exception("Circular dependencies prevent using this here. import tensor before elemwise")
......
......@@ -318,8 +318,11 @@ def astensor(data, broadcastable=None, name=None):
if name is not None and name != data.name:
raise ValueError("Cannot rename an existing Tensor.")
return data
elif isinstance(data, scal.Scalar):
return tensor_from_scalar(data)
elif isinstance(data, Result):
raise TypeError("Cannot make a Tensor out of a result that is not an instance of Tensor: %s (%s)" % (data, data.__class__.__name__), data)
raise TypeError("Cannot make a Tensor out of a Result that is not an instance of Tensor: %s (%s)" % (data, data.__class__.__name__), data)
if data is None and broadcastable is None:
raise TypeError("Cannot make a Tensor out of None.")
......@@ -442,6 +445,21 @@ class _Op(Op):
return [rval.pop()] * self.nout
#########################
# Casting Operations
#########################
class TensorFromScalar(Op):
def __init__(self, s, **kwargs):
assert isinstance(s, scal.Scalar)
Op.__init__(self, **kwargs)
self.inputs = [s]
self.outputs = [Tensor(s.dtype, broadcastable=[])]
def perform(self):
self.outputs[0].data = self.inputs[0].data
def grad(self, (s,), (dt,)):
raise NotImplementedError('todo: ScalarFromTensor')
tensor_from_scalar = gof.op.constructor(TensorFromScalar)
##########################
# Unary Operations
......@@ -797,6 +815,7 @@ def horizontal_stack(x, y, **kwargs):
return transpose(vertical_stack(x.T, y.T, **kwargs))
#########################
# Linalg : Dot
#########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论