提交 aff96770 authored 作者: Frederic Bastien's avatar Frederic Bastien

Allow loading more nfunc path (like class method) and bug fixes.

上级 26bdc132
......@@ -392,7 +392,7 @@ second dimension
inplace_pattern = frozendict({})
self.name = name
self.scalar_op = scalar_op
self.inplace_pattern = frozendict(inplace_pattern)
self.inplace_pattern = inplace_pattern
self.destroy_map = dict((o, [i]) for o, i in self.inplace_pattern.items())
self.ufunc = None
......@@ -400,16 +400,7 @@ second dimension
if nfunc_spec is None:
nfunc_spec = getattr(scalar_op, 'nfunc_spec', None)
self.nfunc_spec = nfunc_spec
if nfunc_spec:
self.nfunc = getattr(np, nfunc_spec[0], None)
if self.nfunc is None:
# Not inside NumPy. So probably another package like scipy.
symb = 'scipy.special.erfinv'.split(".")
module = __import__('.'.join(symb[:-1]))
for sub in symb[1:-1]:
module = getattr(module, sub)
self.nfunc = getattr(module, symb[-1])
self.__setstate__(self.__dict__)
super(Elemwise, self).__init__(openmp=openmp)
def __getstate__(self):
......@@ -425,7 +416,19 @@ second dimension
self.nfunc = None
self.inplace_pattern = frozendict(self.inplace_pattern)
if getattr(self, 'nfunc_spec', None):
self.nfunc = getattr(np, self.nfunc_spec[0])
self.nfunc = getattr(np, self.nfunc_spec[0], None)
if self.nfunc is None:
# Not inside NumPy. So probably another package like scipy.
symb = self.nfunc_spec[0].split(".")
for idx in range(1, len(self.nfunc_spec[0])):
try:
module = __import__('.'.join(symb[:idx]))
except ImportError:
break
for sub in symb[1:]:
module = getattr(module, sub)
self.nfunc = module
elif 0 < self.scalar_op.nin < 32:
self.ufunc = np.frompyfunc(self.scalar_op.impl,
self.scalar_op.nin,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论