提交 0520f913 authored 作者: Frederic's avatar Frederic

small code cleanup.

上级 c225b7a2
...@@ -25,12 +25,6 @@ from theano.sparse.type import SparseType, _is_sparse ...@@ -25,12 +25,6 @@ from theano.sparse.type import SparseType, _is_sparse
sparse_formats = ['csc', 'csr'] sparse_formats = ['csc', 'csr']
# TODO: move this decorator to the compile submodule
def register_specialize(lopt, *tags, **kwargs):
compile.optdb['specialize'].register((kwargs and kwargs.pop('name')) or
lopt.__name__, lopt, 'fast_run',
*tags)
""" Types of sparse matrices to use for testing """ """ Types of sparse matrices to use for testing """
_mtypes = [scipy.sparse.csc_matrix, scipy.sparse.csr_matrix] _mtypes = [scipy.sparse.csc_matrix, scipy.sparse.csr_matrix]
#_mtypes = [sparse.csc_matrix, sparse.csr_matrix, sparse.dok_matrix, #_mtypes = [sparse.csc_matrix, sparse.csr_matrix, sparse.dok_matrix,
......
...@@ -6,8 +6,8 @@ import scipy ...@@ -6,8 +6,8 @@ import scipy
import theano import theano
from theano import gof, scalar, tensor from theano import gof, scalar, tensor
from theano.tensor import blas from theano.tensor import blas
from theano.tensor.opt import register_specialize
from theano.sparse import (CSC, CSR, csm_properties, from theano.sparse import (CSC, CSR, csm_properties,
register_specialize,
csm_grad, usmm, csm_indices, csm_indptr, csm_grad, usmm, csm_indices, csm_indptr,
csm_data) csm_data)
from theano.sparse import basic as sparse from theano.sparse import basic as sparse
...@@ -29,7 +29,7 @@ def local_csm_properties_csm(node): ...@@ -29,7 +29,7 @@ def local_csm_properties_csm(node):
return ret_var return ret_var
return False return False
sparse.register_specialize(local_csm_properties_csm) register_specialize(local_csm_properties_csm)
# This is tested in tests/test_basic.py:test_remove0 # This is tested in tests/test_basic.py:test_remove0
...@@ -861,7 +861,7 @@ def local_usmm_csx(node): ...@@ -861,7 +861,7 @@ def local_usmm_csx(node):
return [usmm_csc_dense(alpha, x_val, x_ind, x_ptr, return [usmm_csc_dense(alpha, x_val, x_ind, x_ptr,
x_nsparse, y, z)] x_nsparse, y, z)]
return False return False
sparse.register_specialize(local_usmm_csx, 'cxx_only') register_specialize(local_usmm_csx, 'cxx_only')
class CSMGradC(gof.Op): class CSMGradC(gof.Op):
...@@ -1272,7 +1272,7 @@ def local_mul_s_d(node): ...@@ -1272,7 +1272,7 @@ def local_mul_s_d(node):
sparse.csm_shape(svar))] sparse.csm_shape(svar))]
return False return False
sparse.register_specialize(local_mul_s_d, 'cxx_only') register_specialize(local_mul_s_d, 'cxx_only')
class MulSVCSR(gof.Op): class MulSVCSR(gof.Op):
...@@ -1414,7 +1414,7 @@ def local_mul_s_v(node): ...@@ -1414,7 +1414,7 @@ def local_mul_s_v(node):
return [CSx(c_data, s_ind, s_ptr, s_shape)] return [CSx(c_data, s_ind, s_ptr, s_shape)]
return False return False
sparse.register_specialize(local_mul_s_v, 'cxx_only') register_specialize(local_mul_s_v, 'cxx_only')
class StructuredAddSVCSR(gof.Op): class StructuredAddSVCSR(gof.Op):
...@@ -1573,7 +1573,7 @@ def local_structured_add_s_v(node): ...@@ -1573,7 +1573,7 @@ def local_structured_add_s_v(node):
return [CSx(c_data, s_ind, s_ptr, s_shape)] return [CSx(c_data, s_ind, s_ptr, s_shape)]
return False return False
sparse.register_specialize(local_structured_add_s_v, 'cxx_only') register_specialize(local_structured_add_s_v, 'cxx_only')
class SamplingDotCSR(gof.Op): class SamplingDotCSR(gof.Op):
...@@ -1822,6 +1822,6 @@ def local_sampling_dot_csr(node): ...@@ -1822,6 +1822,6 @@ def local_sampling_dot_csr(node):
return [sparse.CSR(z_data, z_ind, z_ptr, p_shape)] return [sparse.CSR(z_data, z_ind, z_ptr, p_shape)]
return False return False
sparse.register_specialize(local_sampling_dot_csr, register_specialize(local_sampling_dot_csr,
'cxx_only', 'cxx_only',
name='local_sampling_dot_csr') name='local_sampling_dot_csr')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论