提交 872c63a8 authored 作者: Frederic Bastien's avatar Frederic Bastien

Move to prepare_node, but call it in perform for corner case.

上级 84fdef4f
......@@ -109,6 +109,7 @@ class Erfcx(UnaryScalarOp):
"""
nfunc_spec = ('scipy.special.erfcx', 1, 1)
def impl(self, x):
if imported_scipy_special:
return scipy.special.erfcx(x)
......
......@@ -395,8 +395,6 @@ second dimension
self.inplace_pattern = inplace_pattern
self.destroy_map = dict((o, [i]) for o, i in self.inplace_pattern.items())
self.ufunc = None
self.nfunc = None
if nfunc_spec is None:
nfunc_spec = getattr(scalar_op, 'nfunc_spec', None)
self.nfunc_spec = nfunc_spec
......@@ -415,24 +413,6 @@ second dimension
self.ufunc = None
self.nfunc = None
self.inplace_pattern = frozendict(self.inplace_pattern)
if getattr(self, 'nfunc_spec', None):
self.nfunc = getattr(np, self.nfunc_spec[0], None)
if self.nfunc is None:
# Not inside NumPy. So probably another package like scipy.
symb = self.nfunc_spec[0].split(".")
for idx in range(1, len(self.nfunc_spec[0])):
try:
module = __import__('.'.join(symb[:idx]))
except ImportError:
break
for sub in symb[1:]:
module = getattr(module, sub)
self.nfunc = module
elif 0 < self.scalar_op.nin < 32:
self.ufunc = np.frompyfunc(self.scalar_op.impl,
self.scalar_op.nin,
self.scalar_op.nout)
def get_output_info(self, dim_shuffle, *inputs):
"""Return the outputs dtype and broadcastable pattern and the
......@@ -665,9 +645,24 @@ second dimension
return ret
def prepare_node(self, node, storage_map, compute_map, impl):
# Postpone the ufunc building to the last minutes
# NumPy ufunc support only up to 31 inputs.
# But our c code support more.
# Postpone the ufunc building to the last minutes due to:
# - NumPy ufunc support only up to 31 inputs.
# But our c code support more.
# - nfunc is reused for scipy and scipy is optional
if getattr(self, 'nfunc_spec', None):
self.nfunc = getattr(np, self.nfunc_spec[0], None)
if self.nfunc is None:
# Not inside NumPy. So probably another package like scipy.
symb = self.nfunc_spec[0].split(".")
for idx in range(1, len(self.nfunc_spec[0])):
try:
module = __import__('.'.join(symb[:idx]))
except ImportError:
break
for sub in symb[1:]:
module = getattr(module, sub)
self.nfunc = module
if (len(node.inputs) < 32 and
(self.nfunc is None or
self.scalar_op.nin != len(node.inputs)) and
......@@ -753,6 +748,10 @@ second dimension
ufunc_args = inputs
ufunc_kwargs = {}
# We supported in the past calling manually op.perform.
# To keep that support we need to sometimes call self.prepare_node
if self.nfunc is None and self.ufunc is None:
self.prepare_node(node, None, None, 'py')
if self.nfunc and len(inputs) == self.nfunc_spec[1]:
ufunc = self.nfunc
nout = self.nfunc_spec[2]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论