提交 a8be57e6 authored 作者: Frederic's avatar Frederic

Add the optimization tag cxx_only.

This tag tell that an optimization will insert Op with c code only. We don't want to run that optimization then there is no c compiler.
上级 7c8e8312
...@@ -19,7 +19,11 @@ In this section we will define a couple optimizations on doubles. ...@@ -19,7 +19,11 @@ In this section we will define a couple optimizations on doubles.
Later, the rest is more useful for when that decorator syntax type thing Later, the rest is more useful for when that decorator syntax type thing
doesn't work. (There are optimizations that don't fit that model). doesn't work. (There are optimizations that don't fit that model).
.. note::
There is the optimization tag `cxx_only` that tell this
optimization will insert Op that only have c code. So we should not
run them when we don't have a c++ compiler.
Global and local optimizations Global and local optimizations
============================== ==============================
......
...@@ -83,10 +83,13 @@ def register_linker(name, linker): ...@@ -83,10 +83,13 @@ def register_linker(name, linker):
# If a string is passed as the optimizer argument in the constructor # If a string is passed as the optimizer argument in the constructor
# for Mode, it will be used as the key to retrieve the real optimizer # for Mode, it will be used as the key to retrieve the real optimizer
# in this dictionary # in this dictionary
OPT_FAST_RUN = gof.Query(include=['fast_run']) exclude=[]
if not theano.config.cxx:
exclude = ['cxx_only']
OPT_FAST_RUN = gof.Query(include=['fast_run'], exclude=exclude)
OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring('stable') OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring('stable')
OPT_FAST_COMPILE = gof.Query(include=['fast_compile']) OPT_FAST_COMPILE = gof.Query(include=['fast_compile'], exclude=exclude)
OPT_STABILIZE = gof.Query(include=['fast_run']) OPT_STABILIZE = gof.Query(include=['fast_run'], exclude=exclude)
OPT_STABILIZE.position_cutoff = 1.5000001 OPT_STABILIZE.position_cutoff = 1.5000001
OPT_FAST_RUN.name = 'OPT_FAST_RUN' OPT_FAST_RUN.name = 'OPT_FAST_RUN'
OPT_FAST_RUN_STABLE.name = 'OPT_FAST_RUN_STABLE' OPT_FAST_RUN_STABLE.name = 'OPT_FAST_RUN_STABLE'
......
...@@ -686,7 +686,7 @@ register_specialize(local_usmm, name="local_usmm") ...@@ -686,7 +686,7 @@ register_specialize(local_usmm, name="local_usmm")
def local_usmm_csc_dense_inplace(node): def local_usmm_csc_dense_inplace(node):
if node.op == usmm_csc_dense: if node.op == usmm_csc_dense:
return [usmm_csc_dense_inplace(*node.inputs)] return [usmm_csc_dense_inplace(*node.inputs)]
register_specialize(local_usmm_csc_dense_inplace, 'inplace') register_specialize(local_usmm_csc_dense_inplace, 'cxx_only', 'inplace')
# This is tested in tests/test_basic.py:UsmmTests # This is tested in tests/test_basic.py:UsmmTests
...@@ -714,7 +714,7 @@ def local_usmm_csx(node): ...@@ -714,7 +714,7 @@ def local_usmm_csx(node):
return [usmm_csc_dense(alpha, x_val, x_ind, x_ptr, return [usmm_csc_dense(alpha, x_val, x_ind, x_ptr,
x_nsparse, y, z)] x_nsparse, y, z)]
return False return False
sparse.register_specialize(local_usmm_csx) sparse.register_specialize(local_usmm_csx, 'cxx_only')
class CSMGradC(gof.Op): class CSMGradC(gof.Op):
...@@ -850,7 +850,7 @@ def local_csm_grad_c(node): ...@@ -850,7 +850,7 @@ def local_csm_grad_c(node):
if node.op == csm_grad(None): if node.op == csm_grad(None):
return [csm_grad_c(*node.inputs)] return [csm_grad_c(*node.inputs)]
return False return False
register_specialize(local_csm_grad_c) register_specialize(local_csm_grad_c, 'cxx_only')
class MulSDCSC(gof.Op): class MulSDCSC(gof.Op):
...@@ -1117,7 +1117,7 @@ def local_mul_s_d(node): ...@@ -1117,7 +1117,7 @@ def local_mul_s_d(node):
sparse.csm_shape(svar))] sparse.csm_shape(svar))]
return False return False
sparse.register_specialize(local_mul_s_d) sparse.register_specialize(local_mul_s_d, 'cxx_only')
class MulSVCSR(gof.Op): class MulSVCSR(gof.Op):
...@@ -1259,7 +1259,7 @@ def local_mul_s_v(node): ...@@ -1259,7 +1259,7 @@ def local_mul_s_v(node):
return [CSx(c_data, s_ind, s_ptr, s_shape)] return [CSx(c_data, s_ind, s_ptr, s_shape)]
return False return False
sparse.register_specialize(local_mul_s_v) sparse.register_specialize(local_mul_s_v, 'cxx_only')
class StructuredAddSVCSR(gof.Op): class StructuredAddSVCSR(gof.Op):
...@@ -1416,7 +1416,7 @@ def local_structured_add_s_v(node): ...@@ -1416,7 +1416,7 @@ def local_structured_add_s_v(node):
return [CSx(c_data, s_ind, s_ptr, s_shape)] return [CSx(c_data, s_ind, s_ptr, s_shape)]
return False return False
sparse.register_specialize(local_structured_add_s_v) sparse.register_specialize(local_structured_add_s_v, 'cxx_only')
class SamplingDotCSR(gof.Op): class SamplingDotCSR(gof.Op):
...@@ -1656,4 +1656,5 @@ def local_sampling_dot_csr(node): ...@@ -1656,4 +1656,5 @@ def local_sampling_dot_csr(node):
return [sparse.CSR(z_data, z_ind, z_ptr, p_shape)] return [sparse.CSR(z_data, z_ind, z_ptr, p_shape)]
return False return False
sparse.register_specialize(local_sampling_dot_csr, sparse.register_specialize(local_sampling_dot_csr,
'cxx_only',
name='local_sampling_dot_csr') name='local_sampling_dot_csr')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论