提交 d7148396 authored 作者: Frederic's avatar Frederic

Work around a bug in scipy.sparse.hstack

上级 468fa2b9
...@@ -2246,6 +2246,10 @@ class HStack(gof.op.Op): ...@@ -2246,6 +2246,10 @@ class HStack(gof.op.Op):
assert _is_sparse(b) assert _is_sparse(b)
out[0] = scipy.sparse.hstack(block, format=self.format, out[0] = scipy.sparse.hstack(block, format=self.format,
dtype=self.dtype) dtype=self.dtype)
# Some version of scipy (at least 0.14.0.dev-c4314b0)
# Do not cast to the wanted dtype.
if out[0].dtype != self.dtype:
out[0] = out[0].astype(self.dtype)
def grad(self, inputs, (gz, )): def grad(self, inputs, (gz, )):
is_continuous = [(inputs[i].dtype in tensor.continuous_dtypes) is_continuous = [(inputs[i].dtype in tensor.continuous_dtypes)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论