提交 872c63a8 authored 作者: Frederic Bastien's avatar Frederic Bastien

Move to prepare_node, but call it in perform for corner case.

上级 84fdef4f
...@@ -109,6 +109,7 @@ class Erfcx(UnaryScalarOp): ...@@ -109,6 +109,7 @@ class Erfcx(UnaryScalarOp):
""" """
nfunc_spec = ('scipy.special.erfcx', 1, 1) nfunc_spec = ('scipy.special.erfcx', 1, 1)
def impl(self, x): def impl(self, x):
if imported_scipy_special: if imported_scipy_special:
return scipy.special.erfcx(x) return scipy.special.erfcx(x)
......
...@@ -395,8 +395,6 @@ second dimension ...@@ -395,8 +395,6 @@ second dimension
self.inplace_pattern = inplace_pattern self.inplace_pattern = inplace_pattern
self.destroy_map = dict((o, [i]) for o, i in self.inplace_pattern.items()) self.destroy_map = dict((o, [i]) for o, i in self.inplace_pattern.items())
self.ufunc = None
self.nfunc = None
if nfunc_spec is None: if nfunc_spec is None:
nfunc_spec = getattr(scalar_op, 'nfunc_spec', None) nfunc_spec = getattr(scalar_op, 'nfunc_spec', None)
self.nfunc_spec = nfunc_spec self.nfunc_spec = nfunc_spec
...@@ -415,24 +413,6 @@ second dimension ...@@ -415,24 +413,6 @@ second dimension
self.ufunc = None self.ufunc = None
self.nfunc = None self.nfunc = None
self.inplace_pattern = frozendict(self.inplace_pattern) self.inplace_pattern = frozendict(self.inplace_pattern)
if getattr(self, 'nfunc_spec', None):
self.nfunc = getattr(np, self.nfunc_spec[0], None)
if self.nfunc is None:
# Not inside NumPy. So probably another package like scipy.
symb = self.nfunc_spec[0].split(".")
for idx in range(1, len(self.nfunc_spec[0])):
try:
module = __import__('.'.join(symb[:idx]))
except ImportError:
break
for sub in symb[1:]:
module = getattr(module, sub)
self.nfunc = module
elif 0 < self.scalar_op.nin < 32:
self.ufunc = np.frompyfunc(self.scalar_op.impl,
self.scalar_op.nin,
self.scalar_op.nout)
def get_output_info(self, dim_shuffle, *inputs): def get_output_info(self, dim_shuffle, *inputs):
"""Return the outputs dtype and broadcastable pattern and the """Return the outputs dtype and broadcastable pattern and the
...@@ -665,9 +645,24 @@ second dimension ...@@ -665,9 +645,24 @@ second dimension
return ret return ret
def prepare_node(self, node, storage_map, compute_map, impl): def prepare_node(self, node, storage_map, compute_map, impl):
# Postpone the ufunc building to the last minutes # Postpone the ufunc building to the last minutes due to:
# NumPy ufunc support only up to 31 inputs. # - NumPy ufunc support only up to 31 inputs.
# But our c code support more. # But our c code support more.
# - nfunc is reused for scipy and scipy is optional
if getattr(self, 'nfunc_spec', None):
self.nfunc = getattr(np, self.nfunc_spec[0], None)
if self.nfunc is None:
# Not inside NumPy. So probably another package like scipy.
symb = self.nfunc_spec[0].split(".")
for idx in range(1, len(self.nfunc_spec[0])):
try:
module = __import__('.'.join(symb[:idx]))
except ImportError:
break
for sub in symb[1:]:
module = getattr(module, sub)
self.nfunc = module
if (len(node.inputs) < 32 and if (len(node.inputs) < 32 and
(self.nfunc is None or (self.nfunc is None or
self.scalar_op.nin != len(node.inputs)) and self.scalar_op.nin != len(node.inputs)) and
...@@ -753,6 +748,10 @@ second dimension ...@@ -753,6 +748,10 @@ second dimension
ufunc_args = inputs ufunc_args = inputs
ufunc_kwargs = {} ufunc_kwargs = {}
# We supported in the past calling manually op.perform.
# To keep that support we need to sometimes call self.prepare_node
if self.nfunc is None and self.ufunc is None:
self.prepare_node(node, None, None, 'py')
if self.nfunc and len(inputs) == self.nfunc_spec[1]: if self.nfunc and len(inputs) == self.nfunc_spec[1]:
ufunc = self.nfunc ufunc = self.nfunc
nout = self.nfunc_spec[2] nout = self.nfunc_spec[2]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论