提交 be65be8c authored 作者: David Warde-Farley's avatar David Warde-Farley

Misc. PEP8 fixes.

上级 9ca2180e
...@@ -1524,14 +1524,17 @@ class Dot(gof.op.Op): ...@@ -1524,14 +1524,17 @@ class Dot(gof.op.Op):
return rval return rval
_dot = Dot() _dot = Dot()
def dot(x, y): def dot(x, y):
""" """
Operation for efficiently calculating the dot product when Operation for efficiently calculating the dot product when
one or all operands is sparse. Supported format are CSC and CSR. one or all operands is sparse. Supported format are CSC and CSR.
The output of the operation is dense. The output of the operation is dense.
""" """
if hasattr(x, 'getnnz'): x = as_sparse_variable(x) if hasattr(x, 'getnnz'):
if hasattr(y, 'getnnz'): y = as_sparse_variable(y) x = as_sparse_variable(x)
if hasattr(y, 'getnnz'):
y = as_sparse_variable(y)
x_is_sparse_variable = _is_sparse_variable(x) x_is_sparse_variable = _is_sparse_variable(x)
y_is_sparse_variable = _is_sparse_variable(y) y_is_sparse_variable = _is_sparse_variable(y)
...@@ -1581,12 +1584,13 @@ class Usmm(gof.op.Op): ...@@ -1581,12 +1584,13 @@ class Usmm(gof.op.Op):
# We should use Dot22 and Gemm in that case. # We should use Dot22 and Gemm in that case.
raise TypeError(x) raise TypeError(x)
dtype_out = scalar.upcast(alpha.type.dtype, x.type.dtype, y.type.dtype, z.type.dtype) dtype_out = scalar.upcast(alpha.type.dtype, x.type.dtype,
y.type.dtype, z.type.dtype)
alpha = tensor.as_tensor_variable(alpha) alpha = tensor.as_tensor_variable(alpha)
z = tensor.as_tensor_variable(z) z = tensor.as_tensor_variable(z)
assert z.ndim == 2 assert z.ndim == 2
assert alpha.type.broadcastable == (True,)* alpha.ndim assert alpha.type.broadcastable == (True,) * alpha.ndim
if not _is_sparse_variable(x): if not _is_sparse_variable(x):
x = tensor.as_tensor_variable(x) x = tensor.as_tensor_variable(x)
assert x.ndim == 2 assert x.ndim == 2
...@@ -1594,8 +1598,10 @@ class Usmm(gof.op.Op): ...@@ -1594,8 +1598,10 @@ class Usmm(gof.op.Op):
y = tensor.as_tensor_variable(y) y = tensor.as_tensor_variable(y)
assert y.ndim == 2 assert y.ndim == 2
return gof.Apply(self, [alpha, x, y, z], [tensor.tensor(dtype=dtype_out, broadcastable=(False, False))]) return gof.Apply(self, [alpha, x, y, z],
[tensor.tensor(dtype=dtype_out,
broadcastable=(False, False))])
def perform(self, node, (alpha, x, y, z), (out, )): def perform(self, node, (alpha, x, y, z), (out, )):
x_is_sparse = _is_sparse(x) x_is_sparse = _is_sparse(x)
y_is_sparse = _is_sparse(y) y_is_sparse = _is_sparse(y)
...@@ -1607,17 +1613,18 @@ class Usmm(gof.op.Op): ...@@ -1607,17 +1613,18 @@ class Usmm(gof.op.Op):
if isinstance(rval, scipy.sparse.spmatrix): if isinstance(rval, scipy.sparse.spmatrix):
rval = rval.toarray() rval = rval.toarray()
if rval.dtype == alpha.dtype: if rval.dtype == alpha.dtype:
rval *= alpha # Faster because operation is inplace rval *= alpha # Faster because operation is inplace
else: else:
rval = rval * alpha rval = rval * alpha
if rval.dtype == z.dtype: if rval.dtype == z.dtype:
rval += z # Faster because operation is inplace rval += z # Faster because operation is inplace
else: else:
rval = rval + z rval = rval + z
out[0] = rval out[0] = rval
usmm = Usmm() usmm = Usmm()
class UsmmCscDense(gof.Op): class UsmmCscDense(gof.Op):
""" """
Performs the expression is alpha * x y + z Performs the expression is alpha * x y + z
...@@ -1630,16 +1637,20 @@ class UsmmCscDense(gof.Op): ...@@ -1630,16 +1637,20 @@ class UsmmCscDense(gof.Op):
def __init__(self, inplace): def __init__(self, inplace):
self.inplace = inplace self.inplace = inplace
if inplace: if inplace:
self.destroy_map={ 0 : [6] } self.destroy_map = {0: [6]}
def __str__(self): def __str__(self):
if self.inplace: if self.inplace:
return 'UsmmCscDense{inplace}' return 'UsmmCscDense{inplace}'
else: else:
return 'UsmmCscDense{no_inplace}' return 'UsmmCscDense{no_inplace}'
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) and self.inplace == other.inplace return (type(self) == type(other)) and self.inplace == other.inplace
def __hash__(self): def __hash__(self):
return hash(type(self)) ^ self.inplace return hash(type(self)) ^ self.inplace
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
xshp, yshp = shapes xshp, yshp = shapes
x, y = node.inputs x, y = node.inputs
...@@ -1652,6 +1663,7 @@ class UsmmCscDense(gof.Op): ...@@ -1652,6 +1663,7 @@ class UsmmCscDense(gof.Op):
if x.ndim == 1 and y.ndim == 1: if x.ndim == 1 and y.ndim == 1:
return [()] return [()]
raise NotImplementedError() raise NotImplementedError()
def make_node(self, alpha, x_val, x_ind, x_ptr, x_nrows, y, z): def make_node(self, alpha, x_val, x_ind, x_ptr, x_nrows, y, z):
alpha = tensor.as_tensor_variable(alpha) alpha = tensor.as_tensor_variable(alpha)
x_val = tensor.as_tensor_variable(x_val) x_val = tensor.as_tensor_variable(x_val)
...@@ -1685,11 +1697,12 @@ class UsmmCscDense(gof.Op): ...@@ -1685,11 +1697,12 @@ class UsmmCscDense(gof.Op):
z = tensor.cast(z, dtype_out) z = tensor.cast(z, dtype_out)
if node.inputs[1].type.dtype in ('complex64', 'complex128'): if node.inputs[1].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for x_val') raise NotImplementedError('Complex types are not supported '
'for x_val')
if node.inputs[5].type.dtype in ('complex64', 'complex128'): if node.inputs[5].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for y') raise NotImplementedError('Complex types are not supported for y')
r = gof.Apply(self, [alpha, x_val, x_ind, x_ptr, x_nrows, y, z], r = gof.Apply(self, [alpha, x_val, x_ind, x_ptr, x_nrows, y, z],
[tensor.tensor(dtype_out, (False, y.type.broadcastable[1]))]) [tensor.tensor(dtype_out, (False, y.type.broadcastable[1]))])
return r return r
...@@ -1715,7 +1728,8 @@ class UsmmCscDense(gof.Op): ...@@ -1715,7 +1728,8 @@ class UsmmCscDense(gof.Op):
alpha, x_val, x_ind, x_ptr, x_nrows, y, z = inputs alpha, x_val, x_ind, x_ptr, x_nrows, y, z = inputs
zn = outputs[0] zn = outputs[0]
if node.inputs[1].type.dtype in ('complex64', 'complex128'): if node.inputs[1].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for x_val') raise NotImplementedError('Complex types are not supported for '
'x_val')
if node.inputs[5].type.dtype in ('complex64', 'complex128'): if node.inputs[5].type.dtype in ('complex64', 'complex128'):
raise NotImplementedError('Complex types are not supported for y') raise NotImplementedError('Complex types are not supported for y')
if node.inputs[6].type.dtype != node.outputs[0].type.dtype: if node.inputs[6].type.dtype != node.outputs[0].type.dtype:
...@@ -1727,13 +1741,13 @@ class UsmmCscDense(gof.Op): ...@@ -1727,13 +1741,13 @@ class UsmmCscDense(gof.Op):
else: else:
conv_type = "double" conv_type = "double"
axpy = "daxpy_" axpy = "daxpy_"
# retrieve dtype numbers
typenum_alpha = node.inputs[0].type.dtype_specs()[-1] # retrieve dtype number typenum_alpha = node.inputs[0].type.dtype_specs()[-1]
typenum_x_val = node.inputs[1].type.dtype_specs()[-1] # retrieve dtype number typenum_x_val = node.inputs[1].type.dtype_specs()[-1]
typenum_y = node.inputs[5].type.dtype_specs()[-1] # retrieve dtype number typenum_y = node.inputs[5].type.dtype_specs()[-1]
typenum_z = node.inputs[6].type.dtype_specs()[-1] # retrieve dtype number typenum_z = node.inputs[6].type.dtype_specs()[-1]
typenum_zn = node.outputs[0].type.dtype_specs()[-1] # retrieve dtype number typenum_zn = node.outputs[0].type.dtype_specs()[-1]
inplace = int(self.inplace) inplace = int(self.inplace)
rval = """ rval = """
...@@ -1852,15 +1866,24 @@ class UsmmCscDense(gof.Op): ...@@ -1852,15 +1866,24 @@ class UsmmCscDense(gof.Op):
} }
} }
} }
"""% dict(locals(), **sub) """ % dict(locals(), **sub)
return rval return rval
usmm_csc_dense = UsmmCscDense(inplace=False) usmm_csc_dense = UsmmCscDense(inplace=False)
usmm_csc_dense_inplace = UsmmCscDense(inplace=True) usmm_csc_dense_inplace = UsmmCscDense(inplace=True)
local_usmm = gof.opt.PatternSub((tensor.sub, 'z', (tensor.mul, {'pattern' : 'alpha', 'constraint' : lambda expr: numpy.all(expr.type.broadcastable) },
(_dot, 'x', 'y'))), local_usmm = gof.opt.PatternSub(
(usmm, (tensor.neg, 'alpha'), 'x', 'y', 'z')) (tensor.sub, 'z',
(tensor.mul,
{'pattern': 'alpha',
'constraint': lambda expr: numpy.all(expr.type.broadcastable)},
(_dot, 'x', 'y'))),
(usmm, (tensor.neg, 'alpha'), 'x', 'y', 'z'))
register_specialize(local_usmm, name="local_usmm") register_specialize(local_usmm, name="local_usmm")
...@@ -1876,15 +1899,18 @@ def local_usmm_csx(node): ...@@ -1876,15 +1899,18 @@ def local_usmm_csx(node):
if x.type.format == 'csc': if x.type.format == 'csc':
x_val, x_ind, x_ptr, x_shape = csm_properties(x) x_val, x_ind, x_ptr, x_shape = csm_properties(x)
x_nsparse = x_shape[0] x_nsparse = x_shape[0]
dtype_out = scalar.upcast(alpha.type.dtype, x.type.dtype, y.type.dtype, z.type.dtype) dtype_out = scalar.upcast(alpha.type.dtype, x.type.dtype,
y.type.dtype, z.type.dtype)
# Sparse cast is not implemented. # Sparse cast is not implemented.
if y.type.dtype != dtype_out: if y.type.dtype != dtype_out:
return False return False
return [usmm_csc_dense(alpha, x_val, x_ind, x_ptr, x_nsparse, y, z)] return [usmm_csc_dense(alpha, x_val, x_ind, x_ptr,
x_nsparse, y, z)]
return False return False
register_specialize(local_usmm_csx) register_specialize(local_usmm_csx)
@gof.local_optimizer([usmm_csc_dense]) @gof.local_optimizer([usmm_csc_dense])
def local_usmm_csc_dense_inplace(node): def local_usmm_csc_dense_inplace(node):
if node.op == usmm_csc_dense: if node.op == usmm_csc_dense:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论