提交 7e34c538 authored 作者: Frederic Bastien's avatar Frederic Bastien

Speed up Elemwise.perform

上级 91547f2a
...@@ -811,6 +811,24 @@ class Elemwise(OpenMPOp): ...@@ -811,6 +811,24 @@ class Elemwise(OpenMPOp):
else: else:
node.tag.ufunc = ufunc node.tag.ufunc = ufunc
# Numpy ufuncs will sometimes perform operations in
# float16, in particular when the input is int8.
# This is not something that we want, and we do not
# do it in the C code, so we specify that the computation
# should be carried out in the returned dtype.
# This is done via the "sig" kwarg of the ufunc, its value
# should be something like "ff->f", where the characters
# represent the dtype of the inputs and outputs.
# NumPy 1.10.1 raise an error when giving the signature
# when the input is complex. So add it only when inputs is int.
out_dtype = node.outputs[0].dtype
if (out_dtype in float_dtypes and
isinstance(self.nfunc, numpy.ufunc) and
node.inputs[0].dtype in scalar.int_types):
char = numpy.sctype2char(out_dtype)
sig = char * node.nin + '->' + char * node.nout
node.tag.sig = sig
return super(Elemwise, node_.op).make_thunk(node_, storage_map, return super(Elemwise, node_.op).make_thunk(node_, storage_map,
compute_map, no_recycling) compute_map, no_recycling)
...@@ -860,24 +878,8 @@ class Elemwise(OpenMPOp): ...@@ -860,24 +878,8 @@ class Elemwise(OpenMPOp):
if self.nfunc and len(inputs) == self.nfunc_spec[1]: if self.nfunc and len(inputs) == self.nfunc_spec[1]:
ufunc = self.nfunc ufunc = self.nfunc
nout = self.nfunc_spec[2] nout = self.nfunc_spec[2]
# Numpy ufuncs will sometimes perform operations in if hasattr(node.tag, 'sig'):
# float16, in particular when the input is int8. ufunc_kwargs['sig'] = node.tag.sig
# This is not something that we want, and we do not
# do it in the C code, so we specify that the computation
# should be carried out in the returned dtype.
# This is done via the "sig" kwarg of the ufunc, its value
# should be something like "ff->f", where the characters
# represent the dtype of the inputs and outputs.
# NumPy 1.10.1 raise an error when giving the signature
# when the input is complex. So add it only when inputs is int.
out_dtype = node.outputs[0].dtype
if (out_dtype in float_dtypes and
isinstance(ufunc, numpy.ufunc) and
inputs[0].dtype in scalar.int_types):
char = numpy.sctype2char(out_dtype)
sig = char * node.nin + '->' + char * node.nout
ufunc_kwargs['sig'] = sig
# Unfortunately, the else case does not allow us to # Unfortunately, the else case does not allow us to
# directly feed the destination arguments to the nfunc # directly feed the destination arguments to the nfunc
# since it sometimes requires resizing. Doing this # since it sometimes requires resizing. Doing this
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论