提交 ed337d4e authored 作者: Frederic Bastien's avatar Frederic Bastien

For scalar that have nin=-1 cache the builded ufunc in the node.tag

上级 bdcb752a
...@@ -7,6 +7,7 @@ import numpy ...@@ -7,6 +7,7 @@ import numpy
import theano import theano
from theano import gof from theano import gof
from theano.compat import izip from theano.compat import izip
from theano.compat import get_unbound_function
from six import iteritems from six import iteritems
from six.moves import xrange from six.moves import xrange
from theano.gof import Apply, Op, OpenMPOp from theano.gof import Apply, Op, OpenMPOp
...@@ -506,6 +507,9 @@ class Elemwise(OpenMPOp): ...@@ -506,6 +507,9 @@ class Elemwise(OpenMPOp):
if nfunc_spec: if nfunc_spec:
self.nfunc = getattr(numpy, nfunc_spec[0]) self.nfunc = getattr(numpy, nfunc_spec[0])
elif scalar_op.nin > 0 and scalar_op.nin < 32: elif scalar_op.nin > 0 and scalar_op.nin < 32:
# NumPy ufunc support only up to 31 inputs.
# But our c code support more.
# when nin == -1, we will build the ufunc in the make_thunk.
self.ufunc = numpy.frompyfunc(scalar_op.impl, scalar_op.nin, self.ufunc = numpy.frompyfunc(scalar_op.impl, scalar_op.nin,
scalar_op.nout) scalar_op.nout)
...@@ -792,6 +796,20 @@ class Elemwise(OpenMPOp): ...@@ -792,6 +796,20 @@ class Elemwise(OpenMPOp):
return ret return ret
def make_thunk(self, node, storage_map, compute_map, no_recycling):
node_ = node
if self.ufunc is None and self.scalar_op.nin == -1:
node_ = copy(node)
assert node.op is node_.op
ufunc = numpy.frompyfunc(self.scalar_op.impl,
len(node.inputs),
self.scalar_op.nout)
node_.op.ufunc = ufunc
return super(Elemwise, node_.op).make_thunk(node_, storage_map,
compute_map, no_recycling)
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
if len(node.inputs) >= 32: if len(node.inputs) >= 32:
# Some versions of NumPy will segfault, other will raise a # Some versions of NumPy will segfault, other will raise a
...@@ -859,9 +877,11 @@ class Elemwise(OpenMPOp): ...@@ -859,9 +877,11 @@ class Elemwise(OpenMPOp):
else: else:
# the second calling form is used because in certain versions of # the second calling form is used because in certain versions of
# numpy the first (faster) version leads to segfaults # numpy the first (faster) version leads to segfaults
ufunc = (self.ufunc or if self.ufunc:
numpy.frompyfunc(self.scalar_op.impl, len(inputs), ufunc = self.ufunc
self.scalar_op.nout)) else:
ufunc = node.tag.ufunc
nout = ufunc.nout nout = ufunc.nout
variables = ufunc(*ufunc_args, **ufunc_kwargs) variables = ufunc(*ufunc_args, **ufunc_kwargs)
...@@ -1234,6 +1254,9 @@ class Elemwise(OpenMPOp): ...@@ -1234,6 +1254,9 @@ class Elemwise(OpenMPOp):
""" """
return node.outputs[0].ndim == 0 return node.outputs[0].ndim == 0
theano.compile.debugmode.default_make_thunk.append(
get_unbound_function(Elemwise.make_thunk))
# def elemwise_to_scal(fgraph): # def elemwise_to_scal(fgraph):
# TODO: why is this commented out? should it be removed? # TODO: why is this commented out? should it be removed?
# it has needed maintenance despite being commented # it has needed maintenance despite being commented
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论