提交 97132fdd authored 作者: Joseph Turian's avatar Joseph Turian

Cleaned up error checking in sparse.StructuredDot

上级 d273d0f4
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
# The list of objects to document. Objects can be named using # The list of objects to document. Objects can be named using
# dotted names, module filenames, or package directory names. # dotted names, module filenames, or package directory names.
# Alases for this option include "objects" and "values". # Alases for this option include "objects" and "values".
modules: theano modules: theano scipy.sparse
# The type of output that should be generated. Should be one # The type of output that should be generated. Should be one
# of: html, text, latex, dvi, ps, pdf. # of: html, text, latex, dvi, ps, pdf.
......
...@@ -673,16 +673,14 @@ class StructuredDot(gof.Op): ...@@ -673,16 +673,14 @@ class StructuredDot(gof.Op):
""" """
def make_node(self, a, b): def make_node(self, a, b):
assert a.type.dtype == b.type.dtype assert a.type.dtype == b.type.dtype
if type(a) is not SparseResult: if type(a) is not SparseResult and type(a) is not SparseConstant:
raise TypeError('First argument must be of type SparseResult'); raise TypeError('First argument must be of type SparseResult or SparseConstant');
return gof.Apply(self, [a,b], [tensor.tensor(a.type.dtype, (False, False))]) return gof.Apply(self, [a,b], [tensor.tensor(a.type.dtype, (False, False))])
def perform(self, node, (a,b), (out,)): def perform(self, node, (a,b), (out,)):
if a.shape[1] != b.shape[0]: if a.shape[1] != b.shape[0]:
raise ValueError('shape mismatch in StructuredDot.perform', (a.shape, b.shape)) raise ValueError('shape mismatch in StructuredDot.perform', (a.shape, b.shape))
if b.shape[0] == 1:
raise NotImplementedError('ERROR: scipy.csc_matrix dot has bug with singleton dimensions')
result = a.dot(b) result = a.dot(b)
...@@ -699,6 +697,12 @@ class StructuredDot(gof.Op): ...@@ -699,6 +697,12 @@ class StructuredDot(gof.Op):
assert result.ndim == 2 assert result.ndim == 2
if result.shape != (a.shape[0], b.shape[1]):
if b.shape[0] == 1:
raise Exception("a.shape=%s, b.shape=%s, result.shape=%s ??? This is probably because scipy.csc_matrix dot has a bug with singleton dimensions (i.e. b.shape[0]=1), for scipy 0.6. Use scipy 0.7" % (a.shape, b.shape, result.shape))
else:
raise Exception("a.shape=%s, b.shape=%s, result.shape=%s ??? I have no idea why")
## Commenting this out because result should be a numpy.ndarray since the assert above ## Commenting this out because result should be a numpy.ndarray since the assert above
## (JB 20090109) ## (JB 20090109)
#out[0] = numpy.asarray(result) #TODO: fix this really bad implementation #out[0] = numpy.asarray(result) #TODO: fix this really bad implementation
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论