提交 be65be8c authored 作者: David Warde-Farley's avatar David Warde-Farley

Misc. PEP8 fixes.

上级 9ca2180e
......@@ -1524,14 +1524,17 @@ class Dot(gof.op.Op):
return rval
_dot = Dot()
def dot(x, y):
"""
Operation for efficiently calculating the dot product when
one or all operands is sparse. Supported format are CSC and CSR.
The output of the operation is dense.
"""
if hasattr(x, 'getnnz'): x = as_sparse_variable(x)
if hasattr(y, 'getnnz'): y = as_sparse_variable(y)
if hasattr(x, 'getnnz'):
x = as_sparse_variable(x)
if hasattr(y, 'getnnz'):
y = as_sparse_variable(y)
x_is_sparse_variable = _is_sparse_variable(x)
y_is_sparse_variable = _is_sparse_variable(y)
......@@ -1581,12 +1584,13 @@ class Usmm(gof.op.Op):
# We should use Dot22 and Gemm in that case.
raise TypeError(x)
dtype_out = scalar.upcast(alpha.type.dtype, x.type.dtype, y.type.dtype, z.type.dtype)
dtype_out = scalar.upcast(alpha.type.dtype, x.type.dtype,
y.type.dtype, z.type.dtype)
alpha = tensor.as_tensor_variable(alpha)
z = tensor.as_tensor_variable(z)
assert z.ndim == 2
assert alpha.type.broadcastable == (True,)* alpha.ndim
assert alpha.type.broadcastable == (True,) * alpha.ndim
if not _is_sparse_variable(x):
x = tensor.as_tensor_variable(x)
assert x.ndim == 2
......@@ -1594,8 +1598,10 @@ class Usmm(gof.op.Op):
y = tensor.as_tensor_variable(y)
assert y.ndim == 2
return gof.Apply(self, [alpha, x, y, z], [tensor.tensor(dtype=dtype_out, broadcastable=(False, False))])
return gof.Apply(self, [alpha, x, y, z],
[tensor.tensor(dtype=dtype_out,
broadcastable=(False, False))])
def perform(self, node, (alpha, x, y, z), (out, )):
x_is_sparse = _is_sparse(x)
y_is_sparse = _is_sparse(y)
......@@ -1607,17 +1613,18 @@ class Usmm(gof.op.Op):
if isinstance(rval, scipy.sparse.spmatrix):
rval = rval.toarray()
if rval.dtype == alpha.dtype:
rval *= alpha # Faster because operation is inplace
rval *= alpha # Faster because operation is inplace
else:
rval = rval * alpha
if rval.dtype == z.dtype:
rval += z # Faster because operation is inplace
rval += z # Faster because operation is inplace
else:
rval = rval + z
out[0] = rval
usmm = Usmm()
class UsmmCscDense(gof.Op):
"""
Performs the expression is alpha * x y + z
......@@ -1630,16 +1637,20 @@ class UsmmCscDense(gof.Op):
def __init__(self, inplace):
self.inplace = inplace
if inplace:
self.destroy_map={ 0 : [6] }
self.destroy_map = {0: [6]}
def __str__(self):
if self.inplace:
return 'UsmmCscDense{inplace}'
else:
return 'UsmmCscDense{no_inplace}'
def __eq__(self, other):
return (type(self) == type(other)) and self.inplace == other.inplace
def __hash__(self):
return hash(type(self)) ^ self.inplace
def infer_shape(self, node, shapes):
xshp, yshp = shapes
x, y = node.inputs
......@@ -1652,6 +1663,7 @@ class UsmmCscDense(gof.Op):
if x.ndim == 1 and y.ndim == 1:
return [()]
raise NotImplementedError()
def make_node(self, alpha, x_val, x_ind, x_ptr, x_nrows, y, z):
alpha = tensor.as_tensor_variable(alpha)
x_val = tensor.as_tensor_variable(x_val)
......@@ -1685,11 +1697,12 @@ class UsmmCscDense(gof.Op):
z = tensor.cast(z, dtype_out)
if node.inputs[1].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for x_val')
raise NotImplementedError('Complex types are not supported '
'for x_val')
if node.inputs[5].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for y')
r = gof.Apply(self, [alpha, x_val, x_ind, x_ptr, x_nrows, y, z],
r = gof.Apply(self, [alpha, x_val, x_ind, x_ptr, x_nrows, y, z],
[tensor.tensor(dtype_out, (False, y.type.broadcastable[1]))])
return r
......@@ -1715,7 +1728,8 @@ class UsmmCscDense(gof.Op):
alpha, x_val, x_ind, x_ptr, x_nrows, y, z = inputs
zn = outputs[0]
if node.inputs[1].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for x_val')
raise NotImplementedError('Complex types are not supported for '
'x_val')
if node.inputs[5].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for y')
if node.inputs[6].type.dtype != node.outputs[0].type.dtype:
......@@ -1727,13 +1741,13 @@ class UsmmCscDense(gof.Op):
else:
conv_type = "double"
axpy = "daxpy_"
typenum_alpha = node.inputs[0].type.dtype_specs()[-1] # retrieve dtype number
typenum_x_val = node.inputs[1].type.dtype_specs()[-1] # retrieve dtype number
typenum_y = node.inputs[5].type.dtype_specs()[-1] # retrieve dtype number
typenum_z = node.inputs[6].type.dtype_specs()[-1] # retrieve dtype number
typenum_zn = node.outputs[0].type.dtype_specs()[-1] # retrieve dtype number
# retrieve dtype numbers
typenum_alpha = node.inputs[0].type.dtype_specs()[-1]
typenum_x_val = node.inputs[1].type.dtype_specs()[-1]
typenum_y = node.inputs[5].type.dtype_specs()[-1]
typenum_z = node.inputs[6].type.dtype_specs()[-1]
typenum_zn = node.outputs[0].type.dtype_specs()[-1]
inplace = int(self.inplace)
rval = """
......@@ -1852,15 +1866,24 @@ class UsmmCscDense(gof.Op):
}
}
}
"""% dict(locals(), **sub)
""" % dict(locals(), **sub)
return rval
usmm_csc_dense = UsmmCscDense(inplace=False)
usmm_csc_dense_inplace = UsmmCscDense(inplace=True)
local_usmm = gof.opt.PatternSub((tensor.sub, 'z', (tensor.mul, {'pattern' : 'alpha', 'constraint' : lambda expr: numpy.all(expr.type.broadcastable) },
(_dot, 'x', 'y'))),
(usmm, (tensor.neg, 'alpha'), 'x', 'y', 'z'))
local_usmm = gof.opt.PatternSub(
(tensor.sub, 'z',
(tensor.mul,
{'pattern': 'alpha',
'constraint': lambda expr: numpy.all(expr.type.broadcastable)},
(_dot, 'x', 'y'))),
(usmm, (tensor.neg, 'alpha'), 'x', 'y', 'z'))
register_specialize(local_usmm, name="local_usmm")
......@@ -1876,15 +1899,18 @@ def local_usmm_csx(node):
if x.type.format == 'csc':
x_val, x_ind, x_ptr, x_shape = csm_properties(x)
x_nsparse = x_shape[0]
dtype_out = scalar.upcast(alpha.type.dtype, x.type.dtype, y.type.dtype, z.type.dtype)
dtype_out = scalar.upcast(alpha.type.dtype, x.type.dtype,
y.type.dtype, z.type.dtype)
# Sparse cast is not implemented.
if y.type.dtype != dtype_out:
return False
return [usmm_csc_dense(alpha, x_val, x_ind, x_ptr, x_nsparse, y, z)]
return [usmm_csc_dense(alpha, x_val, x_ind, x_ptr,
x_nsparse, y, z)]
return False
register_specialize(local_usmm_csx)
@gof.local_optimizer([usmm_csc_dense])
def local_usmm_csc_dense_inplace(node):
if node.op == usmm_csc_dense:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论