提交 c4c73a6f authored 作者: James Bergstra's avatar James Bergstra

added tensor_from_scalar op

上级 950015e0
...@@ -1335,7 +1335,34 @@ class t_gemm(unittest.TestCase): ...@@ -1335,7 +1335,34 @@ class t_gemm(unittest.TestCase):
return return
self.fail() self.fail()
class T_tensorfromscalar(unittest.TestCase):
def test0(self):
s = scal.constant(56)
t = tensor_from_scalar(s)
self.failUnless(t.owner.__class__ is TensorFromScalar)
self.failUnless(t.broadcastable == (), t.broadcastable)
self.failUnless(t.ndim == 0, t.ndim)
self.failUnless(t.dtype == s.dtype)
v = eval_outputs([t])
self.failUnless(v == 56, v)
self.failUnless(isinstance(v, numpy.ndarray))
self.failUnless(v.shape == (), v.shape)
def test1(self):
s = scal.constant(56)
t = astensor(s)
self.failUnless(t.owner.__class__ is TensorFromScalar)
self.failUnless(t.broadcastable == (), t.broadcastable)
self.failUnless(t.ndim == 0, t.ndim)
self.failUnless(t.dtype == s.dtype)
v = eval_outputs([t])
self.failUnless(v == 56, v)
self.failUnless(isinstance(v, numpy.ndarray))
self.failUnless(v.shape == (), v.shape)
def _tensor(data, broadcastable=None, name=None): def _tensor(data, broadcastable=None, name=None):
......
...@@ -10,9 +10,11 @@ from gof.python25 import all ...@@ -10,9 +10,11 @@ from gof.python25 import all
def astensor(data): def astensor(data):
#This symbol is replaced when we import tensor.py, ask Olivier why.
raise Exception("Circular dependencies prevent using this here. import tensor before elemwise") raise Exception("Circular dependencies prevent using this here. import tensor before elemwise")
def Tensor(*inputs, **kwargs): def Tensor(*inputs, **kwargs):
#This symbol is replaced when we import tensor.py, ask Olivier why.
raise Exception("Circular dependencies prevent using this here. import tensor before elemwise") raise Exception("Circular dependencies prevent using this here. import tensor before elemwise")
......
...@@ -318,8 +318,11 @@ def astensor(data, broadcastable=None, name=None): ...@@ -318,8 +318,11 @@ def astensor(data, broadcastable=None, name=None):
if name is not None and name != data.name: if name is not None and name != data.name:
raise ValueError("Cannot rename an existing Tensor.") raise ValueError("Cannot rename an existing Tensor.")
return data return data
elif isinstance(data, scal.Scalar):
return tensor_from_scalar(data)
elif isinstance(data, Result): elif isinstance(data, Result):
raise TypeError("Cannot make a Tensor out of a result that is not an instance of Tensor: %s (%s)" % (data, data.__class__.__name__), data) raise TypeError("Cannot make a Tensor out of a Result that is not an instance of Tensor: %s (%s)" % (data, data.__class__.__name__), data)
if data is None and broadcastable is None: if data is None and broadcastable is None:
raise TypeError("Cannot make a Tensor out of None.") raise TypeError("Cannot make a Tensor out of None.")
...@@ -442,6 +445,21 @@ class _Op(Op): ...@@ -442,6 +445,21 @@ class _Op(Op):
return [rval.pop()] * self.nout return [rval.pop()] * self.nout
#########################
# Casting Operations
#########################
class TensorFromScalar(Op):
def __init__(self, s, **kwargs):
assert isinstance(s, scal.Scalar)
Op.__init__(self, **kwargs)
self.inputs = [s]
self.outputs = [Tensor(s.dtype, broadcastable=[])]
def perform(self):
self.outputs[0].data = self.inputs[0].data
def grad(self, (s,), (dt,)):
raise NotImplementedError('todo: ScalarFromTensor')
tensor_from_scalar = gof.op.constructor(TensorFromScalar)
########################## ##########################
# Unary Operations # Unary Operations
...@@ -797,6 +815,7 @@ def horizontal_stack(x, y, **kwargs): ...@@ -797,6 +815,7 @@ def horizontal_stack(x, y, **kwargs):
return transpose(vertical_stack(x.T, y.T, **kwargs)) return transpose(vertical_stack(x.T, y.T, **kwargs))
######################### #########################
# Linalg : Dot # Linalg : Dot
######################### #########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论