提交 a25bde02 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 of theano/sparse/opt.py

上级 b8fad4ca
...@@ -12,6 +12,7 @@ from theano.sparse import (CSC, CSR, csm_properties, ...@@ -12,6 +12,7 @@ from theano.sparse import (CSC, CSR, csm_properties,
from theano.sparse import basic as sparse from theano.sparse import basic as sparse
_is_sparse_variable = sparse._is_sparse_variable _is_sparse_variable = sparse._is_sparse_variable
_is_dense = sparse._is_dense
# This is tested in tests/test_opt.py:test_local_csm_properties_csm # This is tested in tests/test_opt.py:test_local_csm_properties_csm
...@@ -47,10 +48,11 @@ def local_inplace_remove0(node): ...@@ -47,10 +48,11 @@ def local_inplace_remove0(node):
return [new_node] return [new_node]
return False return False
theano.compile.optdb.register('local_inplace_remove0', theano.compile.optdb.register(
gof.TopoOptimizer(local_inplace_remove0, 'local_inplace_remove0',
failure_callback=gof.TopoOptimizer.warn_inplace), gof.TopoOptimizer(local_inplace_remove0,
60, 'fast_run', 'inplace') failure_callback=gof.TopoOptimizer.warn_inplace),
60, 'fast_run', 'inplace')
class AddSD_ccode(gof.op.Op): class AddSD_ccode(gof.op.Op):
...@@ -174,10 +176,11 @@ def local_inplace_addsd_ccode(node): ...@@ -174,10 +176,11 @@ def local_inplace_addsd_ccode(node):
inplace=True)(*node.inputs) inplace=True)(*node.inputs)
return [new_node] return [new_node]
return False return False
theano.compile.optdb.register('local_inplace_addsd_ccode', theano.compile.optdb.register(
gof.TopoOptimizer(local_inplace_addsd_ccode, 'local_inplace_addsd_ccode',
failure_callback=gof.TopoOptimizer.warn_inplace), gof.TopoOptimizer(local_inplace_addsd_ccode,
60, 'fast_run', 'inplace') failure_callback=gof.TopoOptimizer.warn_inplace),
60, 'fast_run', 'inplace')
@register_canonicalize("fast_compile") @register_canonicalize("fast_compile")
...@@ -234,16 +237,17 @@ class StructuredDotCSC(gof.Op): ...@@ -234,16 +237,17 @@ class StructuredDotCSC(gof.Op):
def make_node(self, a_val, a_ind, a_ptr, a_nrows, b): def make_node(self, a_val, a_ind, a_ptr, a_nrows, b):
dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype) dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype)
r = gof.Apply(self, [a_val, a_ind, a_ptr, a_nrows, b], r = gof.Apply(self, [a_val, a_ind, a_ptr, a_nrows, b],
[tensor.tensor(dtype_out, (False, b.type.broadcastable[1]))]) [tensor.tensor(dtype_out,
(False, b.type.broadcastable[1]))])
return r return r
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(a_val, a_ind, a_ptr, a_nrows, b) = inputs (a_val, a_ind, a_ptr, a_nrows, b) = inputs
(out,) = outputs (out,) = outputs
a = scipy.sparse.csc_matrix((a_val, a_ind, a_ptr), a = scipy.sparse.csc_matrix((a_val, a_ind, a_ptr),
(a_nrows, b.shape[0]), (a_nrows, b.shape[0]),
copy=False) copy=False)
#out[0] = a.dot(b) # out[0] = a.dot(b)
out[0] = theano._asarray(a * b, dtype=node.outputs[0].type.dtype) out[0] = theano._asarray(a * b, dtype=node.outputs[0].type.dtype)
assert _is_dense(out[0]) # scipy 0.7 automatically converts to dense assert _is_dense(out[0]) # scipy 0.7 automatically converts to dense
...@@ -427,17 +431,18 @@ class StructuredDotCSR(gof.Op): ...@@ -427,17 +431,18 @@ class StructuredDotCSR(gof.Op):
def make_node(self, a_val, a_ind, a_ptr, b): def make_node(self, a_val, a_ind, a_ptr, b):
self.dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype) self.dtype_out = scalar.upcast(a_val.type.dtype, b.type.dtype)
r = gof.Apply(self, [a_val, a_ind, a_ptr, b], r = gof.Apply(self, [a_val, a_ind, a_ptr, b],
[tensor.tensor(self.dtype_out, (False, [tensor.tensor(self.dtype_out,
b.type.broadcastable[1]))]) (False, b.type.broadcastable[1]))])
return r return r
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(a_val, a_ind, a_ptr, b) = inputs (a_val, a_ind, a_ptr, b) = inputs
(out,) = outputs (out,) = outputs
a = scipy.sparse.csr_matrix((a_val, a_ind, a_ptr), a = scipy.sparse.csr_matrix(
(len(a_ptr) - 1, b.shape[0]), (a_val, a_ind, a_ptr),
copy=True) # use view_map before setting this to False (len(a_ptr) - 1, b.shape[0]),
#out[0] = a.dot(b) copy=True) # use view_map before setting this to False
# out[0] = a.dot(b)
out[0] = a * b out[0] = a * b
# scipy 0.7 automatically converts to dense, but not .6 sometimes # scipy 0.7 automatically converts to dense, but not .6 sometimes
assert _is_dense(out[0]) assert _is_dense(out[0])
...@@ -634,7 +639,7 @@ class UsmmCscDense(gof.Op): ...@@ -634,7 +639,7 @@ class UsmmCscDense(gof.Op):
assert z.ndim == 2 assert z.ndim == 2
dtype_out = scalar.upcast(alpha.type.dtype, x_val.type.dtype, dtype_out = scalar.upcast(alpha.type.dtype, x_val.type.dtype,
y.type.dtype, z.type.dtype) y.type.dtype, z.type.dtype)
if dtype_out not in ('float32', 'float64'): if dtype_out not in ('float32', 'float64'):
raise NotImplementedError('only float types are supported in ' raise NotImplementedError('only float types are supported in '
...@@ -653,8 +658,9 @@ class UsmmCscDense(gof.Op): ...@@ -653,8 +658,9 @@ class UsmmCscDense(gof.Op):
if dtype_out != z.type.dtype: if dtype_out != z.type.dtype:
z = tensor.cast(z, dtype_out) z = tensor.cast(z, dtype_out)
r = gof.Apply(self, [alpha, x_val, x_ind, x_ptr, x_nrows, y, z], r = gof.Apply(
[tensor.tensor(dtype_out, (False, y.type.broadcastable[1]))]) self, [alpha, x_val, x_ind, x_ptr, x_nrows, y, z],
[tensor.tensor(dtype_out, (False, y.type.broadcastable[1]))])
return r return r
def c_support_code(self): def c_support_code(self):
...@@ -841,7 +847,7 @@ local_usmm = gof.opt.PatternSub( ...@@ -841,7 +847,7 @@ local_usmm = gof.opt.PatternSub(
{'pattern': 'alpha', {'pattern': 'alpha',
'constraint': lambda expr: (numpy.all(expr.type.broadcastable) and 'constraint': lambda expr: (numpy.all(expr.type.broadcastable) and
theano.config.blas.ldflags)}, theano.config.blas.ldflags)},
(sparse._dot, 'x', 'y'))), (sparse._dot, 'x', 'y'))),
(usmm, (theano.tensor.neg, 'alpha'), 'x', 'y', 'z')) (usmm, (theano.tensor.neg, 'alpha'), 'x', 'y', 'z'))
register_specialize(local_usmm, name="local_usmm") register_specialize(local_usmm, name="local_usmm")
...@@ -896,7 +902,7 @@ class CSMGradC(gof.Op): ...@@ -896,7 +902,7 @@ class CSMGradC(gof.Op):
def make_node(self, a_val, a_ind, a_ptr, a_dim, def make_node(self, a_val, a_ind, a_ptr, a_dim,
b_val, b_ind, b_ptr, b_dim): b_val, b_ind, b_ptr, b_dim):
return gof.Apply(self, [a_val, a_ind, a_ptr, a_dim, return gof.Apply(self, [a_val, a_ind, a_ptr, a_dim,
b_val, b_ind, b_ptr, b_dim], [b_val.type()]) b_val, b_ind, b_ptr, b_dim], [b_val.type()])
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
# retrieve dtype number # retrieve dtype number
...@@ -1019,7 +1025,7 @@ def local_csm_grad_c(node): ...@@ -1019,7 +1025,7 @@ def local_csm_grad_c(node):
return [csm_grad_c(*node.inputs)] return [csm_grad_c(*node.inputs)]
return False return False
# DISABLED AS IT IS BROKEN FOR UNSORTED INDICES! # DISABLED AS IT IS BROKEN FOR UNSORTED INDICES!
#register_specialize(local_csm_grad_c, 'cxx_only') # register_specialize(local_csm_grad_c, 'cxx_only')
class MulSDCSC(gof.Op): class MulSDCSC(gof.Op):
...@@ -1572,7 +1578,7 @@ def local_structured_add_s_v(node): ...@@ -1572,7 +1578,7 @@ def local_structured_add_s_v(node):
x, y = node.inputs x, y = node.inputs
x_is_sparse_variable = _is_sparse_variable(x) x_is_sparse_variable = _is_sparse_variable(x)
#y_is_sparse_variable = _is_sparse_variable(y) # y_is_sparse_variable = _is_sparse_variable(y)
if x_is_sparse_variable: if x_is_sparse_variable:
svar = x svar = x
...@@ -1840,7 +1846,7 @@ def local_sampling_dot_csr(node): ...@@ -1840,7 +1846,7 @@ def local_sampling_dot_csr(node):
p_data, p_ind, p_ptr, p_shape = sparse.csm_properties(p) p_data, p_ind, p_ptr, p_shape = sparse.csm_properties(p)
z_data, z_ind, z_ptr = sampling_dot_csr(x, y, p_data, z_data, z_ind, z_ptr = sampling_dot_csr(x, y, p_data,
p_ind, p_ptr, p_shape[1]) p_ind, p_ptr, p_shape[1])
return [sparse.CSR(z_data, z_ind, z_ptr, p_shape)] return [sparse.CSR(z_data, z_ind, z_ptr, p_shape)]
return False return False
......
...@@ -230,7 +230,6 @@ whitelist_flake8 = [ ...@@ -230,7 +230,6 @@ whitelist_flake8 = [
"misc/hooks/check_whitespace.py", "misc/hooks/check_whitespace.py",
"sparse/type.py", "sparse/type.py",
"sparse/__init__.py", "sparse/__init__.py",
"sparse/opt.py",
"sparse/tests/test_utils.py", "sparse/tests/test_utils.py",
"sparse/tests/test_opt.py", "sparse/tests/test_opt.py",
"sparse/tests/test_basic.py", "sparse/tests/test_basic.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论