提交 63d5a55a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Allow data for sparse matrices to be broadcastable

A problem occurred when the data (and indices) were constant vectors of length one, since they had a broadcastable pattern of (True,), not (False,) as was expected. That happened in the buildbot in debug mode with seed 24856.
上级 bd202537
...@@ -5,7 +5,7 @@ To read about different sparse formats, see U{http://www-users.cs.umn.edu/~saad/ ...@@ -5,7 +5,7 @@ To read about different sparse formats, see U{http://www-users.cs.umn.edu/~saad/
@todo: Automatic methods for determining best sparse format? @todo: Automatic methods for determining best sparse format?
""" """
from itertools import izip
import sys import sys
import numpy, theano import numpy, theano
...@@ -473,13 +473,13 @@ class CSM(gof.Op): ...@@ -473,13 +473,13 @@ class CSM(gof.Op):
shape = tensor.as_tensor_variable(shape) shape = tensor.as_tensor_variable(shape)
if data.type.ndim != 1: if data.type.ndim != 1:
raise TypeError('data argument must be a vector', data.type) raise TypeError('data argument must be a vector', data.type, data.type.ndim)
if indices.type != tensor.ivector: if indices.type.ndim != 1 or indices.type.dtype != 'int32':
raise TypeError('indices must be vector of integers', indices) raise TypeError('indices must be vector of integers', indices, indices.type)
if indptr.type != tensor.ivector: if indptr.type.ndim != 1 or indptr.type.dtype != 'int32':
raise TypeError('indices must be vector of integers', indptr) raise TypeError('indices must be vector of integers', indptr, indptr.type)
if shape.type != tensor.ivector: if shape.type.ndim != 1 or shape.type.dtype != 'int32':
raise TypeError('n_rows must be integer type', shape) raise TypeError('n_rows must be integer type', shape, shape.type)
return gof.Apply(self, return gof.Apply(self,
[data, indices, indptr, shape], [data, indices, indptr, shape],
...@@ -553,7 +553,12 @@ def skip_pack_csc01(node): ...@@ -553,7 +553,12 @@ def skip_pack_csc01(node):
if node.op == csm_properties: if node.op == csm_properties:
csm, = node.inputs csm, = node.inputs
if csm.owner and (csm.owner.op == CSC or csm.owner.op == CSR): if csm.owner and (csm.owner.op == CSC or csm.owner.op == CSR):
return csm.owner.inputs # csm.owner.inputs could be broadcastable. In that case, we have
# to adjust the broadcasting flag here.
ret_var = [tensor.patternbroadcast(i, o.broadcastable)
for i, o in izip(csm.owner.inputs, node.outputs)]
return ret_var
return False return False
register_specialize(skip_pack_csc01) register_specialize(skip_pack_csc01)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论