提交 4b6d0d75 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix ScalarFromTensor.grad and add test for ScalarFromTensor and TensorFromScalar.grad

上级 298f6be0
...@@ -1181,7 +1181,7 @@ class ScalarFromTensor(Op): ...@@ -1181,7 +1181,7 @@ class ScalarFromTensor(Op):
def perform(self, node, (s, ), (out, )): def perform(self, node, (s, ), (out, )):
out[0] = s.flatten()[0] out[0] = s.flatten()[0]
def grad(self, (s,), (dt,)): def grad(self, (s,), (dt,)):
return [TensorFromScalar(dt)] return [tensor_from_scalar(dt)]
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
scalar_from_tensor = ScalarFromTensor() scalar_from_tensor = ScalarFromTensor()
......
...@@ -1501,6 +1501,29 @@ class T_tensorfromscalar(unittest.TestCase): ...@@ -1501,6 +1501,29 @@ class T_tensorfromscalar(unittest.TestCase):
self.failUnless(isinstance(v, numpy.ndarray)) self.failUnless(isinstance(v, numpy.ndarray))
self.failUnless(v.shape == (), v.shape) self.failUnless(v.shape == (), v.shape)
g = grad(t, s)
self.failUnless(eval_outputs([g])==1)
class T_scalarfromtensor(unittest.TestCase):
def test0(self):
tt = constant(56)#scal.constant(56)
ss = scalar_from_tensor(tt)
self.failUnless(ss.owner.op is scalar_from_tensor)
self.failUnless(ss.type.dtype == tt.type.dtype)
v = eval_outputs([ss])
self.failUnless(v == 56, v)
self.failUnless(isinstance(v, numpy.int8))
self.failUnless(v.shape == (), v.shape)
tt = lscalar()
ss = scalar_from_tensor(tt)
g = ss.owner.op.grad([tt],[ss])
fff=function([tt],ss)
v = fff(numpy.asarray(5))
self.failUnless(v == 5, v)
self.failUnless(isinstance(v, numpy.int64))
self.failUnless(v.shape == (),v.shape)
# def _tensor(data, broadcastable=None, name=None): # def _tensor(data, broadcastable=None, name=None):
# """Return a TensorType containing given data""" # """Return a TensorType containing given data"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论