提交 2c8bb9d9 authored 作者: James Bergstra's avatar James Bergstra

Added some theano._asarray casts to perform()s in the sparse module.

上级 ccdc071b
......@@ -286,9 +286,11 @@ class CSMProperties(gof.Op):
def perform(self, node, (csm,), out):
if self.kmap is None:
out[0][0] = csm.data
out[0][0] = csm.data
else:
out[0][0] = csm.data[self.kmap]
out[0][0] = csm.data[self.kmap]
if str(csm.data.dtype) == 'int32':
out[0][0] = theano._asarray(out[0][0], dtype='int32')
#backport
#out[0][0] = csm.data if self.kmap is None else csm.data[self.kmap]
out[1][0] = theano._asarray(csm.indices, dtype='int32')
......@@ -562,7 +564,6 @@ class AddSS(gof.op.Op):
if x.type.dtype != y.type.dtype:
raise NotImplementedError()
if x.type.format != y.type.format:
print x.type.format, y.type.format
raise NotImplementedError()
return gof.Apply(self,
[x, y],
......@@ -815,7 +816,7 @@ class StructuredDotCSC(gof.Op):
(a_nrows, b.shape[0]),
copy = False)
#out[0] = a.dot(b)
out[0] = a * b
out[0] = theano._asarray(a * b, dtype=node.outputs[0].type.dtype)
assert _is_dense(out[0]) # scipy 0.7 automatically converts to dense
def c_code(self, node, name, (a_val, a_ind, a_ptr, a_nrows, b), (z,), sub):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论