提交 bdf1617d authored 作者: David Warde-Farley's avatar David Warde-Farley

Remove trailing whitespace.

上级 fc199c2f
...@@ -1460,13 +1460,13 @@ class Dot(gof.op.Op): ...@@ -1460,13 +1460,13 @@ class Dot(gof.op.Op):
""" """
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
xshp, yshp = shapes xshp, yshp = shapes
x, y = node.inputs x, y = node.inputs
...@@ -1479,10 +1479,10 @@ class Dot(gof.op.Op): ...@@ -1479,10 +1479,10 @@ class Dot(gof.op.Op):
if x.ndim == 1 and y.ndim == 1: if x.ndim == 1 and y.ndim == 1:
return [()] return [()]
raise NotImplementedError() raise NotImplementedError()
def make_node(self, x, y): def make_node(self, x, y):
dtype_out = scalar.upcast(x.type.dtype, y.type.dtype) dtype_out = scalar.upcast(x.type.dtype, y.type.dtype)
if not _is_sparse_variable(x) and not _is_sparse_variable(y): if not _is_sparse_variable(x) and not _is_sparse_variable(y):
raise TypeError(x) raise TypeError(x)
...@@ -1492,17 +1492,17 @@ class Dot(gof.op.Op): ...@@ -1492,17 +1492,17 @@ class Dot(gof.op.Op):
def perform(self, node, (x, y), (out, )): def perform(self, node, (x, y), (out, )):
x_is_sparse = _is_sparse(x) x_is_sparse = _is_sparse(x)
y_is_sparse = _is_sparse(y) y_is_sparse = _is_sparse(y)
if not x_is_sparse and not y_is_sparse: if not x_is_sparse and not y_is_sparse:
raise TypeError(x) raise TypeError(x)
rval = x * y rval = x * y
if x_is_sparse and y_is_sparse: if x_is_sparse and y_is_sparse:
rval = rval.todense() rval = rval.todense()
out[0] = rval out[0] = rval
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
assert _is_sparse_variable(x) or _is_sparse_variable(y) assert _is_sparse_variable(x) or _is_sparse_variable(y)
...@@ -1525,10 +1525,10 @@ def dot(x, y): ...@@ -1525,10 +1525,10 @@ def dot(x, y):
x_is_sparse_variable = _is_sparse_variable(x) x_is_sparse_variable = _is_sparse_variable(x)
y_is_sparse_variable = _is_sparse_variable(y) y_is_sparse_variable = _is_sparse_variable(y)
if not x_is_sparse_variable and not y_is_sparse_variable: if not x_is_sparse_variable and not y_is_sparse_variable:
raise TypeError() raise TypeError()
return _dot(x, y) return _dot(x, y)
...@@ -1542,16 +1542,16 @@ class Usmm(gof.op.Op): ...@@ -1542,16 +1542,16 @@ class Usmm(gof.op.Op):
""" """
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
def __str__(self): def __str__(self):
return 'Usmm{no_inplace}' return 'Usmm{no_inplace}'
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
xshp, yshp = shapes xshp, yshp = shapes
x, y = node.inputs x, y = node.inputs
...@@ -1564,7 +1564,7 @@ class Usmm(gof.op.Op): ...@@ -1564,7 +1564,7 @@ class Usmm(gof.op.Op):
if x.ndim == 1 and y.ndim == 1: if x.ndim == 1 and y.ndim == 1:
return [()] return [()]
raise NotImplementedError() raise NotImplementedError()
def make_node(self, alpha, x, y, z): def make_node(self, alpha, x, y, z):
if not _is_sparse_variable(x) and not _is_sparse_variable(y): if not _is_sparse_variable(x) and not _is_sparse_variable(y):
# If x and y are tensor, we don't want to use this class # If x and y are tensor, we don't want to use this class
...@@ -1589,10 +1589,10 @@ class Usmm(gof.op.Op): ...@@ -1589,10 +1589,10 @@ class Usmm(gof.op.Op):
def perform(self, node, (alpha, x, y, z), (out, )): def perform(self, node, (alpha, x, y, z), (out, )):
x_is_sparse = _is_sparse(x) x_is_sparse = _is_sparse(x)
y_is_sparse = _is_sparse(y) y_is_sparse = _is_sparse(y)
if not x_is_sparse and not y_is_sparse: if not x_is_sparse and not y_is_sparse:
raise TypeError(x) raise TypeError(x)
rval = x * y rval = x * y
if isinstance(rval, scipy.sparse.spmatrix): if isinstance(rval, scipy.sparse.spmatrix):
rval = rval.toarray() rval = rval.toarray()
...@@ -1604,7 +1604,7 @@ class Usmm(gof.op.Op): ...@@ -1604,7 +1604,7 @@ class Usmm(gof.op.Op):
rval += z # Faster because operation is inplace rval += z # Faster because operation is inplace
else: else:
rval = rval + z rval = rval + z
out[0] = rval out[0] = rval
usmm = Usmm() usmm = Usmm()
...@@ -1612,7 +1612,7 @@ class UsmmCscDense(gof.Op): ...@@ -1612,7 +1612,7 @@ class UsmmCscDense(gof.Op):
""" """
Performs the expression is alpha * x y + z Performs the expression is alpha * x y + z
This is an optimized operation for the case when x is in CSC format. This is an optimized operation for the case when x is in CSC format.
x are sparse matrix x are sparse matrix
y, z is a dense matrix y, z is a dense matrix
alpha is a scalar alpha is a scalar
...@@ -1673,7 +1673,7 @@ class UsmmCscDense(gof.Op): ...@@ -1673,7 +1673,7 @@ class UsmmCscDense(gof.Op):
y = tensor.cast(y, dtype_out) y = tensor.cast(y, dtype_out)
if dtype_out != z.type.dtype: if dtype_out != z.type.dtype:
z = tensor.cast(z, dtype_out) z = tensor.cast(z, dtype_out)
if node.inputs[1].type.dtype in ('complex64', 'complex128'): if node.inputs[1].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for x_val') raise NotImplementedError('Complex types are not supported for x_val')
if node.inputs[5].type.dtype in ('complex64', 'complex128'): if node.inputs[5].type.dtype in ('complex64', 'complex128'):
...@@ -1688,7 +1688,7 @@ class UsmmCscDense(gof.Op): ...@@ -1688,7 +1688,7 @@ class UsmmCscDense(gof.Op):
def c_support_code(self): def c_support_code(self):
return blas.blas_header_text() return blas.blas_header_text()
def c_libraries(self): def c_libraries(self):
return blas.ldflags() return blas.ldflags()
...@@ -1697,7 +1697,7 @@ class UsmmCscDense(gof.Op): ...@@ -1697,7 +1697,7 @@ class UsmmCscDense(gof.Op):
def c_lib_dirs(self): def c_lib_dirs(self):
return blas.ldflags(libs=False, libs_dir=True) return blas.ldflags(libs=False, libs_dir=True)
def c_header_dirs(self): def c_header_dirs(self):
return blas.ldflags(libs=False, include_dir=True) return blas.ldflags(libs=False, include_dir=True)
...@@ -1708,7 +1708,7 @@ class UsmmCscDense(gof.Op): ...@@ -1708,7 +1708,7 @@ class UsmmCscDense(gof.Op):
raise NotImplementedError('Complex types are not supported for y') raise NotImplementedError('Complex types are not supported for y')
if node.inputs[6].type.dtype != node.outputs[0].type.dtype: if node.inputs[6].type.dtype != node.outputs[0].type.dtype:
raise NotImplementedError('z and output must have same type') raise NotImplementedError('z and output must have same type')
if node.inputs[1].type.dtype == "float32": if node.inputs[1].type.dtype == "float32":
conv_type = "float" conv_type = "float"
axpy = "saxpy_" axpy = "saxpy_"
...@@ -1723,7 +1723,7 @@ class UsmmCscDense(gof.Op): ...@@ -1723,7 +1723,7 @@ class UsmmCscDense(gof.Op):
typenum_zn = node.outputs[0].type.dtype_specs()[-1] # retrieve dtype number typenum_zn = node.outputs[0].type.dtype_specs()[-1] # retrieve dtype number
inplace = int(self.inplace) inplace = int(self.inplace)
rval = """ rval = """
if (%(x_val)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(x_val) != 1"); %(fail)s;} if (%(x_val)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(x_val) != 1"); %(fail)s;}
if (%(x_ind)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(x_ind) != 1"); %(fail)s;} if (%(x_ind)s->nd != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(x_ind) != 1"); %(fail)s;}
...@@ -1757,10 +1757,10 @@ class UsmmCscDense(gof.Op): ...@@ -1757,10 +1757,10 @@ class UsmmCscDense(gof.Op):
if (%(x_ptr)s->dimensions[0] != %(y)s->dimensions[0]+1) if (%(x_ptr)s->dimensions[0] != %(y)s->dimensions[0]+1)
{PyErr_SetString(PyExc_NotImplementedError, "x's number of columns doesn't match y's rows"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "x's number of columns doesn't match y's rows"); %(fail)s;}
if (%(z)s->dimensions[0] != ((npy_int32 *)%(x_nrows)s->data)[0] || %(z)s->dimensions[1] != %(y)s->dimensions[1]) if (%(z)s->dimensions[0] != ((npy_int32 *)%(x_nrows)s->data)[0] || %(z)s->dimensions[1] != %(y)s->dimensions[1])
{PyErr_SetString(PyExc_NotImplementedError, "The dimension of the allocated output doesn't match the correct output size."); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "The dimension of the allocated output doesn't match the correct output size."); %(fail)s;}
if (PyArray_SIZE(%(alpha)s) != 1) if (PyArray_SIZE(%(alpha)s) != 1)
{PyErr_SetString(PyExc_NotImplementedError, "The number of element in alpha must be 1"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "The number of element in alpha must be 1"); %(fail)s;}
...@@ -1784,7 +1784,7 @@ class UsmmCscDense(gof.Op): ...@@ -1784,7 +1784,7 @@ class UsmmCscDense(gof.Op):
Py_XDECREF(%(zn)s); Py_XDECREF(%(zn)s);
%(zn)s = %(z)s; %(zn)s = %(z)s;
Py_INCREF(%(zn)s); Py_INCREF(%(zn)s);
} }
else if (!%(zn)s else if (!%(zn)s
|| (%(zn)s->dimensions[0] != ((npy_int32 *)%(x_nrows)s->data)[0]) || (%(zn)s->dimensions[0] != ((npy_int32 *)%(x_nrows)s->data)[0])
|| (%(zn)s->dimensions[1] != %(y)s->dimensions[1]) || (%(zn)s->dimensions[1] != %(y)s->dimensions[1])
...@@ -1796,13 +1796,13 @@ class UsmmCscDense(gof.Op): ...@@ -1796,13 +1796,13 @@ class UsmmCscDense(gof.Op):
dims[1] = %(y)s->dimensions[1]; dims[1] = %(y)s->dimensions[1];
%(zn)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(typenum_zn)s); %(zn)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(typenum_zn)s);
} }
{ {
// sparse array has size MxK, dense KxN, output MxN // sparse array has size MxK, dense KxN, output MxN
npy_intp M = %(zn)s->dimensions[0]; npy_intp M = %(zn)s->dimensions[0];
npy_intp N = %(zn)s->dimensions[1]; npy_intp N = %(zn)s->dimensions[1];
npy_intp K = %(y)s->dimensions[0]; npy_intp K = %(y)s->dimensions[0];
// pointers to access actual data in the arrays passed as params. // pointers to access actual data in the arrays passed as params.
dtype_%(z)s* __restrict__ Dz = (dtype_%(z)s*)%(z)s->data; dtype_%(z)s* __restrict__ Dz = (dtype_%(z)s*)%(z)s->data;
dtype_%(zn)s* __restrict__ Dzn = (dtype_%(zn)s*)%(zn)s->data; dtype_%(zn)s* __restrict__ Dzn = (dtype_%(zn)s*)%(zn)s->data;
...@@ -1810,32 +1810,32 @@ class UsmmCscDense(gof.Op): ...@@ -1810,32 +1810,32 @@ class UsmmCscDense(gof.Op):
const npy_int32 * __restrict__ Dind = (npy_int32*)%(x_ind)s->data; const npy_int32 * __restrict__ Dind = (npy_int32*)%(x_ind)s->data;
const npy_int32 * __restrict__ Dptr = (npy_int32*)%(x_ptr)s->data; const npy_int32 * __restrict__ Dptr = (npy_int32*)%(x_ptr)s->data;
const dtype_%(alpha)s alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0]; const dtype_%(alpha)s alpha = ((dtype_%(alpha)s*)%(alpha)s->data)[0];
npy_intp Sz = %(z)s->strides[1] / %(z)s->descr->elsize; npy_intp Sz = %(z)s->strides[1] / %(z)s->descr->elsize;
npy_intp Szn = %(zn)s->strides[1] / %(zn)s->descr->elsize; npy_intp Szn = %(zn)s->strides[1] / %(zn)s->descr->elsize;
npy_intp Sval = %(x_val)s->strides[0] / %(x_val)s->descr->elsize; npy_intp Sval = %(x_val)s->strides[0] / %(x_val)s->descr->elsize;
npy_intp Sind = %(x_ind)s->strides[0] / %(x_ind)s->descr->elsize; npy_intp Sind = %(x_ind)s->strides[0] / %(x_ind)s->descr->elsize;
npy_intp Sptr = %(x_ptr)s->strides[0] / %(x_ptr)s->descr->elsize; npy_intp Sptr = %(x_ptr)s->strides[0] / %(x_ptr)s->descr->elsize;
npy_intp Sy = %(y)s->strides[1] / %(y)s->descr->elsize; npy_intp Sy = %(y)s->strides[1] / %(y)s->descr->elsize;
if (!(%(inplace)s)) if (!(%(inplace)s))
{ {
memcpy(Dzn, Dz, M*N*sizeof(dtype_%(zn)s)); memcpy(Dzn, Dz, M*N*sizeof(dtype_%(zn)s));
} }
for (npy_int32 k = 0; k < K; ++k) for (npy_int32 k = 0; k < K; ++k)
{ {
for (npy_int32 m_idx = Dptr[k * Sptr]; m_idx < Dptr[(k+1)*Sptr]; ++m_idx) for (npy_int32 m_idx = Dptr[k * Sptr]; m_idx < Dptr[(k+1)*Sptr]; ++m_idx)
{ {
const npy_int32 m = Dind[m_idx * Sind]; // row index of non-null value for column K const npy_int32 m = Dind[m_idx * Sind]; // row index of non-null value for column K
const dtype_%(x_val)s Amk = alpha * Dval[m_idx * Sval]; // actual value at that location const dtype_%(x_val)s Amk = alpha * Dval[m_idx * Sval]; // actual value at that location
const dtype_%(y)s* y_row = (dtype_%(y)s*)(%(y)s->data + %(y)s->strides[0] * k); const dtype_%(y)s* y_row = (dtype_%(y)s*)(%(y)s->data + %(y)s->strides[0] * k);
const dtype_%(zn)s* z_row = (dtype_%(zn)s*)(%(zn)s->data + %(zn)s->strides[0] * m); const dtype_%(zn)s* z_row = (dtype_%(zn)s*)(%(zn)s->data + %(zn)s->strides[0] * m);
%(axpy)s((int*)&N, (%(conv_type)s*)&Amk, (%(conv_type)s*)y_row, (int*)&Sy, (%(conv_type)s*)z_row, (int*)&Szn); %(axpy)s((int*)&N, (%(conv_type)s*)&Amk, (%(conv_type)s*)y_row, (int*)&Sy, (%(conv_type)s*)z_row, (int*)&Szn);
} }
} }
...@@ -1856,10 +1856,10 @@ register_specialize(local_usmm, name="local_usmm") ...@@ -1856,10 +1856,10 @@ register_specialize(local_usmm, name="local_usmm")
def local_usmm_csx(node): def local_usmm_csx(node):
if node.op == usmm: if node.op == usmm:
alpha, x, y, z = node.inputs alpha, x, y, z = node.inputs
x_is_sparse_variable = _is_sparse_variable(x) x_is_sparse_variable = _is_sparse_variable(x)
y_is_sparse_variable = _is_sparse_variable(y) y_is_sparse_variable = _is_sparse_variable(y)
if x_is_sparse_variable and not y_is_sparse_variable: if x_is_sparse_variable and not y_is_sparse_variable:
if x.type.format == 'csc': if x.type.format == 'csc':
x_val, x_ind, x_ptr, x_shape = csm_properties(x) x_val, x_ind, x_ptr, x_shape = csm_properties(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论