提交 2c8bb9d9 authored 作者: James Bergstra's avatar James Bergstra

Added some theano._asarray casts to perform()s in the sparse module.

上级 ccdc071b
...@@ -289,6 +289,8 @@ class CSMProperties(gof.Op): ...@@ -289,6 +289,8 @@ class CSMProperties(gof.Op):
out[0][0] = csm.data out[0][0] = csm.data
else: else:
out[0][0] = csm.data[self.kmap] out[0][0] = csm.data[self.kmap]
if str(csm.data.dtype) == 'int32':
out[0][0] = theano._asarray(out[0][0], dtype='int32')
#backport #backport
#out[0][0] = csm.data if self.kmap is None else csm.data[self.kmap] #out[0][0] = csm.data if self.kmap is None else csm.data[self.kmap]
out[1][0] = theano._asarray(csm.indices, dtype='int32') out[1][0] = theano._asarray(csm.indices, dtype='int32')
...@@ -562,7 +564,6 @@ class AddSS(gof.op.Op): ...@@ -562,7 +564,6 @@ class AddSS(gof.op.Op):
if x.type.dtype != y.type.dtype: if x.type.dtype != y.type.dtype:
raise NotImplementedError() raise NotImplementedError()
if x.type.format != y.type.format: if x.type.format != y.type.format:
print x.type.format, y.type.format
raise NotImplementedError() raise NotImplementedError()
return gof.Apply(self, return gof.Apply(self,
[x, y], [x, y],
...@@ -815,7 +816,7 @@ class StructuredDotCSC(gof.Op): ...@@ -815,7 +816,7 @@ class StructuredDotCSC(gof.Op):
(a_nrows, b.shape[0]), (a_nrows, b.shape[0]),
copy = False) copy = False)
#out[0] = a.dot(b) #out[0] = a.dot(b)
out[0] = a * b out[0] = theano._asarray(a * b, dtype=node.outputs[0].type.dtype)
assert _is_dense(out[0]) # scipy 0.7 automatically converts to dense assert _is_dense(out[0]) # scipy 0.7 automatically converts to dense
def c_code(self, node, name, (a_val, a_ind, a_ptr, a_nrows, b), (z,), sub): def c_code(self, node, name, (a_val, a_ind, a_ptr, a_nrows, b), (z,), sub):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论