提交 503bc88b authored 作者: Frederic's avatar Frederic

Work around the dtype bug in scipy.sparse.vstack as for hstack()

上级 d7148396
...@@ -2325,6 +2325,10 @@ class VStack(HStack): ...@@ -2325,6 +2325,10 @@ class VStack(HStack):
assert _is_sparse(b) assert _is_sparse(b)
out[0] = scipy.sparse.vstack(block, format=self.format, out[0] = scipy.sparse.vstack(block, format=self.format,
dtype=self.dtype) dtype=self.dtype)
# Some version of scipy (at least 0.14.0.dev-c4314b0)
# Do not cast to the wanted dtype.
if out[0].dtype != self.dtype:
out[0] = out[0].astype(self.dtype)
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
is_continuous = [(inputs[i].dtype in tensor.continuous_dtypes) is_continuous = [(inputs[i].dtype in tensor.continuous_dtypes)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论