提交 3098fe8a authored 作者: goodfeli's avatar goodfeli

Merge pull request #157 from nouiz/fix_sparse_dot

Fix sparse dot
......@@ -136,7 +136,9 @@ class BadCLinkerOutput(DebugModeError):
sio = StringIO()
print >> sio, "BadCLinkerOutput"
print >> sio, " variable:", self.r
print >> sio, " Type :", self.r.type
print >> sio, " Outputs Type :", self.r.type
print >> sio, " Inputs Type:", [i.type for i in self.r.owner.inputs]
print >> sio, " Apply :", self.r.owner
print >> sio, " val_py :", self.val_py
print >> sio, " val_c :", self.val_c
print >> sio, " op :", self.offending_op()
......
......@@ -327,6 +327,16 @@ class SparseType(gof.Type):
return scipy.sparse.issparse(a) and (a.format == self.format)
# for more dtypes, call SparseType(format, dtype)
def matrix(format, name=None, dtype=None):
if dtype is None:
dtype = config.floatX
type = SparseType(format=format, dtype=dtype)
return type(name)
def csc_matrix(name=None, dtype=None):
return matrix('csc', name, dtype)
def csr_matrix(name=None, dtype=None):
return matrix('csr', name, dtype)
# for more dtypes, call SparseType(format, dtype)
csc_matrix = SparseType(format='csc', dtype=config.floatX)
csr_matrix = SparseType(format='csr', dtype=config.floatX)
csc_dmatrix = SparseType(format='csc', dtype='float64')
......@@ -1505,7 +1515,7 @@ class Dot(gof.op.Op):
rval = x * y
if x_is_sparse and y_is_sparse:
rval = rval.todense()
rval = rval.toarray()
out[0] = rval
......@@ -1553,6 +1563,8 @@ class Usmm(gof.op.Op):
x or y are sparse matrix(the other can be sparse or dense)
z is a dense matrix
alpha is a scalar
:note: We don't implement the infer_shape as it is inserted by optimization only
"""
def __eq__(self, other):
return type(self) == type(other)
......@@ -1566,19 +1578,6 @@ class Usmm(gof.op.Op):
def __str__(self):
return 'Usmm{no_inplace}'
def infer_shape(self, node, shapes):
xshp, yshp = shapes
x, y = node.inputs
if x.ndim == 2 and y.ndim == 2:
return [(xshp[0], yshp[1])]
if x.ndim == 1 and y.ndim == 2:
return [(yshp[1],)]
if x.ndim == 2 and y.ndim == 1:
return [(xshp[0],)]
if x.ndim == 1 and y.ndim == 1:
return [()]
raise NotImplementedError()
def make_node(self, alpha, x, y, z):
if not _is_sparse_variable(x) and not _is_sparse_variable(y):
# If x and y are tensor, we don't want to use this class
......@@ -1634,6 +1633,8 @@ class UsmmCscDense(gof.Op):
x are sparse matrix
y, z is a dense matrix
alpha is a scalar
:note: We don't implement the infer_shape as it is inserted by optimization only
"""
def __init__(self, inplace):
self.inplace = inplace
......@@ -1652,19 +1653,6 @@ class UsmmCscDense(gof.Op):
def __hash__(self):
return hash(type(self)) ^ self.inplace
def infer_shape(self, node, shapes):
xshp, yshp = shapes
x, y = node.inputs
if x.ndim == 2 and y.ndim == 2:
return [(xshp[0], yshp[1])]
if x.ndim == 1 and y.ndim == 2:
return [(yshp[1],)]
if x.ndim == 2 and y.ndim == 1:
return [(xshp[0],)]
if x.ndim == 1 and y.ndim == 1:
return [()]
raise NotImplementedError()
def make_node(self, alpha, x_val, x_ind, x_ptr, x_nrows, y, z):
alpha = tensor.as_tensor_variable(alpha)
x_val = tensor.as_tensor_variable(x_val)
......@@ -1884,6 +1872,7 @@ register_specialize(local_usmm, name="local_usmm")
@gof.local_optimizer([usmm])
def local_usmm_csx(node):
""" usmm -> usmm_csc_dense """
if node.op == usmm:
alpha, x, y, z = node.inputs
......@@ -1896,6 +1885,8 @@ def local_usmm_csx(node):
x_nsparse = x_shape[0]
dtype_out = scalar.upcast(alpha.type.dtype, x.type.dtype,
y.type.dtype, z.type.dtype)
if dtype_out not in ('float32', 'float64'):
return False
# Sparse cast is not implemented.
if y.type.dtype != dtype_out:
return False
......
......@@ -390,20 +390,24 @@ else:
#more strict. Atleast float32 precision.
float64_rtol = 1.0000000000000001e-06
def _allclose(a, b):
def _allclose(a, b, rtol=None, atol=None):
narrow = 'float32', 'complex64'
if (str(a.dtype) in narrow) or (str(b.dtype) in narrow):
atol = float32_atol
rtol = float32_rtol
atol_ = float32_atol
rtol_ = float32_rtol
else:
atol = float64_atol
rtol = float64_rtol
atol_ = float64_atol
rtol_ = float64_rtol
if rtol is not None:
rtol_ = rtol
if atol is not None:
atol_ = atol
# Work around bug in Numpy, see http://projects.scipy.org/numpy/ticket/1684
if str(b.dtype) in int_dtypes and (numpy.absolute(b) < 0).any():
b = theano._asarray(b, dtype='float64')
return numpy.allclose(a,b, atol=atol, rtol=rtol)
return numpy.allclose(a, b, atol=atol_, rtol=rtol_)
def get_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v`
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论