提交 03329f62 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix tensor.Eye to return the dtype that make_node told it would return.

上级 cd6af578
...@@ -1761,7 +1761,7 @@ class Eye(gof.Op): ...@@ -1761,7 +1761,7 @@ class Eye(gof.Op):
return gof.Apply(self, [n,m,k], [TensorType(dtype = self.dtype, broadcastable = (False,False))()]) return gof.Apply(self, [n,m,k], [TensorType(dtype = self.dtype, broadcastable = (False,False))()])
def perform(self, node, (n,m,k), (out,)): def perform(self, node, (n,m,k), (out,)):
out[0] = numpy.eye(n,m,k) out[0] = numpy.eye(n,m,k,dtype=self.dtype)
def grad(self, (n,m,k),(gout,)): def grad(self, (n,m,k),(gout,)):
return [None, None, None] return [None, None, None]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论