提交 d22c9272 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Made an error message more explicit, and implemented MakeVector.grad

上级 c5412870
...@@ -317,8 +317,10 @@ class MakeVector(T.Op): ...@@ -317,8 +317,10 @@ class MakeVector(T.Op):
inputs = map(T.as_tensor_variable, inputs) inputs = map(T.as_tensor_variable, inputs)
if not all(a.type == inputs[0].type for a in inputs) or (len(inputs)>0 and inputs[0].dtype != self.dtype): if not all(a.type == inputs[0].type for a in inputs) or (len(inputs)>0 and inputs[0].dtype != self.dtype):
dtype=theano.scalar.upcast(self.dtype,*[i.dtype for i in inputs]) dtype=theano.scalar.upcast(self.dtype,*[i.dtype for i in inputs])
#upcast the input to the determined dtype, but don't upcast downcast anything #upcast the input to the determined dtype, but don't downcast anything
assert dtype==self.dtype, "Upcast the input of MakeVector to dtype gived in init without precissino loss only." assert dtype==self.dtype, (
"The upcast of the inputs to MakeVector should match the "
"dtype given in __init__.")
if not all(self.dtype == T.cast(i,dtype=dtype).dtype for a in inputs): if not all(self.dtype == T.cast(i,dtype=dtype).dtype for a in inputs):
raise TypeError("MakeVector.make_node expected inputs upcastable to %s. got %s"%( raise TypeError("MakeVector.make_node expected inputs upcastable to %s. got %s"%(
self.dtype, self.dtype,
...@@ -348,6 +350,9 @@ class MakeVector(T.Op): ...@@ -348,6 +350,9 @@ class MakeVector(T.Op):
# assume that out has correct dtype. there is no cheap way to check # assume that out has correct dtype. there is no cheap way to check
out[0][...] = inputs out[0][...] = inputs
def grad(self, inputs, output_gradients):
return [output_gradients[0][i] for i in xrange(len(inputs))]
make_vector = MakeVector() make_vector = MakeVector()
class MakeVectorPrinter: class MakeVectorPrinter:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论