提交 2a2353b3 authored 作者: AlOa's avatar AlOa

Add openmp to elemwise contiguous case

上级 85006a60
...@@ -6,7 +6,7 @@ import numpy ...@@ -6,7 +6,7 @@ import numpy
import theano import theano
from theano import gof from theano import gof
from theano.gof import Apply, Op from theano.gof import Apply, Op, OpenMPOp
from theano import scalar from theano import scalar
from theano.scalar import Scalar, get_scalar_type from theano.scalar import Scalar, get_scalar_type
from theano.printing import pprint from theano.printing import pprint
...@@ -419,7 +419,7 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle), ...@@ -419,7 +419,7 @@ pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle),
### Elemwise ### ### Elemwise ###
################ ################
class Elemwise(Op): class Elemwise(OpenMPOp):
""" """
Generalizes a scalar op to tensors. Generalizes a scalar op to tensors.
...@@ -449,7 +449,7 @@ class Elemwise(Op): ...@@ -449,7 +449,7 @@ class Elemwise(Op):
""" """
def __init__(self, scalar_op, inplace_pattern=None, name=None, def __init__(self, scalar_op, inplace_pattern=None, name=None,
nfunc_spec=None): nfunc_spec=None,openmp=None):
""" """
Usage: Elemwise(scalar_op, inplace_pattern = {}) Usage: Elemwise(scalar_op, inplace_pattern = {})
...@@ -487,6 +487,7 @@ class Elemwise(Op): ...@@ -487,6 +487,7 @@ class Elemwise(Op):
#precompute the hash of this node #precompute the hash of this node
self._rehash() self._rehash()
super(Elemwise,self).__init__(openmp=openmp)
def __getstate__(self): def __getstate__(self):
d = copy(self.__dict__) d = copy(self.__dict__)
...@@ -1117,7 +1118,8 @@ class Elemwise(Op): ...@@ -1117,7 +1118,8 @@ class Elemwise(Op):
contig += """ contig += """
dtype_%(x)s& %(x)s_i = ((dtype_%(x)s*) PyArray_DATA(%(x)s))[0]; dtype_%(x)s& %(x)s_i = ((dtype_%(x)s*) PyArray_DATA(%(x)s))[0];
""" % locals() """ % locals()
if self.openmp:
contig += """#pragma omp parallel for if(n>=%d)""" % (config.openmp_minsize)
contig += """ contig += """
for(int i=0; i<n; i++){ for(int i=0; i<n; i++){
%(index)s %(index)s
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论