提交 8a40497a authored 作者: Olivier Breuleux's avatar Olivier Breuleux

some cleanup

上级 f95435ea
......@@ -2004,7 +2004,7 @@ def round(a, mode="half_away_from_zero"):
else:
raise Exception("round mode %s is not implemented."%mode)
@_scal_elemwise_with_nfunc('around', 1, 0)
@_scal_elemwise_with_nfunc('around', 1, -1)
def round_half_to_even(a):
"""round_half_to_even(a)"""
......@@ -2052,15 +2052,15 @@ def erf(a):
def erfc(a):
"""complementary error function"""
@_scal_elemwise_with_nfunc('real', 1, 0)
@_scal_elemwise_with_nfunc('real', 1, -1)
def real(z):
"""Return real component of complex-valued tensor `z`"""
@_scal_elemwise_with_nfunc('imag', 1, 0)
@_scal_elemwise_with_nfunc('imag', 1, -1)
def imag(z):
"""Return imaginary component of complex-valued tensor `z`"""
@_scal_elemwise_with_nfunc('angle', 1, 0)
@_scal_elemwise_with_nfunc('angle', 1, -1)
def angle(z):
"""Return polar-coordinate angle of complex-valued tensor `z`"""
......
......@@ -361,10 +361,6 @@ class DimShufflePrinter:
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, DimShuffle), DimShufflePrinter())
def _make_nfunc(name, nin, nout):
f = getattr(numpy, name)
return f
################
### Elemwise ###
......@@ -408,11 +404,13 @@ class Elemwise(Op):
the input's storage. (Just like destroymap, but without the lists.)
* nfunc_spec: either None or a tuple of three elements, (nfunc_name, nin, nout) such
that getattr(numpy, nfunc_name) implements this operation, takes nin
inputs and nout **destination** outputs (nout == 0 if the numpy function
inputs and abs(nout) outputs (nout < 0 if the numpy function
does not provide the option of providing a numpy array to store the
results in). Note that nin cannot always be inferred from the scalar op's
own nin field because that value is sometimes 0 (meaning a variable number
of inputs), whereas the numpy function may not have varargs.
of inputs), whereas the numpy function may not have varargs. NOTE: as of
now, the sign of the nout field is ignored (some work needs to be done
to resize the destinations when needed).
"""
self.name = name
self.scalar_op = scalar_op
......@@ -423,7 +421,7 @@ class Elemwise(Op):
self.nfunc = None
self.nfunc_spec = nfunc_spec
if nfunc_spec:
self.nfunc = _make_nfunc(*nfunc_spec)
self.nfunc = getattr(numpy, nfunc_spec[0])
elif scalar_op.nin > 0:
self.ufunc = numpy.frompyfunc(scalar_op.impl, scalar_op.nin, scalar_op.nout)
......@@ -443,7 +441,7 @@ class Elemwise(Op):
self.ufunc = None
self.nfunc = None
if getattr(self, 'nfunc_spec', None):
self.nfunc = _make_nfunc(*self.nfunc_spec)
self.nfunc = getattr(numpy, self.nfunc_spec[0])
elif self.scalar_op.nin > 0:
self.ufunc = numpy.frompyfunc(self.scalar_op.impl, self.scalar_op.nin, self.scalar_op.nout)
self._rehash()
......@@ -645,8 +643,8 @@ class Elemwise(Op):
if self.nfunc and len(inputs) == self.nfunc_spec[1]:
ufunc = self.nfunc
nout = self.nfunc_spec[2]
if nout == 0:
nout = 1
if nout < 0:
nout = -nout
# Unfortunately, the else case does not allow us to
# directly feed the destination arguments to the nfunc
# since it sometimes requires resizing. Doing this
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论