提交 86842b83 authored 作者: Frederic's avatar Frederic 提交者: Pascal Lamblin

In op CSM.make_node(), cast to int32 only if this is safe.

上级 f2b475db
...@@ -251,7 +251,7 @@ def sp_zeros_like(x): ...@@ -251,7 +251,7 @@ def sp_zeros_like(x):
# TODO: don't restrict to CSM formats # TODO: don't restrict to CSM formats
_, _, indptr, shape = csm_properties(x) _, _, indptr, shape = csm_properties(x)
return CSM(format=x.format)(data=numpy.array([], dtype=x.type.dtype), return CSM(format=x.format)(data=numpy.array([], dtype=x.type.dtype),
indices=numpy.array([]), indices=numpy.array([], dtype='int32'),
indptr=tensor.zeros_like(indptr), indptr=tensor.zeros_like(indptr),
shape=shape) shape=shape)
...@@ -621,12 +621,22 @@ class CSM(gof.Op): ...@@ -621,12 +621,22 @@ class CSM(gof.Op):
def make_node(self, data, indices, indptr, shape): def make_node(self, data, indices, indptr, shape):
data = tensor.as_tensor_variable(data) data = tensor.as_tensor_variable(data)
if not isinstance(indices, tensor.TensorVariable): if not isinstance(indices, gof.Variable):
indices = theano._asarray(indices, dtype='int32') indices_ = numpy.asarray(indices)
if not isinstance(indptr, tensor.TensorVariable): indices_32 = theano._asarray(indices, dtype='int32')
indptr = theano._asarray(indptr, dtype='int32') assert (indices_ == indices_32).all()
if not isinstance(shape, tensor.TensorVariable): indices = indices_32
shape = theano._asarray(shape, dtype='int32') if not isinstance(indptr, gof.Variable):
indptr_ = numpy.asarray(indptr)
indptr_32 = theano._asarray(indptr, dtype='int32')
assert (indptr_ == indptr_32).all()
indptr = indptr_32
if not isinstance(shape, gof.Variable):
shape_ = numpy.asarray(shape)
shape_32 = theano._asarray(shape, dtype='int32')
assert (shape_ == shape_32).all()
shape = shape_32
indices = tensor.as_tensor_variable(indices) indices = tensor.as_tensor_variable(indices)
indptr = tensor.as_tensor_variable(indptr) indptr = tensor.as_tensor_variable(indptr)
shape = tensor.as_tensor_variable(shape) shape = tensor.as_tensor_variable(shape)
......
...@@ -1249,7 +1249,7 @@ class DotTests(utt.InferShapeTester): ...@@ -1249,7 +1249,7 @@ class DotTests(utt.InferShapeTester):
fI = I.flatten() fI = I.flatten()
data = tensor.ones_like(fI) data = tensor.ones_like(fI)
indptr = tensor.arange(data.shape[0] + 1) indptr = tensor.arange(data.shape[0] + 1, dtype='int32')
m1 = sparse.CSR(data, fI, indptr, (8, size)) m1 = sparse.CSR(data, fI, indptr, (8, size))
m2 = sparse.dot(m1, C) m2 = sparse.dot(m1, C)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论